import random import numpy as np import torch # from utils.action_label_to_idx import action_label_to_idx from data_loaders.tensors import collate from utils.misc import to_torch import utils.rotation_conversions as geometry class Dataset(torch.utils.data.Dataset): def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train", pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs): self.num_frames = num_frames self.sampling = sampling self.sampling_step = sampling_step self.split = split self.pose_rep = pose_rep self.translation = translation self.glob = glob self.max_len = max_len self.min_len = min_len self.num_seq_max = num_seq_max self.align_pose_frontview = kwargs.get('align_pose_frontview', False) self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False) self.only_60_classes = kwargs.get('only_60_classes', False) self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False) self.use_only_15_classes = kwargs.get('use_only_15_classes', False) if self.split not in ["train", "val", "test"]: raise ValueError(f"{self.split} is not a valid split") super().__init__() # to remove shuffling self._original_train = None self._original_test = None def action_to_label(self, action): return self._action_to_label[action] def label_to_action(self, label): import numbers if isinstance(label, numbers.Integral): return self._label_to_action[label] else: # if it is one hot vector label = np.argmax(label) return self._label_to_action[label] def get_pose_data(self, data_index, frame_ix): pose = self._load(data_index, frame_ix) label = self.get_label(data_index) return pose, label def get_label(self, ind): action = self.get_action(ind) return self.action_to_label(action) def get_action(self, ind): return self._actions[ind] def action_to_action_name(self, action): return self._action_classes[action] def action_name_to_action(self, action_name): # self._action_classes is either a list or a dictionary. If it's a dictionary, we 1st convert it to a list all_action_names = self._action_classes if isinstance(all_action_names, dict): all_action_names = list(all_action_names.values()) assert list(self._action_classes.keys()) == list(range(len(all_action_names))) # the keys should be ordered from 0 to num_actions sorter = np.argsort(all_action_names) actions = sorter[np.searchsorted(all_action_names, action_name, sorter=sorter)] return actions def __getitem__(self, index): if self.split == 'train': data_index = self._train[index] else: data_index = self._test[index] # inp, target = self._get_item_data_index(data_index) # return inp, target return self._get_item_data_index(data_index) def _load(self, ind, frame_ix): pose_rep = self.pose_rep if pose_rep == "xyz" or self.translation: if getattr(self, "_load_joints3D", None) is not None: # Locate the root joint of initial pose at origin joints3D = self._load_joints3D(ind, frame_ix) joints3D = joints3D - joints3D[0, 0, :] ret = to_torch(joints3D) if self.translation: ret_tr = ret[:, 0, :] else: if pose_rep == "xyz": raise ValueError("This representation is not possible.") if getattr(self, "_load_translation") is None: raise ValueError("Can't extract translations.") ret_tr = self._load_translation(ind, frame_ix) ret_tr = to_torch(ret_tr - ret_tr[0]) if pose_rep != "xyz": if getattr(self, "_load_rotvec", None) is None: raise ValueError("This representation is not possible.") else: pose = self._load_rotvec(ind, frame_ix) if not self.glob: pose = pose[:, 1:, :] pose = to_torch(pose) if self.align_pose_frontview: first_frame_root_pose_matrix = geometry.axis_angle_to_matrix(pose[0][0]) all_root_poses_matrix = geometry.axis_angle_to_matrix(pose[:, 0, :]) aligned_root_poses_matrix = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1), all_root_poses_matrix) pose[:, 0, :] = geometry.matrix_to_axis_angle(aligned_root_poses_matrix) if self.translation: ret_tr = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1).float(), torch.transpose(ret_tr, 0, 1)) ret_tr = torch.transpose(ret_tr, 0, 1) if pose_rep == "rotvec": ret = pose elif pose_rep == "rotmat": ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9) elif pose_rep == "rotquat": ret = geometry.axis_angle_to_quaternion(pose) elif pose_rep == "rot6d": ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose)) if pose_rep != "xyz" and self.translation: padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype) padded_tr[:, :3] = ret_tr ret = torch.cat((ret, padded_tr[:, None]), 1) ret = ret.permute(1, 2, 0).contiguous() return ret.float() def _get_item_data_index(self, data_index): nframes = self._num_frames_in_video[data_index] if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len): frame_ix = np.arange(nframes) else: if self.num_frames == -2: if self.min_len <= 0: raise ValueError("You should put a min_len > 0 for num_frames == -2 mode") if self.max_len != -1: max_frame = min(nframes, self.max_len) else: max_frame = nframes num_frames = random.randint(self.min_len, max(max_frame, self.min_len)) else: num_frames = self.num_frames if self.num_frames != -1 else self.max_len if num_frames > nframes: fair = False # True if fair: # distills redundancy everywhere choices = np.random.choice(range(nframes), num_frames, replace=True) frame_ix = sorted(choices) else: # adding the last frame until done ntoadd = max(0, num_frames - nframes) lastframe = nframes - 1 padding = lastframe * np.ones(ntoadd, dtype=int) frame_ix = np.concatenate((np.arange(0, nframes), padding)) elif self.sampling in ["conseq", "random_conseq"]: step_max = (nframes - 1) // (num_frames - 1) if self.sampling == "conseq": if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes: step = step_max else: step = self.sampling_step elif self.sampling == "random_conseq": step = random.randint(1, step_max) lastone = step * (num_frames - 1) shift_max = nframes - lastone - 1 shift = random.randint(0, max(0, shift_max - 1)) frame_ix = shift + np.arange(0, lastone + 1, step) elif self.sampling == "random": choices = np.random.choice(range(nframes), num_frames, replace=False) frame_ix = sorted(choices) else: raise ValueError("Sampling not recognized.") inp, action = self.get_pose_data(data_index, frame_ix) output = {'inp': inp, 'action': action} if hasattr(self, '_actions') and hasattr(self, '_action_classes'): output['action_text'] = self.action_to_action_name(self.get_action(data_index)) return output def get_mean_length_label(self, label): if self.num_frames != -1: return self.num_frames if self.split == 'train': index = self._train else: index = self._test action = self.label_to_action(label) choices = np.argwhere(self._actions[index] == action).squeeze(1) lengths = self._num_frames_in_video[np.array(index)[choices]] if self.max_len == -1: return np.mean(lengths) else: # make the lengths less than max_len lengths[lengths > self.max_len] = self.max_len return np.mean(lengths) def __len__(self): num_seq_max = getattr(self, "num_seq_max", -1) if num_seq_max == -1: from math import inf num_seq_max = inf if self.split == 'train': return min(len(self._train), num_seq_max) else: return min(len(self._test), num_seq_max) def shuffle(self): if self.split == 'train': random.shuffle(self._train) else: random.shuffle(self._test) def reset_shuffle(self): if self.split == 'train': if self._original_train is None: self._original_train = self._train else: self._train = self._original_train else: if self._original_test is None: self._original_test = self._test else: self._test = self._original_test