Delete delta-iris/src/data
Browse files- delta-iris/src/data/__init__.py +0 -7
- delta-iris/src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/batch.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/dataset.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/episode.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/episode_count.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/sampler.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/segment.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/utils.cpython-310.pyc +0 -0
- delta-iris/src/data/batch.py +0 -24
- delta-iris/src/data/dataset.py +0 -104
- delta-iris/src/data/episode.py +0 -41
- delta-iris/src/data/episode_count.py +0 -41
- delta-iris/src/data/sampler.py +0 -42
- delta-iris/src/data/segment.py +0 -25
- delta-iris/src/data/utils.py +0 -69
delta-iris/src/data/__init__.py
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
from .batch import Batch
|
| 2 |
-
from .dataset import EpisodeDataset
|
| 3 |
-
from .episode import Episode
|
| 4 |
-
from .episode_count import EpisodeCountManager
|
| 5 |
-
from .sampler import BatchSampler
|
| 6 |
-
from .segment import SegmentId
|
| 7 |
-
from .utils import collate_segments_to_batch, DatasetTraverser, make_segment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta-iris/src/data/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (543 Bytes)
|
|
|
delta-iris/src/data/__pycache__/batch.cpython-310.pyc
DELETED
|
Binary file (1.46 kB)
|
|
|
delta-iris/src/data/__pycache__/dataset.cpython-310.pyc
DELETED
|
Binary file (4.9 kB)
|
|
|
delta-iris/src/data/__pycache__/episode.cpython-310.pyc
DELETED
|
Binary file (1.8 kB)
|
|
|
delta-iris/src/data/__pycache__/episode_count.cpython-310.pyc
DELETED
|
Binary file (2.78 kB)
|
|
|
delta-iris/src/data/__pycache__/sampler.cpython-310.pyc
DELETED
|
Binary file (1.96 kB)
|
|
|
delta-iris/src/data/__pycache__/segment.cpython-310.pyc
DELETED
|
Binary file (1.06 kB)
|
|
|
delta-iris/src/data/__pycache__/utils.cpython-310.pyc
DELETED
|
Binary file (3.98 kB)
|
|
|
delta-iris/src/data/batch.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
from typing import List
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from .segment import SegmentId
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@dataclass
|
| 11 |
-
class Batch:
|
| 12 |
-
observations: torch.ByteTensor
|
| 13 |
-
actions: torch.LongTensor
|
| 14 |
-
rewards: torch.FloatTensor
|
| 15 |
-
ends: torch.LongTensor
|
| 16 |
-
mask_padding: torch.BoolTensor
|
| 17 |
-
segment_ids: List[SegmentId]
|
| 18 |
-
|
| 19 |
-
def pin_memory(self) -> Batch:
|
| 20 |
-
return Batch(**{k: v if k == 'segment_ids' else v.pin_memory() for k, v in self.__dict__.items()})
|
| 21 |
-
|
| 22 |
-
def to(self, device: torch.device) -> Batch:
|
| 23 |
-
return Batch(**{k: v if k == 'segment_ids' else v.to(device) for k, v in self.__dict__.items()})
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta-iris/src/data/dataset.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
import shutil
|
| 3 |
-
from typing import Dict, Optional, Union
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
from .episode import Episode
|
| 9 |
-
from .segment import Segment, SegmentId
|
| 10 |
-
from .utils import make_segment
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class EpisodeDataset(torch.utils.data.Dataset):
|
| 14 |
-
def __init__(self, directory: Path, name: str) -> None:
|
| 15 |
-
super().__init__()
|
| 16 |
-
self.name = name
|
| 17 |
-
self.directory = Path(directory)
|
| 18 |
-
self.num_episodes, self.num_steps, self.start_idx, self.lengths = None, None, None, None
|
| 19 |
-
|
| 20 |
-
if not self.directory.is_dir():
|
| 21 |
-
self._init_empty()
|
| 22 |
-
else:
|
| 23 |
-
self._load_info()
|
| 24 |
-
print(f'({name}) {self.num_episodes} episodes, {self.num_steps} steps.')
|
| 25 |
-
|
| 26 |
-
@property
|
| 27 |
-
def info_path(self) -> Path:
|
| 28 |
-
return self.directory / 'info.pt'
|
| 29 |
-
|
| 30 |
-
@property
|
| 31 |
-
def info(self) -> Dict[str, Union[int, np.ndarray]]:
|
| 32 |
-
return {'num_episodes': self.num_episodes, 'num_steps': self.num_steps, 'start_idx': self.start_idx, 'lengths': self.lengths}
|
| 33 |
-
|
| 34 |
-
def __len__(self) -> int:
|
| 35 |
-
return self.num_steps
|
| 36 |
-
|
| 37 |
-
def __getitem__(self, segment_id: SegmentId) -> Segment:
|
| 38 |
-
return self._load_segment(segment_id)
|
| 39 |
-
|
| 40 |
-
def _init_empty(self) -> None:
|
| 41 |
-
self.directory.mkdir(parents=True, exist_ok=False)
|
| 42 |
-
self.num_episodes = 0
|
| 43 |
-
self.num_steps = 0
|
| 44 |
-
self.start_idx = np.array([], dtype=np.int64)
|
| 45 |
-
self.lengths = np.array([], dtype=np.int64)
|
| 46 |
-
self.save_info()
|
| 47 |
-
|
| 48 |
-
def _load_info(self) -> None:
|
| 49 |
-
info = torch.load(self.info_path)
|
| 50 |
-
self.num_steps = info['num_steps']
|
| 51 |
-
self.num_episodes = info['num_episodes']
|
| 52 |
-
self.start_idx = info['start_idx']
|
| 53 |
-
self.lengths = info['lengths']
|
| 54 |
-
|
| 55 |
-
def save_info(self) -> None:
|
| 56 |
-
torch.save(self.info, self.info_path)
|
| 57 |
-
|
| 58 |
-
def clear(self) -> None:
|
| 59 |
-
shutil.rmtree(self.directory)
|
| 60 |
-
self._init_empty()
|
| 61 |
-
|
| 62 |
-
def _get_episode_path(self, episode_id: int) -> Path:
|
| 63 |
-
n = 3 # number of hierarchies
|
| 64 |
-
powers = np.arange(n)
|
| 65 |
-
subfolders = list(map(int, np.floor((episode_id % 10 ** (1 + powers)) / 10 ** powers) * 10 ** powers))[::-1]
|
| 66 |
-
return self.directory / '/'.join(list(map(lambda x: f'{x[1]:0{n - x[0]}d}', enumerate(subfolders)))) / f'{episode_id}.pt'
|
| 67 |
-
|
| 68 |
-
def _load_segment(self, segment_id: SegmentId, should_pad: bool = True) -> Segment:
|
| 69 |
-
episode = self.load_episode(segment_id.episode_id)
|
| 70 |
-
return make_segment(episode, segment_id, should_pad)
|
| 71 |
-
|
| 72 |
-
def load_episode(self, episode_id: int) -> Episode:
|
| 73 |
-
return Episode(**torch.load(self._get_episode_path(episode_id)))
|
| 74 |
-
|
| 75 |
-
def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
|
| 76 |
-
if episode_id is None:
|
| 77 |
-
episode_id = self.num_episodes
|
| 78 |
-
self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
|
| 79 |
-
self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
|
| 80 |
-
self.num_steps += len(episode)
|
| 81 |
-
self.num_episodes += 1
|
| 82 |
-
|
| 83 |
-
else:
|
| 84 |
-
assert episode_id < self.num_episodes
|
| 85 |
-
old_episode = self.load_episode(episode_id)
|
| 86 |
-
episode = old_episode.merge(episode)
|
| 87 |
-
incr_num_steps = len(episode) - len(old_episode)
|
| 88 |
-
self.lengths[episode_id] = len(episode)
|
| 89 |
-
self.start_idx[episode_id + 1:] += incr_num_steps
|
| 90 |
-
self.num_steps += incr_num_steps
|
| 91 |
-
|
| 92 |
-
episode_path = self._get_episode_path(episode_id)
|
| 93 |
-
episode_path.parent.mkdir(parents=True, exist_ok=True)
|
| 94 |
-
torch.save(episode.__dict__, episode_path.with_suffix('.tmp'))
|
| 95 |
-
episode_path.with_suffix('.tmp').rename(episode_path)
|
| 96 |
-
|
| 97 |
-
return episode_id
|
| 98 |
-
|
| 99 |
-
def get_episode_id_from_global_idx(self, global_idx: np.ndarray) -> np.ndarray:
|
| 100 |
-
return (np.argmax(self.start_idx.reshape(-1, 1) > global_idx, axis=0) - 1) % self.num_episodes
|
| 101 |
-
|
| 102 |
-
def get_global_idx_from_segment_id(self, segment_id: SegmentId) -> np.ndarray:
|
| 103 |
-
start_idx = self.start_idx[segment_id.episode_id]
|
| 104 |
-
return np.arange(start_idx + segment_id.start, start_idx + segment_id.stop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta-iris/src/data/episode.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass
|
| 8 |
-
class EpisodeMetrics:
|
| 9 |
-
episode_length: int
|
| 10 |
-
episode_return: float
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@dataclass
|
| 14 |
-
class Episode:
|
| 15 |
-
observations: torch.ByteTensor
|
| 16 |
-
actions: torch.LongTensor
|
| 17 |
-
rewards: torch.FloatTensor
|
| 18 |
-
ends: torch.LongTensor
|
| 19 |
-
|
| 20 |
-
def __post_init__(self):
|
| 21 |
-
assert len(self.observations) == len(self.actions) == len(self.rewards) == len(self.ends)
|
| 22 |
-
if self.ends.sum() > 0:
|
| 23 |
-
idx_end = torch.argmax(self.ends) + 1
|
| 24 |
-
self.observations = self.observations[:idx_end]
|
| 25 |
-
self.actions = self.actions[:idx_end]
|
| 26 |
-
self.rewards = self.rewards[:idx_end]
|
| 27 |
-
self.ends = self.ends[:idx_end]
|
| 28 |
-
|
| 29 |
-
def __len__(self) -> int:
|
| 30 |
-
return self.observations.size(0)
|
| 31 |
-
|
| 32 |
-
def merge(self, other: Episode) -> Episode:
|
| 33 |
-
return Episode(
|
| 34 |
-
torch.cat((self.observations, other.observations), dim=0),
|
| 35 |
-
torch.cat((self.actions, other.actions), dim=0),
|
| 36 |
-
torch.cat((self.rewards, other.rewards), dim=0),
|
| 37 |
-
torch.cat((self.ends, other.ends), dim=0),
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
def compute_metrics(self) -> EpisodeMetrics:
|
| 41 |
-
return EpisodeMetrics(len(self), self.rewards.sum())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta-iris/src/data/episode_count.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
from typing import Tuple
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from .dataset import EpisodeDataset
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class EpisodeCountManager:
|
| 11 |
-
def __init__(self, dataset: EpisodeDataset) -> None:
|
| 12 |
-
self.dataset = dataset
|
| 13 |
-
self.all_counts = dict()
|
| 14 |
-
|
| 15 |
-
def load(self, path_to_checkpoint: Path) -> None:
|
| 16 |
-
self.all_counts = torch.load(path_to_checkpoint)
|
| 17 |
-
assert all([counts.shape[0] == self.dataset.num_episodes for counts in self.all_counts.values()])
|
| 18 |
-
|
| 19 |
-
def save(self, path_to_checkpoint: Path) -> None:
|
| 20 |
-
torch.save(self.all_counts, path_to_checkpoint)
|
| 21 |
-
|
| 22 |
-
def register(self, *keys: Tuple[str]) -> None:
|
| 23 |
-
assert all([key not in self.all_counts for key in keys])
|
| 24 |
-
self.all_counts.update({key: np.zeros(self.dataset.num_episodes, dtype=np.int64) for key in keys})
|
| 25 |
-
|
| 26 |
-
def add_episode(self, episode_id: int) -> None:
|
| 27 |
-
for key, counts in self.all_counts.items():
|
| 28 |
-
assert episode_id <= counts.shape[0]
|
| 29 |
-
if episode_id == counts.shape[0]:
|
| 30 |
-
self.all_counts[key] = np.concatenate((counts, np.zeros(1, dtype=np.int64)))
|
| 31 |
-
assert self.all_counts[key].shape[0] == self.dataset.num_episodes
|
| 32 |
-
|
| 33 |
-
def increment_episode_count(self, key: str, episode_id: int) -> None:
|
| 34 |
-
assert key in self.all_counts
|
| 35 |
-
self.all_counts[key][episode_id] += 1
|
| 36 |
-
|
| 37 |
-
def compute_probabilities(self, key: str, alpha: float) -> np.ndarray:
|
| 38 |
-
assert key in self.all_counts
|
| 39 |
-
inverse_counts = 1 / (1 + self.all_counts[key])
|
| 40 |
-
p = inverse_counts ** alpha
|
| 41 |
-
return p / p.sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta-iris/src/data/sampler.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
from typing import Generator, List
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
from .dataset import EpisodeDataset
|
| 7 |
-
from .segment import SegmentId
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class BatchSampler(torch.utils.data.Sampler):
|
| 11 |
-
def __init__(self, dataset: EpisodeDataset, num_steps_per_epoch: int, batch_size: int, sequence_length: int, can_sample_beyond_end: bool) -> None:
|
| 12 |
-
super().__init__(dataset)
|
| 13 |
-
self.dataset = dataset
|
| 14 |
-
self.probabilities = None
|
| 15 |
-
self.num_steps_per_epoch = num_steps_per_epoch
|
| 16 |
-
self.batch_size = batch_size
|
| 17 |
-
self.sequence_length = sequence_length
|
| 18 |
-
self.can_sample_beyond_end = can_sample_beyond_end
|
| 19 |
-
|
| 20 |
-
def __len__(self) -> int:
|
| 21 |
-
return self.num_steps_per_epoch
|
| 22 |
-
|
| 23 |
-
def __iter__(self) -> Generator[List[SegmentId], None, None]:
|
| 24 |
-
for _ in range(self.num_steps_per_epoch):
|
| 25 |
-
yield self.sample()
|
| 26 |
-
|
| 27 |
-
def sample(self) -> List[SegmentId]:
|
| 28 |
-
episode_ids = np.random.choice(np.arange(self.dataset.num_episodes), size=self.batch_size, replace=True, p=self.probabilities)
|
| 29 |
-
timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])
|
| 30 |
-
|
| 31 |
-
# padding allowed, both before start and after end
|
| 32 |
-
if self.can_sample_beyond_end:
|
| 33 |
-
starts = timesteps - np.random.randint(0, self.sequence_length, len(timesteps))
|
| 34 |
-
stops = starts + self.sequence_length
|
| 35 |
-
|
| 36 |
-
# padding allowed only before start
|
| 37 |
-
else:
|
| 38 |
-
stops = np.minimum(self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.sequence_length, len(timesteps)))
|
| 39 |
-
starts = stops - self.sequence_length
|
| 40 |
-
|
| 41 |
-
return list(map(lambda x: SegmentId(*x), zip(episode_ids, starts, stops)))
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta-iris/src/data/segment.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass
|
| 8 |
-
class SegmentId:
|
| 9 |
-
episode_id: int
|
| 10 |
-
start: int
|
| 11 |
-
stop: int
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
@dataclass
|
| 15 |
-
class Segment:
|
| 16 |
-
observations: torch.ByteTensor
|
| 17 |
-
actions: torch.LongTensor
|
| 18 |
-
rewards: torch.FloatTensor
|
| 19 |
-
ends: torch.LongTensor
|
| 20 |
-
mask_padding: torch.BoolTensor
|
| 21 |
-
id: SegmentId
|
| 22 |
-
|
| 23 |
-
@property
|
| 24 |
-
def effective_size(self) -> int:
|
| 25 |
-
return self.mask_padding.sum().item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delta-iris/src/data/utils.py
DELETED
|
@@ -1,69 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
from typing import Generator, List
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
from .batch import Batch
|
| 7 |
-
from .episode import Episode
|
| 8 |
-
from .segment import Segment, SegmentId
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def collate_segments_to_batch(segments: List[Segment]) -> Batch:
|
| 12 |
-
return Batch(
|
| 13 |
-
torch.stack(list(map(lambda s: s.observations, segments))).div(255),
|
| 14 |
-
torch.stack(list(map(lambda s: s.actions, segments))),
|
| 15 |
-
torch.stack(list(map(lambda s: s.rewards, segments))),
|
| 16 |
-
torch.stack(list(map(lambda s: s.ends, segments))),
|
| 17 |
-
torch.stack(list(map(lambda s: s.mask_padding, segments))),
|
| 18 |
-
list(map(lambda segment: segment.id, segments))
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def make_segment(episode: Episode, segment_id: SegmentId, should_pad: bool = True) -> Segment:
|
| 23 |
-
assert segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop
|
| 24 |
-
padding_length_right = max(0, segment_id.stop - len(episode))
|
| 25 |
-
padding_length_left = max(0, -segment_id.start)
|
| 26 |
-
assert padding_length_right == padding_length_left == 0 or should_pad
|
| 27 |
-
|
| 28 |
-
def pad(x):
|
| 29 |
-
pad_right = torch.nn.functional.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [padding_length_right]) if padding_length_right > 0 else x
|
| 30 |
-
return torch.nn.functional.pad(pad_right, [0 for _ in range(2 * x.ndim - 2)] + [padding_length_left, 0]) if padding_length_left > 0 else pad_right
|
| 31 |
-
|
| 32 |
-
start = max(0, segment_id.start)
|
| 33 |
-
stop = min(len(episode), segment_id.stop)
|
| 34 |
-
|
| 35 |
-
return Segment(
|
| 36 |
-
pad(episode.observations[start:stop]),
|
| 37 |
-
pad(episode.actions[start:stop]),
|
| 38 |
-
pad(episode.rewards[start:stop]),
|
| 39 |
-
pad(episode.ends[start:stop]),
|
| 40 |
-
mask_padding=torch.cat((torch.zeros(padding_length_left), torch.ones(stop - start), torch.zeros(padding_length_right))).bool(),
|
| 41 |
-
id=SegmentId(segment_id.episode_id, start, stop)
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class DatasetTraverser:
|
| 46 |
-
def __init__(self, dataset, batch_num_samples: int, chunk_size: int) -> None:
|
| 47 |
-
self.dataset = dataset
|
| 48 |
-
self.batch_num_samples = batch_num_samples
|
| 49 |
-
self.chunk_size = chunk_size
|
| 50 |
-
self._num_batches = math.ceil(sum([math.ceil(dataset.lengths[episode_id] / chunk_size) - int(dataset.lengths[episode_id] % chunk_size == 1) for episode_id in range(dataset.num_episodes)]) / batch_num_samples)
|
| 51 |
-
|
| 52 |
-
def __len__(self) -> int:
|
| 53 |
-
return self._num_batches
|
| 54 |
-
|
| 55 |
-
def __iter__(self) -> Generator[Batch, None, None]:
|
| 56 |
-
chunks = []
|
| 57 |
-
|
| 58 |
-
for episode_id in range(self.dataset.num_episodes):
|
| 59 |
-
episode = self.dataset.load_episode(episode_id)
|
| 60 |
-
chunks.extend(make_segment(episode, SegmentId(episode_id, start=i * self.chunk_size, stop=(i + 1) * self.chunk_size), should_pad=True) for i in range(math.ceil(len(episode) / self.chunk_size)))
|
| 61 |
-
if chunks[-1].effective_size < 2:
|
| 62 |
-
chunks.pop()
|
| 63 |
-
|
| 64 |
-
while len(chunks) >= self.batch_num_samples:
|
| 65 |
-
yield collate_segments_to_batch(chunks[:self.batch_num_samples])
|
| 66 |
-
chunks = chunks[self.batch_num_samples:]
|
| 67 |
-
|
| 68 |
-
if len(chunks) > 0:
|
| 69 |
-
yield collate_segments_to_batch(chunks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|