Spaces:
Sleeping
Sleeping
| from __future__ import absolute_import | |
| from __future__ import print_function | |
| from __future__ import division | |
| import torch | |
| import joblib | |
| import numpy as np | |
| from .._dataset import BaseDataset | |
| from ..utils.augmentor import * | |
| from ...utils import data_utils as d_utils | |
| from ...utils import transforms | |
| from ...models import build_body_model | |
| from ...utils.kp_utils import convert_kps, root_centering | |
| class Dataset3D(BaseDataset): | |
| def __init__(self, cfg, fname, training): | |
| super(Dataset3D, self).__init__(cfg, training) | |
| self.epoch = 0 | |
| self.labels = joblib.load(fname) | |
| self.n_frames = cfg.DATASET.SEQLEN + 1 | |
| if self.training: | |
| self.prepare_video_batch() | |
| self.smpl = build_body_model('cpu', self.n_frames) | |
| self.SMPLAugmentor = SMPLAugmentor(cfg, False) | |
| self.VideoAugmentor = VideoAugmentor(cfg) | |
| def __getitem__(self, index): | |
| return self.get_single_sequence(index) | |
| def get_inputs(self, index, target, vis_thr=0.6): | |
| start_index, end_index = self.video_indices[index] | |
| # 2D keypoints detection | |
| kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone() | |
| bbox = self.labels['bbox'][start_index:end_index+1][..., [0, 1, -1]].clone() | |
| bbox[:, 2] = bbox[:, 2] / 200 | |
| kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox) | |
| target['bbox'] = bbox[1:] | |
| target['kp2d'] = kp2d | |
| target['mask'] = self.labels['kp2d'][start_index+1:end_index+1][..., -1] < vis_thr | |
| # Image features | |
| target['features'] = self.labels['features'][start_index+1:end_index+1].clone() | |
| return target | |
| def get_labels(self, index, target): | |
| start_index, end_index = self.video_indices[index] | |
| # SMPL parameters | |
| # NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input. | |
| # We do not supervise the network on SMPL parameters. | |
| target['pose'] = transforms.axis_angle_to_matrix( | |
| self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3)) | |
| target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t | |
| # Apply SMPL augmentor (y-axis rotation and initial frame noise) | |
| target = self.SMPLAugmentor(target) | |
| # 3D and 2D keypoints | |
| if self.__name__ == 'ThreeDPW': # 3DPW has SMPL labels | |
| gt_kp3d = self.labels['joints3D'][start_index:end_index+1].clone() | |
| gt_kp2d = self.labels['joints2D'][start_index+1:end_index+1, ..., :2].clone() | |
| gt_kp3d = root_centering(gt_kp3d.clone()) | |
| else: # Human36m and MPII do not have SMPL labels | |
| gt_kp3d = torch.zeros((self.n_frames, self.n_joints + 14, 3)) | |
| gt_kp3d[:, self.n_joints:] = convert_kps(self.labels['joints3D'][start_index:end_index+1], 'spin', 'common') | |
| gt_kp2d = torch.zeros((self.n_frames - 1, self.n_joints + 14, 2)) | |
| gt_kp2d[:, self.n_joints:] = convert_kps(self.labels['joints2D'][start_index+1:end_index+1, ..., :2], 'spin', 'common') | |
| conf = self.mask.repeat(self.n_frames, 1).unsqueeze(-1) | |
| gt_kp2d = torch.cat((gt_kp2d, conf[1:]), dim=-1) | |
| gt_kp3d = torch.cat((gt_kp3d, conf), dim=-1) | |
| target['kp3d'] = gt_kp3d | |
| target['full_kp2d'] = gt_kp2d | |
| target['weak_kp2d'] = torch.zeros_like(gt_kp2d) | |
| if self.__name__ != 'ThreeDPW': # 3DPW does not contain world-coordinate motion | |
| # Foot ground contact labels for Human36M and MPII3D | |
| target['contact'] = self.labels['stationaries'][start_index+1:end_index+1].clone() | |
| else: | |
| # No foot ground contact label available for 3DPW | |
| target['contact'] = torch.ones((self.n_frames - 1, 4)) * (-1) | |
| if self.has_verts: | |
| # SMPL vertices available for 3DPW | |
| with torch.no_grad(): | |
| start_index, end_index = self.video_indices[index] | |
| gender = self.labels['gender'][start_index].item() | |
| output = self.smpl_gender[gender]( | |
| body_pose=target['pose'][1:, 1:], | |
| global_orient=target['pose'][1:, :1], | |
| betas=target['betas'][1:], | |
| pose2rot=False, | |
| ) | |
| target['verts'] = output.vertices.clone() | |
| else: | |
| # No SMPL vertices available | |
| target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float() | |
| return target | |
| def get_init_frame(self, target): | |
| # Prepare initial frame | |
| output = self.smpl.get_output( | |
| body_pose=target['init_pose'][:, 1:], | |
| global_orient=target['init_pose'][:, :1], | |
| betas=target['betas'][:1], | |
| pose2rot=False | |
| ) | |
| target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1) | |
| return target | |
| def get_camera_info(self, index, target): | |
| start_index, end_index = self.video_indices[index] | |
| # Intrinsics | |
| target['res'] = self.labels['res'][start_index:end_index+1][0].clone() | |
| self.get_naive_intrinsics(target['res']) | |
| target['cam_intrinsics'] = self.cam_intrinsics.clone() | |
| # Extrinsics pose | |
| R = self.labels['cam_poses'][start_index:end_index+1, :3, :3].clone().float() | |
| yaw = transforms.axis_angle_to_matrix(torch.tensor([[0, 2 * np.pi * np.random.uniform(), 0]])).float() | |
| if self.__name__ == 'Human36M': | |
| # Map Z-up to Y-down coordinate | |
| zup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[-np.pi/2, 0, 0]])).float() | |
| zup2ydown = torch.matmul(yaw, zup2ydown) | |
| R = torch.matmul(R, zup2ydown) | |
| elif self.__name__ == 'MPII3D': | |
| # Map Y-up to Y-down coordinate | |
| yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float() | |
| yup2ydown = torch.matmul(yaw, yup2ydown) | |
| R = torch.matmul(R, yup2ydown) | |
| return target | |
| def get_single_sequence(self, index): | |
| # Universal target | |
| target = {'has_full_screen': torch.tensor(True), | |
| 'has_smpl': torch.tensor(self.has_smpl), | |
| 'has_traj': torch.tensor(self.has_traj), | |
| 'has_verts': torch.tensor(self.has_verts), | |
| 'transl': torch.zeros((self.n_frames, 3)), | |
| # Null camera motion | |
| 'R': torch.eye(3).repeat(self.n_frames, 1, 1), | |
| 'cam_angvel': torch.zeros((self.n_frames - 1, 6)), | |
| # Null root orientation and velocity | |
| 'pose_root': torch.zeros((self.n_frames, 6)), | |
| 'vel_root': torch.zeros((self.n_frames - 1, 3)), | |
| 'init_root': torch.zeros((1, 6)), | |
| } | |
| self.get_camera_info(index, target) | |
| self.get_inputs(index, target) | |
| self.get_labels(index, target) | |
| self.get_init_frame(target) | |
| target = d_utils.prepare_keypoints_data(target) | |
| target = d_utils.prepare_smpl_data(target) | |
| return target |