lsnu's picture
Add files using upload-large-folder tool
5ce8761 verified
import json
from torch.utils.data import Dataset
from .utils import to_tensor, read_zarr_with_cache, to_relative_action
class BaseDataset(Dataset):
"""Base dataset."""
def __init__(
self,
root, # the directory path of the dataset
instructions, # path to instruction file
copies=None, # copy the dataset for less loader restarts
relative_action=False, # whether to return relative actions
mem_limit=8, # cache limit per dataset class in GigaBytes
actions_only=False, # return actions without observations
chunk_size=4 # chunk size for zarr
):
super().__init__()
self.copies = self.train_copies if copies is None else copies
self._relative_action = relative_action
self._actions_only = actions_only
self.chunk_size = chunk_size
# Load instructions
self._instructions = self._load_instructions(instructions)
# Load all annotations lazily
self.annos = read_zarr_with_cache(root, mem_gb=mem_limit)
# Sanity check
len_ = len(self.annos['action'])
for key in self.annos:
assert len(self.annos[key]) == len_, f'length mismatch in {key}'
print(f"Found {len(self.annos['action'])} samples")
def _load_instructions(self, instruction_file):
return json.load(open(instruction_file))
def _get_attr_by_idx(self, idx, attr, filter_cam=False):
t = to_tensor(self.annos[attr][idx:idx + self.chunk_size])
if filter_cam and self.camera_inds is not None:
t = t[:, self.camera_inds]
return t
def _get_task(self, idx):
return ["task"] * self.chunk_size
def _get_instr(self, idx):
return ["instruction"] * self.chunk_size
def _get_rgb(self, idx, key='rgb'):
return self._get_attr_by_idx(idx, key, True)
def _get_depth(self, idx, key='depth'):
return self._get_attr_by_idx(idx, key, True)
def _get_proprioception(self, idx):
return self._get_attr_by_idx(idx, 'proprioception', False)
def _get_action(self, idx):
if self._relative_action:
if 'rel_action' in self.annos:
return self._get_attr_by_idx(idx, 'rel_action', False)
else:
action = self._get_attr_by_idx(idx, 'action', False)
prop = self._get_proprioception(idx)[[-1]]
action = to_relative_action(action, prop, self.quat_format)
else:
action = self._get_attr_by_idx(idx, 'action', False)
return action
def __getitem__(self, idx):
"""
self.annos: {
action: (N, T, 8) float
depth: (N, n_cam, H, W) float16
proprioception: (N, nhist, 8) float
rgb: (N, n_cam, 3, H, W) uint8
}
In addition self.annos may contain fields for task/instruction ids
"""
# First detect which copy we fall into
idx = idx % (len(self.annos['action']) // self.chunk_size)
# and then which chunk
idx = idx * self.chunk_size
if self._actions_only:
return {"action": self._get_action(idx)}
return {
"task": self._get_task(idx),
"instr": self._get_instr(idx), # [str]
"rgb": self._get_rgb(idx), # tensor(n_cam, 3, H, W)
"depth": self._get_depth(idx), # tensor(n_cam, H, W)
"proprioception": self._get_proprioception(idx), # tensor(1, 8)
"action": self._get_action(idx) # tensor(T, 8)
}
def __len__(self):
return self.copies * (len(self.annos['action']) // self.chunk_size)