|
|
from pathlib import Path |
|
|
import shutil |
|
|
from typing import Dict, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from .episode import Episode |
|
|
from .segment import Segment, SegmentId |
|
|
from .utils import make_segment |
|
|
|
|
|
|
|
|
class EpisodeDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, directory: Path, name: str) -> None: |
|
|
super().__init__() |
|
|
self.name = name |
|
|
self.directory = Path(directory) |
|
|
self.num_episodes, self.num_steps, self.start_idx, self.lengths = None, None, None, None |
|
|
|
|
|
if not self.directory.is_dir(): |
|
|
self._init_empty() |
|
|
else: |
|
|
self._load_info() |
|
|
print(f'({name}) {self.num_episodes} episodes, {self.num_steps} steps.') |
|
|
|
|
|
@property |
|
|
def info_path(self) -> Path: |
|
|
return self.directory / 'info.pt' |
|
|
|
|
|
@property |
|
|
def info(self) -> Dict[str, Union[int, np.ndarray]]: |
|
|
return {'num_episodes': self.num_episodes, 'num_steps': self.num_steps, 'start_idx': self.start_idx, 'lengths': self.lengths} |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self.num_steps |
|
|
|
|
|
def __getitem__(self, segment_id: SegmentId) -> Segment: |
|
|
return self._load_segment(segment_id) |
|
|
|
|
|
def _init_empty(self) -> None: |
|
|
self.directory.mkdir(parents=True, exist_ok=False) |
|
|
self.num_episodes = 0 |
|
|
self.num_steps = 0 |
|
|
self.start_idx = np.array([], dtype=np.int64) |
|
|
self.lengths = np.array([], dtype=np.int64) |
|
|
self.save_info() |
|
|
|
|
|
def _load_info(self) -> None: |
|
|
info = torch.load(self.info_path) |
|
|
self.num_steps = info['num_steps'] |
|
|
self.num_episodes = info['num_episodes'] |
|
|
self.start_idx = info['start_idx'] |
|
|
self.lengths = info['lengths'] |
|
|
|
|
|
def save_info(self) -> None: |
|
|
torch.save(self.info, self.info_path) |
|
|
|
|
|
def clear(self) -> None: |
|
|
shutil.rmtree(self.directory) |
|
|
self._init_empty() |
|
|
|
|
|
def _get_episode_path(self, episode_id: int) -> Path: |
|
|
n = 3 |
|
|
powers = np.arange(n) |
|
|
subfolders = list(map(int, np.floor((episode_id % 10 ** (1 + powers)) / 10 ** powers) * 10 ** powers))[::-1] |
|
|
return self.directory / '/'.join(list(map(lambda x: f'{x[1]:0{n - x[0]}d}', enumerate(subfolders)))) / f'{episode_id}.pt' |
|
|
|
|
|
def _load_segment(self, segment_id: SegmentId, should_pad: bool = True) -> Segment: |
|
|
episode = self.load_episode(segment_id.episode_id) |
|
|
return make_segment(episode, segment_id, should_pad) |
|
|
|
|
|
def load_episode(self, episode_id: int) -> Episode: |
|
|
return Episode(**torch.load(self._get_episode_path(episode_id))) |
|
|
|
|
|
def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int: |
|
|
if episode_id is None: |
|
|
episode_id = self.num_episodes |
|
|
self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps]))) |
|
|
self.lengths = np.concatenate((self.lengths, np.array([len(episode)]))) |
|
|
self.num_steps += len(episode) |
|
|
self.num_episodes += 1 |
|
|
|
|
|
else: |
|
|
assert episode_id < self.num_episodes |
|
|
old_episode = self.load_episode(episode_id) |
|
|
episode = old_episode.merge(episode) |
|
|
incr_num_steps = len(episode) - len(old_episode) |
|
|
self.lengths[episode_id] = len(episode) |
|
|
self.start_idx[episode_id + 1:] += incr_num_steps |
|
|
self.num_steps += incr_num_steps |
|
|
|
|
|
episode_path = self._get_episode_path(episode_id) |
|
|
episode_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
torch.save(episode.__dict__, episode_path.with_suffix('.tmp')) |
|
|
episode_path.with_suffix('.tmp').rename(episode_path) |
|
|
|
|
|
return episode_id |
|
|
|
|
|
def get_episode_id_from_global_idx(self, global_idx: np.ndarray) -> np.ndarray: |
|
|
return (np.argmax(self.start_idx.reshape(-1, 1) > global_idx, axis=0) - 1) % self.num_episodes |
|
|
|
|
|
def get_global_idx_from_segment_id(self, segment_id: SegmentId) -> np.ndarray: |
|
|
start_idx = self.start_idx[segment_id.episode_id] |
|
|
return np.arange(start_idx + segment_id.start, start_idx + segment_id.stop) |
|
|
|