from typing import Generator, List, Optional import numpy as np import torch from .dataset import CSGOHdf5Dataset, Dataset from .segment import SegmentId class BatchSampler(torch.utils.data.Sampler): def __init__( self, dataset: Dataset, rank: int, world_size: int, batch_size: int, seq_length: int, sample_weights: Optional[List[float]] = None, can_sample_beyond_end: bool = False, ) -> None: super().__init__(dataset) assert isinstance(dataset, (Dataset, CSGOHdf5Dataset)) self.dataset = dataset self.rank = rank self.world_size = world_size self.sample_weights = sample_weights self.batch_size = batch_size self.seq_length = seq_length self.can_sample_beyond_end = can_sample_beyond_end def __len__(self): raise NotImplementedError def __iter__(self) -> Generator[List[SegmentId], None, None]: while True: yield self.sample() def sample(self) -> List[SegmentId]: num_episodes = self.dataset.num_episodes if (self.sample_weights is None) or num_episodes < len(self.sample_weights): weights = self.dataset.lengths / self.dataset.num_steps else: weights = self.sample_weights num_weights = len(self.sample_weights) assert all([0 <= x <= 1 for x in weights]) and sum(weights) == 1 sizes = [ num_episodes // num_weights + (num_episodes % num_weights) * (i == num_weights - 1) for i in range(num_weights) ] weights = [w / s for (w, s) in zip(weights, sizes) for _ in range(s)] episodes_partition = np.arange(self.rank, num_episodes, self.world_size) weights = np.array(weights[self.rank::self.world_size]) max_eps = self.batch_size episode_ids = np.random.choice(episodes_partition, size=max_eps, replace=True, p=weights / weights.sum()) episode_ids = episode_ids.repeat(self.batch_size // max_eps) timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids]) # padding allowed, both before start and after end if self.can_sample_beyond_end: starts = timesteps - np.random.randint(0, self.seq_length, len(timesteps)) stops = starts + self.seq_length # padding allowed only before start else: stops = np.minimum( self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.seq_length, len(timesteps)) ) starts = stops - self.seq_length return [SegmentId(*x) for x in zip(episode_ids, starts, stops)]