File size: 4,130 Bytes
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pathlib import Path
import shutil
from typing import Dict, Optional, Union

import numpy as np
import torch

from .episode import Episode
from .segment import Segment, SegmentId
from .utils import make_segment


class EpisodeDataset(torch.utils.data.Dataset):
    def __init__(self, directory: Path, name: str) -> None:
        super().__init__()
        self.name = name
        self.directory = Path(directory)
        self.num_episodes, self.num_steps, self.start_idx, self.lengths = None, None, None, None

        if not self.directory.is_dir():
            self._init_empty()
        else:
            self._load_info()
            print(f'({name}) {self.num_episodes} episodes, {self.num_steps} steps.')

    @property
    def info_path(self) -> Path:
        return self.directory / 'info.pt'

    @property
    def info(self) -> Dict[str, Union[int, np.ndarray]]:
        return {'num_episodes': self.num_episodes, 'num_steps': self.num_steps, 'start_idx': self.start_idx, 'lengths': self.lengths}

    def __len__(self) -> int:
        return self.num_steps

    def __getitem__(self, segment_id: SegmentId) -> Segment:
        return self._load_segment(segment_id)

    def _init_empty(self) -> None:
        self.directory.mkdir(parents=True, exist_ok=False)
        self.num_episodes = 0
        self.num_steps = 0
        self.start_idx = np.array([], dtype=np.int64)            
        self.lengths = np.array([], dtype=np.int64)
        self.save_info()

    def _load_info(self) -> None:
        info = torch.load(self.info_path)
        self.num_steps = info['num_steps']
        self.num_episodes = info['num_episodes']
        self.start_idx = info['start_idx']
        self.lengths = info['lengths']

    def save_info(self) -> None:
        torch.save(self.info, self.info_path)

    def clear(self) -> None:
        shutil.rmtree(self.directory)
        self._init_empty()

    def _get_episode_path(self, episode_id: int) -> Path:
        n = 3 # number of hierarchies
        powers = np.arange(n)
        subfolders = list(map(int, np.floor((episode_id % 10 ** (1 + powers)) / 10 ** powers) * 10 ** powers))[::-1]
        return self.directory / '/'.join(list(map(lambda x: f'{x[1]:0{n - x[0]}d}', enumerate(subfolders)))) / f'{episode_id}.pt'

    def _load_segment(self, segment_id: SegmentId, should_pad: bool = True) -> Segment:
        episode = self.load_episode(segment_id.episode_id)
        return make_segment(episode, segment_id, should_pad)

    def load_episode(self, episode_id: int) -> Episode:
        return Episode(**torch.load(self._get_episode_path(episode_id)))

    def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
        if episode_id is None:
            episode_id = self.num_episodes
            self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
            self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
            self.num_steps += len(episode)
            self.num_episodes += 1

        else:
            assert episode_id < self.num_episodes
            old_episode = self.load_episode(episode_id) 
            episode = old_episode.merge(episode)
            incr_num_steps = len(episode) - len(old_episode)
            self.lengths[episode_id] = len(episode)
            self.start_idx[episode_id + 1:] += incr_num_steps            
            self.num_steps += incr_num_steps

        episode_path = self._get_episode_path(episode_id)
        episode_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(episode.__dict__, episode_path.with_suffix('.tmp'))
        episode_path.with_suffix('.tmp').rename(episode_path)

        return episode_id

    def get_episode_id_from_global_idx(self, global_idx: np.ndarray) -> np.ndarray:
        return (np.argmax(self.start_idx.reshape(-1, 1) > global_idx, axis=0) - 1) % self.num_episodes

    def get_global_idx_from_segment_id(self, segment_id: SegmentId) -> np.ndarray:
        start_idx = self.start_idx[segment_id.episode_id]
        return np.arange(start_idx + segment_id.start, start_idx + segment_id.stop)