ShaswatRobotics's picture
Upload 35 files
23bc32f verified
raw
history blame
4.13 kB
from pathlib import Path
import shutil
from typing import Dict, Optional, Union
import numpy as np
import torch
from .episode import Episode
from .segment import Segment, SegmentId
from .utils import make_segment
class EpisodeDataset(torch.utils.data.Dataset):
def __init__(self, directory: Path, name: str) -> None:
super().__init__()
self.name = name
self.directory = Path(directory)
self.num_episodes, self.num_steps, self.start_idx, self.lengths = None, None, None, None
if not self.directory.is_dir():
self._init_empty()
else:
self._load_info()
print(f'({name}) {self.num_episodes} episodes, {self.num_steps} steps.')
@property
def info_path(self) -> Path:
return self.directory / 'info.pt'
@property
def info(self) -> Dict[str, Union[int, np.ndarray]]:
return {'num_episodes': self.num_episodes, 'num_steps': self.num_steps, 'start_idx': self.start_idx, 'lengths': self.lengths}
def __len__(self) -> int:
return self.num_steps
def __getitem__(self, segment_id: SegmentId) -> Segment:
return self._load_segment(segment_id)
def _init_empty(self) -> None:
self.directory.mkdir(parents=True, exist_ok=False)
self.num_episodes = 0
self.num_steps = 0
self.start_idx = np.array([], dtype=np.int64)
self.lengths = np.array([], dtype=np.int64)
self.save_info()
def _load_info(self) -> None:
info = torch.load(self.info_path)
self.num_steps = info['num_steps']
self.num_episodes = info['num_episodes']
self.start_idx = info['start_idx']
self.lengths = info['lengths']
def save_info(self) -> None:
torch.save(self.info, self.info_path)
def clear(self) -> None:
shutil.rmtree(self.directory)
self._init_empty()
def _get_episode_path(self, episode_id: int) -> Path:
n = 3 # number of hierarchies
powers = np.arange(n)
subfolders = list(map(int, np.floor((episode_id % 10 ** (1 + powers)) / 10 ** powers) * 10 ** powers))[::-1]
return self.directory / '/'.join(list(map(lambda x: f'{x[1]:0{n - x[0]}d}', enumerate(subfolders)))) / f'{episode_id}.pt'
def _load_segment(self, segment_id: SegmentId, should_pad: bool = True) -> Segment:
episode = self.load_episode(segment_id.episode_id)
return make_segment(episode, segment_id, should_pad)
def load_episode(self, episode_id: int) -> Episode:
return Episode(**torch.load(self._get_episode_path(episode_id)))
def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
if episode_id is None:
episode_id = self.num_episodes
self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
self.num_steps += len(episode)
self.num_episodes += 1
else:
assert episode_id < self.num_episodes
old_episode = self.load_episode(episode_id)
episode = old_episode.merge(episode)
incr_num_steps = len(episode) - len(old_episode)
self.lengths[episode_id] = len(episode)
self.start_idx[episode_id + 1:] += incr_num_steps
self.num_steps += incr_num_steps
episode_path = self._get_episode_path(episode_id)
episode_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(episode.__dict__, episode_path.with_suffix('.tmp'))
episode_path.with_suffix('.tmp').rename(episode_path)
return episode_id
def get_episode_id_from_global_idx(self, global_idx: np.ndarray) -> np.ndarray:
return (np.argmax(self.start_idx.reshape(-1, 1) > global_idx, axis=0) - 1) % self.num_episodes
def get_global_idx_from_segment_id(self, segment_id: SegmentId) -> np.ndarray:
start_idx = self.start_idx[segment_id.episode_id]
return np.arange(start_idx + segment_id.start, start_idx + segment_id.stop)