File size: 1,720 Bytes
23bc32f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from typing import Generator, List
import numpy as np
import torch
from .dataset import EpisodeDataset
from .segment import SegmentId
class BatchSampler(torch.utils.data.Sampler):
def __init__(self, dataset: EpisodeDataset, num_steps_per_epoch: int, batch_size: int, sequence_length: int, can_sample_beyond_end: bool) -> None:
super().__init__(dataset)
self.dataset = dataset
self.probabilities = None
self.num_steps_per_epoch = num_steps_per_epoch
self.batch_size = batch_size
self.sequence_length = sequence_length
self.can_sample_beyond_end = can_sample_beyond_end
def __len__(self) -> int:
return self.num_steps_per_epoch
def __iter__(self) -> Generator[List[SegmentId], None, None]:
for _ in range(self.num_steps_per_epoch):
yield self.sample()
def sample(self) -> List[SegmentId]:
episode_ids = np.random.choice(np.arange(self.dataset.num_episodes), size=self.batch_size, replace=True, p=self.probabilities)
timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])
# padding allowed, both before start and after end
if self.can_sample_beyond_end:
starts = timesteps - np.random.randint(0, self.sequence_length, len(timesteps))
stops = starts + self.sequence_length
# padding allowed only before start
else:
stops = np.minimum(self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.sequence_length, len(timesteps)))
starts = stops - self.sequence_length
return list(map(lambda x: SegmentId(*x), zip(episode_ids, starts, stops)))
|