Spaces:
Sleeping
Sleeping
| import random | |
| import numpy as np | |
| import torch | |
| # from utils.action_label_to_idx import action_label_to_idx | |
| from data_loaders.tensors import collate | |
| from utils.misc import to_torch | |
| import utils.rotation_conversions as geometry | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train", | |
| pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs): | |
| self.num_frames = num_frames | |
| self.sampling = sampling | |
| self.sampling_step = sampling_step | |
| self.split = split | |
| self.pose_rep = pose_rep | |
| self.translation = translation | |
| self.glob = glob | |
| self.max_len = max_len | |
| self.min_len = min_len | |
| self.num_seq_max = num_seq_max | |
| self.align_pose_frontview = kwargs.get('align_pose_frontview', False) | |
| self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False) | |
| self.only_60_classes = kwargs.get('only_60_classes', False) | |
| self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False) | |
| self.use_only_15_classes = kwargs.get('use_only_15_classes', False) | |
| if self.split not in ["train", "val", "test"]: | |
| raise ValueError(f"{self.split} is not a valid split") | |
| super().__init__() | |
| # to remove shuffling | |
| self._original_train = None | |
| self._original_test = None | |
| def action_to_label(self, action): | |
| return self._action_to_label[action] | |
| def label_to_action(self, label): | |
| import numbers | |
| if isinstance(label, numbers.Integral): | |
| return self._label_to_action[label] | |
| else: # if it is one hot vector | |
| label = np.argmax(label) | |
| return self._label_to_action[label] | |
| def get_pose_data(self, data_index, frame_ix): | |
| pose = self._load(data_index, frame_ix) | |
| label = self.get_label(data_index) | |
| return pose, label | |
| def get_label(self, ind): | |
| action = self.get_action(ind) | |
| return self.action_to_label(action) | |
| def get_action(self, ind): | |
| return self._actions[ind] | |
| def action_to_action_name(self, action): | |
| return self._action_classes[action] | |
| def action_name_to_action(self, action_name): | |
| # self._action_classes is either a list or a dictionary. If it's a dictionary, we 1st convert it to a list | |
| all_action_names = self._action_classes | |
| if isinstance(all_action_names, dict): | |
| all_action_names = list(all_action_names.values()) | |
| assert list(self._action_classes.keys()) == list(range(len(all_action_names))) # the keys should be ordered from 0 to num_actions | |
| sorter = np.argsort(all_action_names) | |
| actions = sorter[np.searchsorted(all_action_names, action_name, sorter=sorter)] | |
| return actions | |
| def __getitem__(self, index): | |
| if self.split == 'train': | |
| data_index = self._train[index] | |
| else: | |
| data_index = self._test[index] | |
| # inp, target = self._get_item_data_index(data_index) | |
| # return inp, target | |
| return self._get_item_data_index(data_index) | |
| def _load(self, ind, frame_ix): | |
| pose_rep = self.pose_rep | |
| if pose_rep == "xyz" or self.translation: | |
| if getattr(self, "_load_joints3D", None) is not None: | |
| # Locate the root joint of initial pose at origin | |
| joints3D = self._load_joints3D(ind, frame_ix) | |
| joints3D = joints3D - joints3D[0, 0, :] | |
| ret = to_torch(joints3D) | |
| if self.translation: | |
| ret_tr = ret[:, 0, :] | |
| else: | |
| if pose_rep == "xyz": | |
| raise ValueError("This representation is not possible.") | |
| if getattr(self, "_load_translation") is None: | |
| raise ValueError("Can't extract translations.") | |
| ret_tr = self._load_translation(ind, frame_ix) | |
| ret_tr = to_torch(ret_tr - ret_tr[0]) | |
| if pose_rep != "xyz": | |
| if getattr(self, "_load_rotvec", None) is None: | |
| raise ValueError("This representation is not possible.") | |
| else: | |
| pose = self._load_rotvec(ind, frame_ix) | |
| if not self.glob: | |
| pose = pose[:, 1:, :] | |
| pose = to_torch(pose) | |
| if self.align_pose_frontview: | |
| first_frame_root_pose_matrix = geometry.axis_angle_to_matrix(pose[0][0]) | |
| all_root_poses_matrix = geometry.axis_angle_to_matrix(pose[:, 0, :]) | |
| aligned_root_poses_matrix = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1), | |
| all_root_poses_matrix) | |
| pose[:, 0, :] = geometry.matrix_to_axis_angle(aligned_root_poses_matrix) | |
| if self.translation: | |
| ret_tr = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1).float(), | |
| torch.transpose(ret_tr, 0, 1)) | |
| ret_tr = torch.transpose(ret_tr, 0, 1) | |
| if pose_rep == "rotvec": | |
| ret = pose | |
| elif pose_rep == "rotmat": | |
| ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9) | |
| elif pose_rep == "rotquat": | |
| ret = geometry.axis_angle_to_quaternion(pose) | |
| elif pose_rep == "rot6d": | |
| ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose)) | |
| if pose_rep != "xyz" and self.translation: | |
| padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype) | |
| padded_tr[:, :3] = ret_tr | |
| ret = torch.cat((ret, padded_tr[:, None]), 1) | |
| ret = ret.permute(1, 2, 0).contiguous() | |
| return ret.float() | |
| def _get_item_data_index(self, data_index): | |
| nframes = self._num_frames_in_video[data_index] | |
| if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len): | |
| frame_ix = np.arange(nframes) | |
| else: | |
| if self.num_frames == -2: | |
| if self.min_len <= 0: | |
| raise ValueError("You should put a min_len > 0 for num_frames == -2 mode") | |
| if self.max_len != -1: | |
| max_frame = min(nframes, self.max_len) | |
| else: | |
| max_frame = nframes | |
| num_frames = random.randint(self.min_len, max(max_frame, self.min_len)) | |
| else: | |
| num_frames = self.num_frames if self.num_frames != -1 else self.max_len | |
| if num_frames > nframes: | |
| fair = False # True | |
| if fair: | |
| # distills redundancy everywhere | |
| choices = np.random.choice(range(nframes), | |
| num_frames, | |
| replace=True) | |
| frame_ix = sorted(choices) | |
| else: | |
| # adding the last frame until done | |
| ntoadd = max(0, num_frames - nframes) | |
| lastframe = nframes - 1 | |
| padding = lastframe * np.ones(ntoadd, dtype=int) | |
| frame_ix = np.concatenate((np.arange(0, nframes), | |
| padding)) | |
| elif self.sampling in ["conseq", "random_conseq"]: | |
| step_max = (nframes - 1) // (num_frames - 1) | |
| if self.sampling == "conseq": | |
| if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes: | |
| step = step_max | |
| else: | |
| step = self.sampling_step | |
| elif self.sampling == "random_conseq": | |
| step = random.randint(1, step_max) | |
| lastone = step * (num_frames - 1) | |
| shift_max = nframes - lastone - 1 | |
| shift = random.randint(0, max(0, shift_max - 1)) | |
| frame_ix = shift + np.arange(0, lastone + 1, step) | |
| elif self.sampling == "random": | |
| choices = np.random.choice(range(nframes), | |
| num_frames, | |
| replace=False) | |
| frame_ix = sorted(choices) | |
| else: | |
| raise ValueError("Sampling not recognized.") | |
| inp, action = self.get_pose_data(data_index, frame_ix) | |
| output = {'inp': inp, 'action': action} | |
| if hasattr(self, '_actions') and hasattr(self, '_action_classes'): | |
| output['action_text'] = self.action_to_action_name(self.get_action(data_index)) | |
| return output | |
| def get_mean_length_label(self, label): | |
| if self.num_frames != -1: | |
| return self.num_frames | |
| if self.split == 'train': | |
| index = self._train | |
| else: | |
| index = self._test | |
| action = self.label_to_action(label) | |
| choices = np.argwhere(self._actions[index] == action).squeeze(1) | |
| lengths = self._num_frames_in_video[np.array(index)[choices]] | |
| if self.max_len == -1: | |
| return np.mean(lengths) | |
| else: | |
| # make the lengths less than max_len | |
| lengths[lengths > self.max_len] = self.max_len | |
| return np.mean(lengths) | |
| def __len__(self): | |
| num_seq_max = getattr(self, "num_seq_max", -1) | |
| if num_seq_max == -1: | |
| from math import inf | |
| num_seq_max = inf | |
| if self.split == 'train': | |
| return min(len(self._train), num_seq_max) | |
| else: | |
| return min(len(self._test), num_seq_max) | |
| def shuffle(self): | |
| if self.split == 'train': | |
| random.shuffle(self._train) | |
| else: | |
| random.shuffle(self._test) | |
| def reset_shuffle(self): | |
| if self.split == 'train': | |
| if self._original_train is None: | |
| self._original_train = self._train | |
| else: | |
| self._train = self._original_train | |
| else: | |
| if self._original_test is None: | |
| self._original_test = self._test | |
| else: | |
| self._test = self._original_test | |