Spaces:
Sleeping
Sleeping
| from collections import Counter | |
| import multiprocessing as mp | |
| from pathlib import Path | |
| import shutil | |
| from typing import Any, Dict, List, Optional | |
| import h5py | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset as TorchDataset | |
| from .episode import Episode | |
| from .segment import Segment, SegmentId | |
| from .utils import make_segment | |
| from ..utils import StateDictMixin | |
| class Dataset(StateDictMixin, TorchDataset): | |
| def __init__( | |
| self, | |
| directory: Path, | |
| dataset_full_res: Optional[TorchDataset], | |
| name: Optional[str] = None, | |
| cache_in_ram: bool = False, | |
| use_manager: bool = False, | |
| save_on_disk: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| # State | |
| self.is_static = False | |
| self.num_episodes = None | |
| self.num_steps = None | |
| self.start_idx = None | |
| self.lengths = None | |
| self.counter_rew = None | |
| self.counter_end = None | |
| self._directory = Path(directory).expanduser() | |
| self._name = name if name is not None else self._directory.stem | |
| self._cache_in_ram = cache_in_ram | |
| self._save_on_disk = save_on_disk | |
| self._default_path = self._directory / "info.pt" | |
| self._cache = mp.Manager().dict() if use_manager else {} | |
| self._reset() | |
| self._dataset_full_res = dataset_full_res | |
| def __len__(self) -> int: | |
| return self.num_steps | |
| def __getitem__(self, segment_id: SegmentId) -> Segment: | |
| episode = self.load_episode(segment_id.episode_id) | |
| segment = make_segment(episode, segment_id, should_pad=True) | |
| if self._dataset_full_res is not None: | |
| segment_id_full_res = SegmentId(episode.info["original_file_id"], segment_id.start, segment_id.stop) | |
| segment.info["full_res"] = self._dataset_full_res[segment_id_full_res].obs | |
| elif "full_res" in segment.info: | |
| segment.info["full_res"] = segment.info["full_res"][segment_id.start:segment_id.stop] | |
| return segment | |
| def __str__(self) -> str: | |
| return f"{self.name}: {self.num_episodes} episodes, {self.num_steps} steps." | |
| def name(self) -> str: | |
| return self._name | |
| def counts_rew(self) -> List[int]: | |
| return [self.counter_rew[r] for r in [-1, 0, 1]] | |
| def counts_end(self) -> List[int]: | |
| return [self.counter_end[e] for e in [0, 1]] | |
| def _reset(self) -> None: | |
| self.num_episodes = 0 | |
| self.num_steps = 0 | |
| self.start_idx = np.array([], dtype=np.int64) | |
| self.lengths = np.array([], dtype=np.int64) | |
| self.counter_rew = Counter() | |
| self.counter_end = Counter() | |
| self._cache.clear() | |
| def clear(self) -> None: | |
| self.assert_not_static() | |
| if self._directory.is_dir(): | |
| shutil.rmtree(self._directory) | |
| self._reset() | |
| def load_episode(self, episode_id: int) -> Episode: | |
| if self._cache_in_ram and episode_id in self._cache: | |
| episode = self._cache[episode_id] | |
| else: | |
| episode = Episode.load(self._get_episode_path(episode_id)) | |
| if self._cache_in_ram: | |
| self._cache[episode_id] = episode | |
| return episode | |
| def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int: | |
| self.assert_not_static() | |
| episode = episode.to("cpu") | |
| if episode_id is None: | |
| episode_id = self.num_episodes | |
| self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps]))) | |
| self.lengths = np.concatenate((self.lengths, np.array([len(episode)]))) | |
| self.num_steps += len(episode) | |
| self.num_episodes += 1 | |
| else: | |
| assert episode_id < self.num_episodes | |
| old_episode = self.load_episode(episode_id) | |
| incr_num_steps = len(episode) - len(old_episode) | |
| self.lengths[episode_id] = len(episode) | |
| self.start_idx[episode_id + 1 :] += incr_num_steps | |
| self.num_steps += incr_num_steps | |
| self.counter_rew.subtract(old_episode.rew.sign().tolist()) | |
| self.counter_end.subtract(old_episode.end.tolist()) | |
| self.counter_rew.update(episode.rew.sign().tolist()) | |
| self.counter_end.update(episode.end.tolist()) | |
| if self._save_on_disk: | |
| episode.save(self._get_episode_path(episode_id)) | |
| if self._cache_in_ram: | |
| self._cache[episode_id] = episode | |
| return episode_id | |
| def _get_episode_path(self, episode_id: int) -> Path: | |
| n = 3 # number of hierarchies | |
| powers = np.arange(n) | |
| subfolders = np.floor((episode_id % 10 ** (1 + powers)) / 10**powers) * 10**powers | |
| subfolders = [int(x) for x in subfolders[::-1]] | |
| subfolders = "/".join([f"{x:0{n - i}d}" for i, x in enumerate(subfolders)]) | |
| return self._directory / subfolders / f"{episode_id}.pt" | |
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
| super().load_state_dict(state_dict) | |
| self._cache.clear() | |
| def assert_not_static(self) -> None: | |
| assert not self.is_static, "Trying to modify a static dataset." | |
| def save_to_default_path(self) -> None: | |
| self._default_path.parent.mkdir(exist_ok=True, parents=True) | |
| torch.save(self.state_dict(), self._default_path) | |
| def load_from_default_path(self) -> None: | |
| if self._default_path.is_file(): | |
| self.load_state_dict(torch.load(self._default_path)) | |
| class CSGOHdf5Dataset(StateDictMixin, TorchDataset): | |
| def __init__(self, directory: Path) -> None: | |
| super().__init__() | |
| filenames = sorted(Path(directory).rglob("*.hdf5"), key=lambda x: int(x.stem.split("_")[-1])) | |
| self._filenames = {f"{x.parent.name}/{x.name}": x for x in filenames} | |
| self._length_one_episode = 1000 | |
| self.num_episodes = len(self._filenames) | |
| self.num_steps = self._length_one_episode * self.num_episodes | |
| self.lengths = np.array([self._length_one_episode] * self.num_episodes, dtype=np.int64) | |
| def __len__(self) -> int: | |
| return self.num_steps | |
| def save_to_default_path(self) -> None: | |
| pass | |
| def __getitem__(self, segment_id: SegmentId) -> Segment: | |
| assert segment_id.start < self._length_one_episode and segment_id.stop > 0 and segment_id.start < segment_id.stop | |
| pad_len_right = max(0, segment_id.stop - self._length_one_episode) | |
| pad_len_left = max(0, -segment_id.start) | |
| start = max(0, segment_id.start) | |
| stop = min(self._length_one_episode, segment_id.stop) | |
| mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool() | |
| with h5py.File(self._filenames[segment_id.episode_id], "r") as f: | |
| obs = torch.stack([torch.tensor(f[f"frame_{i}_x"][:]).flip(2).permute(2, 0, 1).div(255).mul(2).sub(1) for i in range(start, stop)]) | |
| act = torch.tensor(np.array([f[f"frame_{i}_y"][:] for i in range(start, stop)])) | |
| states = torch.stack([torch.tensor(f[f"frame_{i}_observation"][:]) for i in range(start, stop)]) | |
| ego_state = torch.stack([torch.tensor(f[f"frame_{i}_ego_state"][:]) for i in range(start, stop)]) | |
| def pad(x): | |
| right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x | |
| return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right | |
| obs = pad(obs) | |
| act = pad(act) | |
| rew = torch.zeros(obs.size(0)) | |
| end = torch.zeros(obs.size(0), dtype=torch.uint8) | |
| trunc = torch.zeros(obs.size(0), dtype=torch.uint8) | |
| return Segment(obs, act, rew, end, trunc, mask_padding, states=states, ego_state=ego_state, info={}, id=SegmentId(segment_id.episode_id, start, stop)) | |
| def load_episode(self, episode_id: int) -> Episode: # used by DatasetTraverser | |
| s = self[SegmentId(episode_id, 0, self._length_one_episode)] | |
| return Episode(s.obs, s.act, s.rew, s.end, s.trunc, s.info) |