""" This file contains Dataset classes that are used by torch dataloaders to fetch batches from hdf5 files. """ import os import h5py import numpy as np from copy import deepcopy from contextlib import contextmanager from collections import OrderedDict import torch.utils.data import robomimic.utils.tensor_utils as TensorUtils import robomimic.utils.obs_utils as ObsUtils import robomimic.utils.action_utils as AcUtils import robomimic.utils.log_utils as LogUtils class SequenceDataset(torch.utils.data.Dataset): def __init__( self, hdf5_path, obs_keys, action_keys, dataset_keys, action_config, frame_stack=1, seq_length=1, pad_frame_stack=True, pad_seq_length=True, get_pad_mask=False, goal_mode=None, hdf5_cache_mode=None, hdf5_use_swmr=True, hdf5_normalize_obs=False, filter_by_attribute=None, load_next_obs=True, ): """ Dataset class for fetching sequences of experience. Length of the fetched sequence is equal to (@frame_stack - 1 + @seq_length) Args: hdf5_path (str): path to hdf5 obs_keys (tuple, list): keys to observation items (image, object, etc) to be fetched from the dataset action_config (dict): TODO dataset_keys (tuple, list): keys to dataset items (actions, rewards, etc) to be fetched from the dataset frame_stack (int): numbers of stacked frames to fetch. Defaults to 1 (single frame). seq_length (int): length of sequences to sample. Defaults to 1 (single frame). pad_frame_stack (int): whether to pad sequence for frame stacking at the beginning of a demo. This ensures that partial frame stacks are observed, such as (s_0, s_0, s_0, s_1). Otherwise, the first frame stacked observation would be (s_0, s_1, s_2, s_3). pad_seq_length (int): whether to pad sequence for sequence fetching at the end of a demo. This ensures that partial sequences at the end of a demonstration are observed, such as (s_{T-1}, s_{T}, s_{T}, s_{T}). Otherwise, the last sequence provided would be (s_{T-3}, s_{T-2}, s_{T-1}, s_{T}). get_pad_mask (bool): if True, also provide padding masks as part of the batch. This can be useful for masking loss functions on padded parts of the data. goal_mode (str): either "last" or None. Defaults to None, which is to not fetch goals hdf5_cache_mode (str): one of ["all", "low_dim", or None]. Set to "all" to cache entire hdf5 in memory - this is by far the fastest for data loading. Set to "low_dim" to cache all non-image data. Set to None to use no caching - in this case, every batch sample is retrieved via file i/o. You should almost never set this to None, even for large image datasets. hdf5_use_swmr (bool): whether to use swmr feature when opening the hdf5 file. This ensures that multiple Dataset instances can all access the same hdf5 file without problems. hdf5_normalize_obs (bool): if True, normalize observations by computing the mean observation and std of each observation (in each dimension and modality), and normalizing to unit mean and variance in each dimension. filter_by_attribute (str): if provided, use the provided filter key to look up a subset of demonstrations to load load_next_obs (bool): whether to load next_obs from the dataset """ super(SequenceDataset, self).__init__() self.hdf5_path = os.path.expandvars(os.path.expanduser(hdf5_path)) self.hdf5_use_swmr = hdf5_use_swmr self.hdf5_normalize_obs = hdf5_normalize_obs self._hdf5_file = None assert hdf5_cache_mode in ["all", "low_dim", None] self.hdf5_cache_mode = hdf5_cache_mode self.load_next_obs = load_next_obs self.filter_by_attribute = filter_by_attribute # get all keys that needs to be fetched self.obs_keys = tuple(obs_keys) self.action_keys = tuple(action_keys) self.dataset_keys = tuple(dataset_keys) # add action keys to dataset keys if self.action_keys is not None: self.dataset_keys = tuple(set(self.dataset_keys).union(set(self.action_keys))) self.action_config = action_config self.n_frame_stack = frame_stack assert self.n_frame_stack >= 1 self.seq_length = seq_length assert self.seq_length >= 1 self.goal_mode = goal_mode if self.goal_mode is not None: assert self.goal_mode in ["last"] if not self.load_next_obs: assert self.goal_mode != "last" # we use last next_obs as goal self.pad_seq_length = pad_seq_length self.pad_frame_stack = pad_frame_stack self.get_pad_mask = get_pad_mask self.load_demo_info(filter_by_attribute=self.filter_by_attribute) # maybe prepare for observation normalization self.obs_normalization_stats = None if self.hdf5_normalize_obs: self.obs_normalization_stats = self.normalize_obs() # prepare for action normalization self.action_normalization_stats = None # maybe store dataset in memory for fast access if self.hdf5_cache_mode in ["all", "low_dim"]: obs_keys_in_memory = self.obs_keys if self.hdf5_cache_mode == "low_dim": # only store low-dim observations obs_keys_in_memory = [] for k in self.obs_keys: if ObsUtils.key_is_obs_modality(k, "low_dim"): obs_keys_in_memory.append(k) self.obs_keys_in_memory = obs_keys_in_memory self.hdf5_cache = self.load_dataset_in_memory( demo_list=self.demos, hdf5_file=self.hdf5_file, obs_keys=self.obs_keys_in_memory, dataset_keys=self.dataset_keys, load_next_obs=self.load_next_obs ) if self.hdf5_cache_mode == "all": # cache getitem calls for even more speedup. We don't do this for # "low-dim" since image observations require calls to getitem anyways. print("SequenceDataset: caching get_item calls...") self.getitem_cache = [self.get_item(i) for i in LogUtils.custom_tqdm(range(len(self)))] # don't need the previous cache anymore del self.hdf5_cache self.hdf5_cache = None else: self.hdf5_cache = None self.close_and_delete_hdf5_handle() def load_demo_info(self, filter_by_attribute=None, demos=None): """ Args: filter_by_attribute (str): if provided, use the provided filter key to select a subset of demonstration trajectories to load demos (list): list of demonstration keys to load from the hdf5 file. If omitted, all demos in the file (or under the @filter_by_attribute filter key) are used. """ # filter demo trajectory by mask if demos is not None: self.demos = demos elif filter_by_attribute is not None: self.demos = [elem.decode("utf-8") for elem in np.array(self.hdf5_file["mask/{}".format(filter_by_attribute)][:])] else: self.demos = list(self.hdf5_file["data"].keys()) # sort demo keys inds = np.argsort([int(elem[5:]) for elem in self.demos]) self.demos = [self.demos[i] for i in inds] self.n_demos = len(self.demos) # keep internal index maps to know which transitions belong to which demos self._index_to_demo_id = dict() # maps every index to a demo id self._demo_id_to_start_indices = dict() # gives start index per demo id self._demo_id_to_demo_length = dict() # determine index mapping self.total_num_sequences = 0 for ep in self.demos: demo_length = self.hdf5_file["data/{}".format(ep)].attrs["num_samples"] self._demo_id_to_start_indices[ep] = self.total_num_sequences self._demo_id_to_demo_length[ep] = demo_length num_sequences = demo_length # determine actual number of sequences taking into account whether to pad for frame_stack and seq_length if not self.pad_frame_stack: num_sequences -= (self.n_frame_stack - 1) if not self.pad_seq_length: num_sequences -= (self.seq_length - 1) if self.pad_seq_length: assert demo_length >= 1 # sequence needs to have at least one sample num_sequences = max(num_sequences, 1) else: assert num_sequences >= 1 # assume demo_length >= (self.n_frame_stack - 1 + self.seq_length) for _ in range(num_sequences): self._index_to_demo_id[self.total_num_sequences] = ep self.total_num_sequences += 1 @property def hdf5_file(self): """ This property allows for a lazy hdf5 file open. """ if self._hdf5_file is None: self._hdf5_file = h5py.File(self.hdf5_path, 'r', swmr=self.hdf5_use_swmr, libver='latest') return self._hdf5_file def close_and_delete_hdf5_handle(self): """ Maybe close the file handle. """ if self._hdf5_file is not None: self._hdf5_file.close() self._hdf5_file = None @contextmanager def hdf5_file_opened(self): """ Convenient context manager to open the file on entering the scope and then close it on leaving. """ should_close = self._hdf5_file is None yield self.hdf5_file if should_close: self.close_and_delete_hdf5_handle() def __del__(self): self.close_and_delete_hdf5_handle() def __repr__(self): """ Pretty print the class and important attributes on a call to `print`. """ msg = str(self.__class__.__name__) msg += " (\n\tpath={}\n\tobs_keys={}\n\tseq_length={}\n\tfilter_key={}\n\tframe_stack={}\n" msg += "\tpad_seq_length={}\n\tpad_frame_stack={}\n\tgoal_mode={}\n" msg += "\tcache_mode={}\n" msg += "\tnum_demos={}\n\tnum_sequences={}\n)" filter_key_str = self.filter_by_attribute if self.filter_by_attribute is not None else "none" goal_mode_str = self.goal_mode if self.goal_mode is not None else "none" cache_mode_str = self.hdf5_cache_mode if self.hdf5_cache_mode is not None else "none" msg = msg.format(self.hdf5_path, self.obs_keys, self.seq_length, filter_key_str, self.n_frame_stack, self.pad_seq_length, self.pad_frame_stack, goal_mode_str, cache_mode_str, self.n_demos, self.total_num_sequences) return msg def __len__(self): """ Ensure that the torch dataloader will do a complete pass through all sequences in the dataset before starting a new iteration. """ return self.total_num_sequences def load_dataset_in_memory(self, demo_list, hdf5_file, obs_keys, dataset_keys, load_next_obs): """ Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this differs from `self.getitem_cache`, which, if active, actually caches the outputs of the `getitem` operation. Args: demo_list (list): list of demo keys, e.g., 'demo_0' hdf5_file (h5py.File): file handle to the hdf5 dataset. obs_keys (list, tuple): observation keys to fetch, e.g., 'images' dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions' load_next_obs (bool): whether to load next_obs from the dataset Returns: all_data (dict): dictionary of loaded data. """ all_data = dict() print("SequenceDataset: loading dataset into memory...") for ep in LogUtils.custom_tqdm(demo_list): all_data[ep] = {} all_data[ep]["attrs"] = {} all_data[ep]["attrs"]["num_samples"] = hdf5_file["data/{}".format(ep)].attrs["num_samples"] # get obs all_data[ep]["obs"] = {k: hdf5_file["data/{}/obs/{}".format(ep, k)][()] for k in obs_keys} if load_next_obs: all_data[ep]["next_obs"] = {k: hdf5_file["data/{}/next_obs/{}".format(ep, k)][()] for k in obs_keys} # get other dataset keys for k in dataset_keys: if k in hdf5_file["data/{}".format(ep)]: all_data[ep][k] = hdf5_file["data/{}/{}".format(ep, k)][()].astype('float32') else: all_data[ep][k] = np.zeros((all_data[ep]["attrs"]["num_samples"], 1), dtype=np.float32) if "model_file" in hdf5_file["data/{}".format(ep)].attrs: all_data[ep]["attrs"]["model_file"] = hdf5_file["data/{}".format(ep)].attrs["model_file"] return all_data def normalize_obs(self): """ Computes a dataset-wide mean and standard deviation for the observations (per dimension and per obs key) and returns it. """ # Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate # with the previous statistics. ep = self.demos[0] obs_traj = {k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in self.obs_keys} obs_traj = ObsUtils.process_obs_dict(obs_traj) merged_stats = _compute_traj_stats(obs_traj) print("SequenceDataset: normalizing observations...") for ep in LogUtils.custom_tqdm(self.demos[1:]): obs_traj = {k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype('float32') for k in self.obs_keys} obs_traj = ObsUtils.process_obs_dict(obs_traj) traj_stats = _compute_traj_stats(obs_traj) merged_stats = _aggregate_traj_stats(merged_stats, traj_stats) obs_normalization_stats = { k : {} for k in merged_stats } for k in merged_stats: # note we add a small tolerance of 1e-3 for std obs_normalization_stats[k]["mean"] = merged_stats[k]["mean"].astype(np.float32) obs_normalization_stats[k]["std"] = (np.sqrt(merged_stats[k]["sqdiff"] / merged_stats[k]["n"]) + 1e-3).astype(np.float32) return obs_normalization_stats def get_obs_normalization_stats(self): """ Returns dictionary of mean and std for each observation key if using observation normalization, otherwise None. Returns: obs_normalization_stats (dict): a dictionary for observation normalization. This maps observation keys to dicts with a "mean" and "std" of shape (1, ...) where ... is the default shape for the observation. """ assert self.hdf5_normalize_obs, "not using observation normalization!" return deepcopy(self.obs_normalization_stats) def get_action_traj(self, ep): action_traj = dict() for key in self.action_keys: action_traj[key] = self.hdf5_file["data/{}/{}".format(ep, key)][()].astype('float32') return action_traj def get_action_stats(self): ep = self.demos[0] action_traj = self.get_action_traj(ep) action_stats = _compute_traj_stats(action_traj) print("SequenceDataset: normalizing actions...") for ep in LogUtils.custom_tqdm(self.demos[1:]): action_traj = self.get_action_traj(ep) traj_stats = _compute_traj_stats(action_traj) action_stats = _aggregate_traj_stats(action_stats, traj_stats) return action_stats def set_action_normalization_stats(self, action_normalization_stats): self.action_normalization_stats = action_normalization_stats def get_action_normalization_stats(self): """ Computes a dataset-wide min, max, mean and standard deviation for the actions (per dimension) and returns it. """ # Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate # with the previous statistics. if self.action_normalization_stats is None: action_stats = self.get_action_stats() self.action_normalization_stats = action_stats_to_normalization_stats( action_stats, self.action_config) return self.action_normalization_stats def get_dataset_for_ep(self, ep, key): """ Helper utility to get a dataset for a specific demonstration. Takes into account whether the dataset has been loaded into memory. """ # check if this key should be in memory key_should_be_in_memory = (self.hdf5_cache_mode in ["all", "low_dim"]) if key_should_be_in_memory: # if key is an observation, it may not be in memory if '/' in key: key1, key2 = key.split('/') assert(key1 in ['obs', 'next_obs', 'action_dict']) if key2 not in self.obs_keys_in_memory: key_should_be_in_memory = False if key_should_be_in_memory: # read cache if '/' in key: key1, key2 = key.split('/') assert(key1 in ['obs', 'next_obs', 'action_dict']) ret = self.hdf5_cache[ep][key1][key2] else: ret = self.hdf5_cache[ep][key] else: # read from file hd5key = "data/{}/{}".format(ep, key) ret = self.hdf5_file[hd5key] return ret def __getitem__(self, index): """ Fetch dataset sequence @index (inferred through internal index map), using the getitem_cache if available. """ if self.hdf5_cache_mode == "all": return self.getitem_cache[index] return self.get_item(index) def get_item(self, index): """ Main implementation of getitem when not using cache. """ demo_id = self._index_to_demo_id[index] demo_start_index = self._demo_id_to_start_indices[demo_id] demo_length = self._demo_id_to_demo_length[demo_id] # start at offset index if not padding for frame stacking demo_index_offset = 0 if self.pad_frame_stack else (self.n_frame_stack - 1) index_in_demo = index - demo_start_index + demo_index_offset # end at offset index if not padding for seq length demo_length_offset = 0 if self.pad_seq_length else (self.seq_length - 1) end_index_in_demo = demo_length - demo_length_offset meta = self.get_dataset_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=self.dataset_keys, num_frames_to_stack=self.n_frame_stack - 1, # note: need to decrement self.n_frame_stack by one seq_length=self.seq_length ) # determine goal index goal_index = None if self.goal_mode == "last": goal_index = end_index_in_demo - 1 meta["obs"] = self.get_obs_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=self.obs_keys, num_frames_to_stack=self.n_frame_stack - 1, seq_length=self.seq_length, prefix="obs" ) if self.load_next_obs: meta["next_obs"] = self.get_obs_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=self.obs_keys, num_frames_to_stack=self.n_frame_stack - 1, seq_length=self.seq_length, prefix="next_obs" ) if goal_index is not None: goal = self.get_obs_sequence_from_demo( demo_id, index_in_demo=goal_index, keys=self.obs_keys, num_frames_to_stack=0, seq_length=1, prefix="next_obs", ) meta["goal_obs"] = {k: goal[k][0] for k in goal} # remove sequence dimension for goal # get action components ac_dict = OrderedDict() for k in self.action_keys: ac = meta[k] # expand action shape if needed if len(ac.shape) == 1: ac = ac.reshape(-1, 1) ac_dict[k] = ac # normalize actions action_normalization_stats = self.get_action_normalization_stats() ac_dict = ObsUtils.normalize_dict(ac_dict, normalization_stats=action_normalization_stats) # concatenate all action components meta["actions"] = AcUtils.action_dict_to_vector(ac_dict) # also return the sampled index meta["index"] = index return meta def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1): """ Extract a (sub)sequence of data items from a demo given the @keys of the items. Args: demo_id (str): id of the demo, e.g., demo_0 index_in_demo (int): beginning index of the sequence wrt the demo keys (tuple): list of keys to extract num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range Returns: a dictionary of extracted items. """ assert num_frames_to_stack >= 0 assert seq_length >= 1 demo_length = self._demo_id_to_demo_length[demo_id] assert index_in_demo < demo_length # determine begin and end of sequence seq_begin_index = max(0, index_in_demo - num_frames_to_stack) seq_end_index = min(demo_length, index_in_demo + seq_length) # determine sequence padding seq_begin_pad = max(0, num_frames_to_stack - index_in_demo) # pad for frame stacking seq_end_pad = max(0, index_in_demo + seq_length - demo_length) # pad for sequence length # make sure we are not padding if specified. if not self.pad_frame_stack: assert seq_begin_pad == 0 if not self.pad_seq_length: assert seq_end_pad == 0 # fetch observation from the dataset file seq = dict() for k in keys: data = self.get_dataset_for_ep(demo_id, k) seq[k] = data[seq_begin_index: seq_end_index] seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True) pad_mask = np.array([0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad) pad_mask = pad_mask[:, None].astype(bool) return seq, pad_mask def get_obs_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1, prefix="obs"): """ Extract a (sub)sequence of observation items from a demo given the @keys of the items. Args: demo_id (str): id of the demo, e.g., demo_0 index_in_demo (int): beginning index of the sequence wrt the demo keys (tuple): list of keys to extract num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range prefix (str): one of "obs", "next_obs" Returns: a dictionary of extracted items. """ obs, pad_mask = self.get_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=tuple('{}/{}'.format(prefix, k) for k in keys), num_frames_to_stack=num_frames_to_stack, seq_length=seq_length, ) obs = {'/'.join(k.split('/')[1:]): obs[k] for k in obs} # strip the prefix if self.get_pad_mask: obs["pad_mask"] = pad_mask return obs def get_dataset_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1): """ Extract a (sub)sequence of dataset items from a demo given the @keys of the items (e.g., states, actions). Args: demo_id (str): id of the demo, e.g., demo_0 index_in_demo (int): beginning index of the sequence wrt the demo keys (tuple): list of keys to extract num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range Returns: a dictionary of extracted items. """ data, pad_mask = self.get_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=keys, num_frames_to_stack=num_frames_to_stack, seq_length=seq_length, ) if self.get_pad_mask: data["pad_mask"] = pad_mask return data def get_trajectory_at_index(self, index): """ Method provided as a utility to get an entire trajectory, given the corresponding @index. """ demo_id = self.demos[index] demo_length = self._demo_id_to_demo_length[demo_id] meta = self.get_dataset_sequence_from_demo( demo_id, index_in_demo=0, keys=self.dataset_keys, num_frames_to_stack=self.n_frame_stack - 1, # note: need to decrement self.n_frame_stack by one seq_length=demo_length ) meta["obs"] = self.get_obs_sequence_from_demo( demo_id, index_in_demo=0, keys=self.obs_keys, seq_length=demo_length ) if self.load_next_obs: meta["next_obs"] = self.get_obs_sequence_from_demo( demo_id, index_in_demo=0, keys=self.obs_keys, seq_length=demo_length, prefix="next_obs" ) meta["ep"] = demo_id return meta def get_dataset_sampler(self): """ Return instance of torch.utils.data.Sampler or None. Allows for dataset to define custom sampling logic, such as re-weighting the probability of samples being drawn. See the `train` function in scripts/train.py, and torch `DataLoader` documentation, for more info. """ return None class R2D2Dataset(SequenceDataset): def get_action_traj(self, ep): action_traj = dict() for key in self.action_keys: action_traj[key] = self.hdf5_file[key][()].astype('float32') if len(action_traj[key].shape) == 1: action_traj[key] = np.reshape(action_traj[key], (-1, 1)) return action_traj def load_demo_info(self, filter_by_attribute=None, demos=None, n_demos=None): """ Args: filter_by_attribute (str): if provided, use the provided filter key to select a subset of demonstration trajectories to load demos (list): list of demonstration keys to load from the hdf5 file. If omitted, all demos in the file (or under the @filter_by_attribute filter key) are used. """ self.demos = ["demo"] self.n_demos = len(self.demos) # keep internal index maps to know which transitions belong to which demos self._index_to_demo_id = dict() # maps every index to a demo id self._demo_id_to_start_indices = dict() # gives start index per demo id self._demo_id_to_demo_length = dict() # segment time stamps self._demo_id_to_segments = dict() ep = self.demos[0] # determine index mapping self.total_num_sequences = 0 demo_length = self.hdf5_file["action/cartesian_velocity"].shape[0] self._demo_id_to_start_indices[ep] = self.total_num_sequences self._demo_id_to_demo_length[ep] = demo_length # seperate demo into segments for better alignment gripper_actions = list(self.hdf5_file["action/gripper_position"]) gripper_closed = [1 if x > 0 else 0 for x in gripper_actions] try: # find when the gripper fist opens/closes gripper_close = gripper_closed.index(1) gripper_open = gripper_close + gripper_closed[gripper_close:].index(0) except ValueError: # special case for (invalid) trajectories gripper_close, gripper_open = int(demo_length / 3), int(demo_length / 3 * 2) print("No gripper action:", gripper_actions) self._demo_id_to_segments[ep] = [0, gripper_close, gripper_open, demo_length - 1] num_sequences = demo_length # determine actual number of sequences taking into account whether to pad for frame_stack and seq_length if not self.pad_frame_stack: num_sequences -= (self.n_frame_stack - 1) if not self.pad_seq_length: num_sequences -= (self.seq_length - 1) if self.pad_seq_length: assert demo_length >= 1 # sequence needs to have at least one sample num_sequences = max(num_sequences, 1) else: assert num_sequences >= 1 # assume demo_length >= (self.n_frame_stack - 1 + self.seq_length) for _ in range(num_sequences): self._index_to_demo_id[self.total_num_sequences] = ep self.total_num_sequences += 1 def load_dataset_in_memory(self, demo_list, hdf5_file, obs_keys, dataset_keys, load_next_obs): """ Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this differs from `self.getitem_cache`, which, if active, actually caches the outputs of the `getitem` operation. Args: demo_list (list): list of demo keys, e.g., 'demo_0' hdf5_file (h5py.File): file handle to the hdf5 dataset. obs_keys (list, tuple): observation keys to fetch, e.g., 'images' dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions' load_next_obs (bool): whether to load next_obs from the dataset Returns: all_data (dict): dictionary of loaded data. """ all_data = dict() print("SequenceDataset: loading dataset into memory...") for ep in LogUtils.custom_tqdm(demo_list): all_data[ep] = {} all_data[ep]["attrs"] = {} all_data[ep]["attrs"]["num_samples"] = hdf5_file["action/cartesian_velocity"].shape[0] # hack to get traj len # get obs all_data[ep]["obs"] = {k: hdf5_file["observation/{}".format(k)][()].astype('float32') for k in obs_keys} if load_next_obs: raise NotImplementedError # get other dataset keys for k in dataset_keys: if k in hdf5_file.keys(): all_data[ep][k] = hdf5_file["{}".format(k)][()].astype('float32') else: raise NotImplementedError return all_data def get_dataset_for_ep(self, ep, key, try_to_use_cache=True): """ Helper utility to get a dataset for a specific demonstration. Takes into account whether the dataset has been loaded into memory. """ # check if this key should be in memory key_should_be_in_memory = try_to_use_cache and (self.hdf5_cache_mode in ["all", "low_dim"]) if key_should_be_in_memory: # if key is an observation, it may not be in memory if '/' in key: key_splits = key.split('/') key1 = key_splits[0] key2 = "/".join(key_splits[1:]) if key1 == "observation" and key2 not in self.obs_keys_in_memory: key_should_be_in_memory = False if key_should_be_in_memory: # read cache if '/' in key: key_splits = key.split('/') key1 = key_splits[0] key2 = "/".join(key_splits[1:]) if key1 == "observation": ret = self.hdf5_cache[ep]["obs"][key2] else: ret = self.hdf5_cache[ep][key] else: ret = self.hdf5_cache[ep][key] else: # read from file hd5key = "{}".format(key) #"data/{}/{}".format(ep, key) ret = self.hdf5_file[hd5key] return ret def get_sequence_from_demo(self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1): """ Extract a (sub)sequence of data items from a demo given the @keys of the items. Args: demo_id (str): id of the demo, e.g., demo_0 index_in_demo (int): beginning index of the sequence wrt the demo keys (tuple): list of keys to extract num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range Returns: a dictionary of extracted items. """ assert num_frames_to_stack >= 0 assert seq_length >= 1 demo_length = self._demo_id_to_demo_length[demo_id] assert index_in_demo < demo_length # determine begin and end of sequence seq_begin_index = max(0, index_in_demo - num_frames_to_stack) seq_end_index = min(demo_length, index_in_demo + seq_length) # determine sequence padding seq_begin_pad = max(0, num_frames_to_stack - index_in_demo) # pad for frame stacking seq_end_pad = max(0, index_in_demo + seq_length - demo_length) # pad for sequence length # make sure we are not padding if specified. if not self.pad_frame_stack: assert seq_begin_pad == 0 if not self.pad_seq_length: assert seq_end_pad == 0 # fetch observation from the dataset file seq = dict() for k in keys: data = self.get_dataset_for_ep(demo_id, k) seq[k] = data[seq_begin_index: seq_end_index].astype("float32") seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True) pad_mask = np.array([0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad) pad_mask = pad_mask[:, None].astype(bool) return seq, pad_mask def get_item(self, index): """ Main implementation of getitem when not using cache. """ demo_id = self._index_to_demo_id[index] demo_start_index = self._demo_id_to_start_indices[demo_id] demo_length = self._demo_id_to_demo_length[demo_id] # start at offset index if not padding for frame stacking demo_index_offset = 0 if self.pad_frame_stack else (self.n_frame_stack - 1) index_in_demo = index - demo_start_index + demo_index_offset # end at offset index if not padding for seq length demo_length_offset = 0 if self.pad_seq_length else (self.seq_length - 1) end_index_in_demo = demo_length - demo_length_offset meta = self.get_dataset_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=self.dataset_keys, num_frames_to_stack=self.n_frame_stack - 1, seq_length=self.seq_length, ) # determine goal index goal_index = None if self.goal_mode == "last": goal_index = end_index_in_demo - 1 meta["obs"] = self.get_obs_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=self.obs_keys, num_frames_to_stack=self.n_frame_stack - 1, seq_length=self.seq_length, prefix="observation" ) if self.load_next_obs: meta["next_obs"] = self.get_obs_sequence_from_demo( demo_id, index_in_demo=index_in_demo, keys=self.obs_keys, num_frames_to_stack=self.n_frame_stack - 1, seq_length=self.seq_length, prefix="next_obs" ) if goal_index is not None: goal = self.get_obs_sequence_from_demo( demo_id, index_in_demo=goal_index, keys=self.obs_keys, num_frames_to_stack=0, seq_length=1, prefix="next_obs", ) meta["goal_obs"] = {k: goal[k][0] for k in goal} # remove sequence dimension for goal # get action components ac_dict = OrderedDict() for k in self.action_keys: ac = meta[k] # expand action shape if needed if len(ac.shape) == 1: ac = ac.reshape(-1, 1) ac_dict[k] = ac # normalize actions action_normalization_stats = self.get_action_normalization_stats() ac_dict = ObsUtils.normalize_dict(ac_dict, normalization_stats=action_normalization_stats) # concatenate all action components meta["actions"] = AcUtils.action_dict_to_vector(ac_dict) # keys to reshape for k in meta["obs"]: if len(meta["obs"][k].shape) == 1: meta["obs"][k] = np.expand_dims(meta["obs"][k], axis=1) # also return the sampled index meta["index"] = index return meta class MetaDataset(torch.utils.data.Dataset): def __init__( self, datasets, ds_weights, normalize_weights_by_ds_size=False, ds_labels=None, ): super(MetaDataset, self).__init__() self.datasets = datasets ds_lens = np.array([len(ds) for ds in self.datasets]) if normalize_weights_by_ds_size: self.ds_weights = np.array(ds_weights) / ds_lens else: self.ds_weights = ds_weights self._ds_ind_bins = np.cumsum([0] + list(ds_lens)) # cache mode "all" not supported! The action normalization stats of each # dataset will change after the datasets are already initialized for ds in self.datasets: assert ds.hdf5_cache_mode != "all" # compute ds_labels to one hot ids if ds_labels is None: self.ds_labels = ["dummy"] else: self.ds_labels = ds_labels unique_labels = sorted(set(self.ds_labels)) self.ds_labels_to_ids = {} for i, label in enumerate(sorted(unique_labels)): one_hot_id = np.zeros(len(unique_labels)) one_hot_id[i] = 1.0 self.ds_labels_to_ids[label] = one_hot_id # TODO: comment action_stats = self.get_action_stats() self.action_normalization_stats = action_stats_to_normalization_stats( action_stats, self.datasets[0].action_config) self.set_action_normalization_stats(self.action_normalization_stats) def __len__(self): return np.sum([len(ds) for ds in self.datasets]) def __getitem__(self, idx): ds_ind = np.digitize(idx, self._ds_ind_bins) - 1 ind_in_ds = idx - self._ds_ind_bins[ds_ind] meta = self.datasets[ds_ind].__getitem__(ind_in_ds) meta["index"] = idx ds_label = self.ds_labels[ds_ind] T = meta["actions"].shape[0] return meta def get_ds_label(self, idx): ds_ind = np.digitize(idx, self._ds_ind_bins) - 1 ds_label = self.ds_labels[ds_ind] return ds_label def get_ds_id(self, idx): ds_ind = np.digitize(idx, self._ds_ind_bins) - 1 ds_label = self.ds_labels[ds_ind] return self.ds_labels_to_ids[ds_label] def __repr__(self): str_output = '\n'.join([ds.__repr__() for ds in self.datasets]) return str_output def get_dataset_sampler(self): weights = np.ones(len(self)) for i, (start, end) in enumerate(zip(self._ds_ind_bins[:-1], self._ds_ind_bins[1:])): weights[start:end] = self.ds_weights[i] sampler = torch.utils.data.WeightedRandomSampler( weights=weights, num_samples=len(self), replacement=True, ) return sampler def get_action_stats(self): meta_action_stats = self.datasets[0].get_action_stats() for dataset in self.datasets[1:]: ds_action_stats = dataset.get_action_stats() meta_action_stats = _aggregate_traj_stats(meta_action_stats, ds_action_stats) return meta_action_stats def set_action_normalization_stats(self, action_normalization_stats): self.action_normalization_stats = action_normalization_stats for ds in self.datasets: ds.set_action_normalization_stats(self.action_normalization_stats) def get_action_normalization_stats(self): """ Computes a dataset-wide min, max, mean and standard deviation for the actions (per dimension) and returns it. """ # Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate # with the previous statistics. if self.action_normalization_stats is None: action_stats = self.get_action_stats() self.action_normalization_stats = action_stats_to_normalization_stats( action_stats, self.datasets[0].action_config) return self.action_normalization_stats def _compute_traj_stats(traj_obs_dict): """ Helper function to compute statistics over a single trajectory of observations. """ traj_stats = { k : {} for k in traj_obs_dict } for k in traj_obs_dict: traj_stats[k]["n"] = traj_obs_dict[k].shape[0] traj_stats[k]["mean"] = traj_obs_dict[k].mean(axis=0, keepdims=True) # [1, ...] traj_stats[k]["sqdiff"] = ((traj_obs_dict[k] - traj_stats[k]["mean"]) ** 2).sum(axis=0, keepdims=True) # [1, ...] traj_stats[k]["min"] = traj_obs_dict[k].min(axis=0, keepdims=True) traj_stats[k]["max"] = traj_obs_dict[k].max(axis=0, keepdims=True) return traj_stats def _aggregate_traj_stats(traj_stats_a, traj_stats_b): """ Helper function to aggregate trajectory statistics. See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm for more information. """ merged_stats = {} for k in traj_stats_a: n_a, avg_a, M2_a, min_a, max_a = traj_stats_a[k]["n"], traj_stats_a[k]["mean"], traj_stats_a[k]["sqdiff"], traj_stats_a[k]["min"], traj_stats_a[k]["max"] n_b, avg_b, M2_b, min_b, max_b = traj_stats_b[k]["n"], traj_stats_b[k]["mean"], traj_stats_b[k]["sqdiff"], traj_stats_b[k]["min"], traj_stats_b[k]["max"] n = n_a + n_b mean = (n_a * avg_a + n_b * avg_b) / n delta = (avg_b - avg_a) M2 = M2_a + M2_b + (delta ** 2) * (n_a * n_b) / n min_ = np.minimum(min_a, min_b) max_ = np.maximum(max_a, max_b) merged_stats[k] = dict(n=n, mean=mean, sqdiff=M2, min=min_, max=max_) return merged_stats def action_stats_to_normalization_stats(action_stats, action_config): action_normalization_stats = OrderedDict() for action_key in action_stats.keys(): # get how this action should be normalized from config, default to None norm_method = action_config[action_key].get("normalization", None) if norm_method is None: # no normalization, unit scale, zero offset action_normalization_stats[action_key] = { "scale": np.ones_like(action_stats[action_key]["mean"], dtype=np.float32), "offset": np.zeros_like(action_stats[action_key]["mean"], dtype=np.float32) } elif norm_method == "min_max": # normalize min to -1 and max to 1 range_eps = 1e-4 input_min = action_stats[action_key]["min"].astype(np.float32) input_max = action_stats[action_key]["max"].astype(np.float32) # instead of -1 and 1 use numbers just below threshold to prevent numerical instability issues output_min = -0.999999 output_max = 0.999999 # ignore input dimentions that is too small to prevent division by zero input_range = input_max - input_min ignore_dim = input_range < range_eps input_range[ignore_dim] = output_max - output_min # expected usage of scale and offset # normalized_action = (raw_action - offset) / scale # raw_action = scale * normalized_action + offset # eq1: input_max = scale * output_max + offset # eq2: input_min = scale * output_min + offset # solution for scale and offset # eq1 - eq2: # input_max - input_min = scale * (output_max - output_min) # (input_max - input_min) / (output_max - output_min) = scale <- eq3 # offset = input_min - scale * output_min <- eq4 scale = input_range / (output_max - output_min) offset = input_min - scale * output_min offset[ignore_dim] = input_min[ignore_dim] - (output_max + output_min) / 2 action_normalization_stats[action_key] = { "scale": scale, "offset": offset } elif norm_method == "gaussian": # normalize to zero mean unit variance input_mean = action_stats[action_key]["mean"].astype(np.float32) input_std = np.sqrt(action_stats[action_key]["sqdiff"] / action_stats[action_key]["n"]).astype(np.float32) # ignore input dimentions that is too small to prevent division by zero std_eps = 1e-6 ignore_dim = input_std < std_eps input_std[ignore_dim] = 1.0 action_normalization_stats[action_key] = { "scale": input_mean, "offset": input_std } else: raise NotImplementedError( 'action_config.actions.normalization: "{}" is not supported'.format(norm_method)) return action_normalization_stats