Spaces:
Sleeping
Sleeping
| from collections import Counter | |
| import multiprocessing as mp | |
| from pathlib import Path | |
| import shutil | |
| from typing import Any, Dict, List, Optional | |
| import numpy as np | |
| import torch | |
| from .episode import Episode | |
| from .segment import Segment, SegmentId | |
| from .utils import make_segment | |
| from utils import StateDictMixin | |
| class Dataset(StateDictMixin, torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| directory: Path, | |
| name: Optional[str] = None, | |
| cache_in_ram: bool = False, | |
| use_manager: bool = False, | |
| save_on_disk: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| # State | |
| self.is_static = False | |
| self.num_episodes = None | |
| self.num_steps = None | |
| self.start_idx = None | |
| self.lengths = None | |
| self.counter_rew = None | |
| self.counter_end = None | |
| self._directory = Path(directory).expanduser() | |
| self._name = name if name is not None else self._directory.stem | |
| self._cache_in_ram = cache_in_ram | |
| self._save_on_disk = save_on_disk | |
| self._default_path = self._directory / "info.pt" | |
| self._cache = mp.Manager().dict() if use_manager else {} | |
| self._reset() | |
| def __len__(self) -> int: | |
| return self.num_steps | |
| def __getitem__(self, segment_id: SegmentId) -> Segment: | |
| episode = self.load_episode(segment_id.episode_id) | |
| return make_segment(episode, segment_id, should_pad=True) | |
| def __str__(self) -> str: | |
| return f"{self.name}: {self.num_episodes} episodes, {self.num_steps} steps." | |
| def name(self) -> str: | |
| return self._name | |
| def counts_rew(self) -> List[int]: | |
| return [self.counter_rew[r] for r in [-1, 0, 1]] | |
| def counts_end(self) -> List[int]: | |
| return [self.counter_end[e] for e in [0, 1]] | |
| def _reset(self) -> None: | |
| self.num_episodes = 0 | |
| self.num_steps = 0 | |
| self.start_idx = np.array([], dtype=np.int64) | |
| self.lengths = np.array([], dtype=np.int64) | |
| self.counter_rew = Counter() | |
| self.counter_end = Counter() | |
| self._cache.clear() | |
| def clear(self) -> None: | |
| self.assert_not_static() | |
| if self._directory.is_dir(): | |
| shutil.rmtree(self._directory) | |
| self._reset() | |
| def load_episode(self, episode_id: int) -> Episode: | |
| if self._cache_in_ram and episode_id in self._cache: | |
| episode = self._cache[episode_id] | |
| else: | |
| episode = Episode.load(self._get_episode_path(episode_id)) | |
| if self._cache_in_ram: | |
| self._cache[episode_id] = episode | |
| return episode | |
| def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int: | |
| self.assert_not_static() | |
| episode = episode.to("cpu") | |
| if episode_id is None: | |
| episode_id = self.num_episodes | |
| self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps]))) | |
| self.lengths = np.concatenate((self.lengths, np.array([len(episode)]))) | |
| self.num_steps += len(episode) | |
| self.num_episodes += 1 | |
| else: | |
| assert episode_id < self.num_episodes | |
| old_episode = self.load_episode(episode_id) | |
| incr_num_steps = len(episode) - len(old_episode) | |
| self.lengths[episode_id] = len(episode) | |
| self.start_idx[episode_id + 1 :] += incr_num_steps | |
| self.num_steps += incr_num_steps | |
| self.counter_rew.subtract(old_episode.rew.sign().tolist()) | |
| self.counter_end.subtract(old_episode.end.tolist()) | |
| self.counter_rew.update(episode.rew.sign().tolist()) | |
| self.counter_end.update(episode.end.tolist()) | |
| if self._save_on_disk: | |
| episode.save(self._get_episode_path(episode_id)) | |
| if self._cache_in_ram: | |
| self._cache[episode_id] = episode | |
| return episode_id | |
| def _get_episode_path(self, episode_id: int) -> Path: | |
| n = 3 # number of hierarchies | |
| powers = np.arange(n) | |
| subfolders = np.floor((episode_id % 10 ** (1 + powers)) / 10**powers) * 10**powers | |
| subfolders = [int(x) for x in subfolders[::-1]] | |
| subfolders = "/".join([f"{x:0{n - i}d}" for i, x in enumerate(subfolders)]) | |
| return self._directory / subfolders / f"{episode_id}.pt" | |
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
| super().load_state_dict(state_dict) | |
| self._cache.clear() | |
| def assert_not_static(self) -> None: | |
| assert not self.is_static, "Trying to modify a static dataset." | |
| def save_to_default_path(self) -> None: | |
| self._default_path.parent.mkdir(exist_ok=True, parents=True) | |
| torch.save(self.state_dict(), self._default_path) | |
| def load_from_default_path(self) -> None: | |
| if self._default_path.is_file(): | |
| self.load_state_dict(torch.load(self._default_path)) | |