| import json |
|
|
| from torch.utils.data import Dataset |
|
|
| from .utils import to_tensor, read_zarr_with_cache, to_relative_action |
|
|
|
|
| class BaseDataset(Dataset): |
| """Base dataset.""" |
|
|
| def __init__( |
| self, |
| root, |
| instructions, |
| copies=None, |
| relative_action=False, |
| mem_limit=8, |
| actions_only=False, |
| chunk_size=4 |
| ): |
| super().__init__() |
| self.copies = self.train_copies if copies is None else copies |
| self._relative_action = relative_action |
| self._actions_only = actions_only |
| self.chunk_size = chunk_size |
|
|
| |
| self._instructions = self._load_instructions(instructions) |
|
|
| |
| self.annos = read_zarr_with_cache(root, mem_gb=mem_limit) |
| |
| len_ = len(self.annos['action']) |
| for key in self.annos: |
| assert len(self.annos[key]) == len_, f'length mismatch in {key}' |
| print(f"Found {len(self.annos['action'])} samples") |
|
|
| def _load_instructions(self, instruction_file): |
| return json.load(open(instruction_file)) |
|
|
| def _get_attr_by_idx(self, idx, attr, filter_cam=False): |
| t = to_tensor(self.annos[attr][idx:idx + self.chunk_size]) |
| if filter_cam and self.camera_inds is not None: |
| t = t[:, self.camera_inds] |
| return t |
|
|
| def _get_task(self, idx): |
| return ["task"] * self.chunk_size |
|
|
| def _get_instr(self, idx): |
| return ["instruction"] * self.chunk_size |
|
|
| def _get_rgb(self, idx, key='rgb'): |
| return self._get_attr_by_idx(idx, key, True) |
|
|
| def _get_depth(self, idx, key='depth'): |
| return self._get_attr_by_idx(idx, key, True) |
|
|
| def _get_proprioception(self, idx): |
| return self._get_attr_by_idx(idx, 'proprioception', False) |
|
|
| def _get_action(self, idx): |
| if self._relative_action: |
| if 'rel_action' in self.annos: |
| return self._get_attr_by_idx(idx, 'rel_action', False) |
| else: |
| action = self._get_attr_by_idx(idx, 'action', False) |
| prop = self._get_proprioception(idx)[[-1]] |
| action = to_relative_action(action, prop, self.quat_format) |
| else: |
| action = self._get_attr_by_idx(idx, 'action', False) |
| return action |
|
|
| def __getitem__(self, idx): |
| """ |
| self.annos: { |
| action: (N, T, 8) float |
| depth: (N, n_cam, H, W) float16 |
| proprioception: (N, nhist, 8) float |
| rgb: (N, n_cam, 3, H, W) uint8 |
| } |
| In addition self.annos may contain fields for task/instruction ids |
| """ |
| |
| idx = idx % (len(self.annos['action']) // self.chunk_size) |
| |
| idx = idx * self.chunk_size |
| if self._actions_only: |
| return {"action": self._get_action(idx)} |
| return { |
| "task": self._get_task(idx), |
| "instr": self._get_instr(idx), |
| "rgb": self._get_rgb(idx), |
| "depth": self._get_depth(idx), |
| "proprioception": self._get_proprioception(idx), |
| "action": self._get_action(idx) |
| } |
|
|
| def __len__(self): |
| return self.copies * (len(self.annos['action']) // self.chunk_size) |
|
|