Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import os.path | |
| from abc import ABCMeta | |
| from collections import OrderedDict | |
| from typing import Any, List, Optional, Union | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from mmcv.runner import get_dist_info | |
| from detrsmpl.core.conventions.keypoints_mapping import ( | |
| convert_kps, | |
| get_keypoint_num, | |
| get_mapping, | |
| ) | |
| from detrsmpl.core.evaluation import ( | |
| keypoint_3d_auc, | |
| keypoint_3d_pck, | |
| keypoint_mpjpe, | |
| vertice_pve, | |
| ) | |
| from detrsmpl.data.data_structures.multi_human_data import MultiHumanData | |
| from detrsmpl.models.body_models.builder import build_body_model | |
| from .base_dataset import BaseDataset | |
| from .builder import DATASETS | |
| class MultiHumanImageDataset(BaseDataset, metaclass=ABCMeta): | |
| def __init__(self, | |
| data_prefix: str, | |
| pipeline: list, | |
| body_model: Optional[Union[dict, None]] = None, | |
| ann_file: Optional[Union[str, None]] = None, | |
| convention: Optional[str] = 'human_data', | |
| test_mode: Optional[bool] = False, | |
| dataset_name: Optional[Union[str, None]] = None): | |
| self.num_keypoints = get_keypoint_num(convention) | |
| self.convention = convention | |
| super(MultiHumanImageDataset, | |
| self).__init__(data_prefix, pipeline, ann_file, test_mode, | |
| dataset_name) | |
| if body_model is not None: | |
| self.body_model = build_body_model(body_model) | |
| else: | |
| self.body_model = None | |
| def get_annotation_file(self): | |
| """Get path of the annotation file.""" | |
| ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets') | |
| self.ann_file = os.path.join(ann_prefix, self.ann_file) | |
| def load_annotations(self): | |
| """Load annotations.""" | |
| self.get_annotation_file() | |
| self.human_data = MultiHumanData() | |
| self.human_data.load(self.ann_file) | |
| self.instance_num = self.human_data.instance_num | |
| self.image_path = self.human_data['image_path'] | |
| self.num_data = self.human_data.data_len | |
| try: | |
| self.frame_range = self.human_data['frame_range'] | |
| except KeyError: | |
| self.frame_range = \ | |
| np.array([[i, i + 1] for i in range(self.num_data)]) | |
| self.num_data = self.frame_range.shape[0] | |
| if self.human_data.check_keypoints_compressed(): | |
| self.human_data.decompress_keypoints() | |
| # change keypoint from 'human_data' to the given convention | |
| if 'keypoints3d_ori' in self.human_data: | |
| keypoints3d_ori = self.human_data['keypoints3d_ori'] | |
| assert 'keypoints3d_ori_mask' in self.human_data | |
| keypoints3d_ori_mask = self.human_data['keypoints3d_ori_mask'] | |
| keypoints3d_ori, keypoints3d_ori_mask = \ | |
| convert_kps( | |
| keypoints3d_ori, | |
| src='human_data', | |
| dst=self.convention, | |
| mask=keypoints3d_ori_mask) | |
| self.human_data.__setitem__('keypoints3d_ori', keypoints3d_ori) | |
| self.human_data.__setitem__('keypoints3d_ori_convention', | |
| self.convention) | |
| self.human_data.__setitem__('keypoints3d_ori_mask', | |
| keypoints3d_ori_mask) | |
| elif 'keypoints3d' in self.human_data: | |
| keypoints3d_ori = self.human_data['keypoints3d'] | |
| assert 'keypoints3d_mask' in self.human_data | |
| keypoints3d_ori_mask = self.human_data['keypoints3d_mask'] | |
| keypoints3d_ori, keypoints3d_ori_mask = \ | |
| convert_kps( | |
| keypoints3d_ori, | |
| src='human_data', | |
| dst=self.convention, | |
| mask=keypoints3d_ori_mask) | |
| self.human_data.__setitem__('keypoints3d_ori', keypoints3d_ori) | |
| self.human_data.__setitem__('keypoints3d_ori_convention', | |
| self.convention) | |
| self.human_data.__setitem__('keypoints3d_ori_mask', | |
| keypoints3d_ori_mask) | |
| if 'keypoints2d_ori' in self.human_data: | |
| keypoints2d_ori = self.human_data['keypoints2d_ori'] | |
| assert 'keypoints2d_ori_mask' in self.human_data | |
| keypoints2d_ori_mask = self.human_data['keypoints2d_ori_mask'] | |
| keypoints2d_ori, keypoints2d_ori_mask = \ | |
| convert_kps( | |
| keypoints2d_ori, | |
| src='human_data', | |
| dst=self.convention, | |
| mask=keypoints2d_ori_mask) | |
| self.human_data.__setitem__('keypoints2d_ori', keypoints2d_ori) | |
| self.human_data.__setitem__('keypoints2d_ori_convention', | |
| self.convention) | |
| self.human_data.__setitem__('keypoints2d_ori_mask', | |
| keypoints2d_ori_mask) | |
| ori_mask = keypoints2d_ori[:, :, 2] | |
| elif 'keypoints2d' in self.human_data: | |
| keypoints2d_ori = self.human_data['keypoints2d'] | |
| assert 'keypoints2d_mask' in self.human_data | |
| keypoints2d_ori_mask = self.human_data['keypoints2d_mask'] | |
| keypoints2d_ori, keypoints2d_ori_mask = \ | |
| convert_kps( | |
| keypoints2d_ori, | |
| src='human_data', | |
| dst=self.convention, | |
| mask=keypoints2d_ori_mask) | |
| self.human_data.__setitem__('keypoints2d_ori', keypoints2d_ori) | |
| self.human_data.__setitem__('keypoints2d_ori_convention', | |
| self.convention) | |
| self.human_data.__setitem__('keypoints2d_ori_mask', | |
| keypoints2d_ori_mask) | |
| # if 'has_smpl' in self.human_data: | |
| # index = ori_mask.sum(-1)>=8 | |
| # self.human_data['has_smpl']=self.human_data['has_smpl'][:147270]*index | |
| # change keypoint from 'human_data' to the given convention | |
| if 'keypoints3d_smpl' in self.human_data: | |
| keypoints3d_smpl = self.human_data['keypoints3d_smpl'] | |
| assert 'keypoints3d_smpl_mask' in self.human_data | |
| keypoints3d_smpl_mask = self.human_data['keypoints3d_smpl_mask'] | |
| keypoints3d_smpl, keypoints3d_smpl_mask = \ | |
| convert_kps( | |
| keypoints3d_smpl, | |
| src='human_data', | |
| dst=self.convention, | |
| mask=keypoints3d_smpl_mask) | |
| # index = ori_mask.sum(-1)<8 | |
| # index = ori_mask.sum(-1)<8 | |
| # keypoints3d_smpl[index]=np.concatenate( | |
| # [keypoints3d_smpl[index][:,:,:3], | |
| # keypoints2d_ori[index][:,:,[2]]], | |
| # -1) | |
| self.human_data.__setitem__('keypoints3d_smpl', keypoints3d_smpl) | |
| self.human_data.__setitem__('keypoints3d_smpl_convention', | |
| self.convention) | |
| self.human_data.__setitem__('keypoints3d_smpl_mask', | |
| keypoints3d_smpl_mask) | |
| if 'keypoints2d_smpl' in self.human_data: | |
| keypoints2d_smpl = self.human_data['keypoints2d_smpl'] | |
| assert 'keypoints2d_smpl_mask' in self.human_data | |
| keypoints2d_smpl_mask = self.human_data['keypoints2d_smpl_mask'] | |
| keypoints2d_smpl, keypoints2d_smpl_mask = \ | |
| convert_kps( | |
| keypoints2d_smpl, | |
| src='human_data', | |
| dst=self.convention, | |
| mask=keypoints2d_smpl_mask) | |
| # index = ori_mask.sum(-1)<8 | |
| # keypoints2d_smpl[index]=np.concatenate( | |
| # [keypoints2d_smpl[index][:,:,:2], | |
| # keypoints2d_ori[index][:,:,[2]]], | |
| # -1) | |
| # keypoints2d_smpl[index][:,:,2]=keypoints2d_ori[index][:, :,2] | |
| self.human_data.__setitem__('keypoints2d_smpl', keypoints2d_smpl) | |
| self.human_data.__setitem__('keypoints2d_smpl_convention', | |
| self.convention) | |
| self.human_data.__setitem__('keypoints2d_smpl_mask', | |
| keypoints2d_smpl_mask) | |
| self.human_data.compress_keypoints_by_mask() | |
| def prepare_raw_data(self, idx: int): | |
| """Get item from self.human_data.""" | |
| sample_idx = idx | |
| frame_start, frame_end = self.frame_range[idx] | |
| frame_num = frame_end - frame_start | |
| # TODO: Support cache_reader? | |
| info = {} | |
| info['img_prefix'] = None | |
| image_path = self.human_data['image_path'][frame_start] | |
| info['image_path'] = os.path.join(self.data_prefix, 'datasets', | |
| self.dataset_name, image_path) | |
| # TODO: Support smc? | |
| info['dataset_name'] = self.dataset_name | |
| info['sample_idx'] = sample_idx | |
| if 'bbox_xywh' in self.human_data: | |
| info['bbox_xywh'] = self.human_data['bbox_xywh'][ | |
| frame_start:frame_end] | |
| center, scale = [], [] | |
| for bbox in info['bbox_xywh']: | |
| x, y, w, h, s = bbox | |
| cx = x + w / 2 | |
| cy = y + h / 2 | |
| # TODO: verify if we should keep w = h = max(w, h) for multi human data | |
| w = h = max(w, h) | |
| center.append([cx, cy]) | |
| scale.append([w, h]) | |
| info['center'] = np.array(center) | |
| info['scale'] = np.array(scale) | |
| else: | |
| info['bbox_xywh'] = np.zeros((frame_num, 5)) | |
| info['center'] = np.zeros((frame_num, 2)) | |
| info['scale'] = np.zeros((frame_num, 2)) | |
| if 'keypoints2d_ori' in self.human_data: | |
| info['keypoints2d_ori'] = self.human_data['keypoints2d_ori'][ | |
| frame_start:frame_end] | |
| conf = info['keypoints2d_ori'][..., -1].sum(-1) > 0 | |
| info['has_keypoints2d_ori'] = np.ones( | |
| (frame_num, 1)) * conf[..., None] | |
| else: | |
| info['keypoints2d_ori'] = np.zeros( | |
| (frame_num, self.num_keypoints, 3)) | |
| info['has_keypoints2d_ori'] = np.zeros((frame_num, 1)) | |
| if 'keypoints3d_ori' in self.human_data: | |
| info['keypoints3d_ori'] = self.human_data['keypoints3d_ori'][ | |
| frame_start:frame_end] | |
| conf = info['keypoints3d_ori'][..., -1].sum(-1) > 0 | |
| info['has_keypoints3d_ori'] = np.ones( | |
| (frame_num, 1)) * conf[..., None] | |
| else: | |
| info['keypoints3d_ori'] = np.zeros( | |
| (frame_num, self.num_keypoints, 4)) | |
| info['has_keypoints3d_ori'] = np.zeros((frame_num, 1)) | |
| if 'keypoints2d_smpl' in self.human_data: | |
| info['keypoints2d_smpl'] = self.human_data['keypoints2d_smpl'][ | |
| frame_start:frame_end] | |
| conf = info['keypoints2d_smpl'][..., -1].sum(-1) > 0 | |
| info['has_keypoints2d_smpl'] = np.ones( | |
| (frame_num, 1)) * conf[..., None] | |
| else: | |
| info['keypoints2d_smpl'] = np.zeros( | |
| (frame_num, self.num_keypoints, 3)) | |
| info['has_keypoints2d_smpl'] = np.zeros((frame_num, 1)) | |
| if 'keypoints3d_smpl' in self.human_data: | |
| info['keypoints3d_smpl'] = self.human_data['keypoints3d_smpl'][ | |
| frame_start:frame_end] | |
| conf = info['keypoints3d_smpl'][..., -1].sum(-1) > 0 | |
| info['has_keypoints3d_smpl'] = np.ones( | |
| (frame_num, 1)) * conf[..., None] | |
| else: | |
| info['keypoints3d_smpl'] = np.zeros( | |
| (frame_num, self.num_keypoints, 4)) | |
| info['has_keypoints3d_smpl'] = np.zeros((frame_num, 1)) | |
| if 'smpl' in self.human_data: | |
| if 'has_smpl' in self.human_data: | |
| info['has_smpl'] = \ | |
| self.human_data['has_smpl'][frame_start:frame_end] | |
| else: | |
| info['has_smpl'] = np.ones((frame_num, 1)) | |
| smpl_dict = self.human_data['smpl'] | |
| else: | |
| info['has_smpl'] = np.zeros((frame_num, 1)) | |
| smpl_dict = {} | |
| if 'body_pose' in smpl_dict: | |
| info['smpl_body_pose'] = smpl_dict['body_pose'][ | |
| frame_start:frame_end] | |
| else: | |
| info['smpl_body_pose'] = np.zeros((frame_num, 23, 3)) | |
| if 'global_orient' in smpl_dict: | |
| info['smpl_global_orient'] = smpl_dict['global_orient'][ | |
| frame_start:frame_end] | |
| else: | |
| info['smpl_global_orient'] = np.zeros((frame_num, 3)) | |
| if 'betas' in smpl_dict: | |
| info['smpl_betas'] = smpl_dict['betas'][frame_start:frame_end] | |
| else: | |
| info['smpl_betas'] = np.zeros((frame_num, 10)) | |
| if 'transl' in smpl_dict: | |
| info['smpl_transl'] = smpl_dict['transl'][frame_start:frame_end] | |
| else: | |
| info['smpl_transl'] = np.zeros((frame_num, 3)) | |
| if 'area' in self.human_data: | |
| info['area'] = self.human_data['area'][frame_start:frame_end] | |
| else: | |
| info['area'] = np.zeros((frame_num, 0)) | |
| return info | |
| def prepare_data(self, idx: int): | |
| """Generate and transform data.""" | |
| info = self.prepare_raw_data(idx) | |
| return self.pipeline(info) | |
| def evaluate(self, | |
| outputs: list, | |
| res_folder: str, | |
| metric: Optional[Union[str, List[str]]] = 'pa-mpjpe', | |
| **kwargs: dict): | |
| """Evaluate 3D keypoint results. | |
| Args: | |
| outputs (list): results from model inference. | |
| res_folder (str): path to store results. | |
| metric (Optional[Union[str, List(str)]]): | |
| the type of metric. Default: 'pa-mpjpe' | |
| kwargs (dict): other arguments. | |
| Returns: | |
| dict: | |
| A dict of all evaluation results. | |
| """ | |
| metrics = metric if isinstance(metric, list) else [metric] | |
| for metric in metrics: | |
| if metric not in self.ALLOWED_METRICS: | |
| raise KeyError(f'metric {metric} is not supported') | |
| res_file = os.path.join(res_folder, 'result_keypoints.json') | |
| # for keeping correctness during multi-gpu test, we sort all results | |
| res_dict = {} | |
| # 'scores', 'labels', 'boxes', 'keypoints', 'pred_smpl_pose', | |
| # 'pred_smpl_beta', 'pred_smpl_cam', 'pred_smpl_kp3d', | |
| # 'gt_smpl_pose', 'gt_smpl_beta', 'gt_smpl_kp3d', 'gt_boxes', | |
| # 'gt_keypoints', 'image_idx' | |
| for out in outputs: | |
| target_id = out['image_idx'] | |
| batch_size = len(out['pred_smpl_kp3d']) | |
| for i in range(batch_size): | |
| res_dict[int(target_id[i])] = dict( | |
| keypoints=out['pred_smpl_kp3d'][i], | |
| gt_poses=out['gt_smpl_pose'][i], | |
| gt_betas=out['gt_smpl_beta'][i], | |
| pred_poses=out['pred_smpl_pose'][i], | |
| pred_betas=out['pred_smpl_beta'][i]) | |
| keypoints, gt_poses, gt_betas, pred_poses, pred_betas = \ | |
| [], [], [], [], [] | |
| # print(self.num_data) | |
| for i in range(self.num_data): | |
| keypoints.append(res_dict[i]['keypoints']) | |
| gt_poses.append(res_dict[i]['gt_poses']) | |
| gt_betas.append(res_dict[i]['gt_betas']) | |
| pred_poses.append(res_dict[i]['pred_poses']) | |
| pred_betas.append(res_dict[i]['pred_betas']) | |
| res = dict(keypoints=keypoints, | |
| gt_poses=gt_poses, | |
| gt_betas=gt_betas, | |
| pred_poses=pred_poses, | |
| pred_betas=pred_betas) | |
| # mmcv.dump(res, res_file) | |
| name_value_tuples = [] | |
| for _metric in metrics: | |
| if _metric == 'mpjpe': | |
| _nv_tuples = self._report_mpjpe(res) | |
| elif _metric == 'pa-mpjpe': | |
| _nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe') | |
| print(_nv_tuples) | |
| elif _metric == '3dpck': | |
| _nv_tuples = self._report_3d_pck(res) | |
| elif _metric == 'pa-3dpck': | |
| _nv_tuples = self._report_3d_pck(res, metric='pa-3dpck') | |
| elif _metric == '3dauc': | |
| _nv_tuples = self._report_3d_auc(res) | |
| elif _metric == 'pa-3dauc': | |
| _nv_tuples = self._report_3d_auc(res, metric='pa-3dauc') | |
| elif _metric == 'pve': | |
| _nv_tuples = self._report_pve(res) | |
| elif _metric == 'ihmr': | |
| _nv_tuples = self._report_ihmr(res) | |
| else: | |
| raise NotImplementedError | |
| name_value_tuples.extend(_nv_tuples) | |
| name_value = OrderedDict(name_value_tuples) | |
| return name_value | |
| def _write_keypoint_results(keypoints: Any, res_file: str): | |
| """Write results into a json file.""" | |
| with open(res_file, 'w') as f: | |
| json.dump(keypoints, f, sort_keys=True, indent=4) | |
| def _parse_result(self, res, mode='keypoint', body_part=None): | |
| """Parse results.""" | |
| if mode == 'vertice': | |
| # gt | |
| gt_beta, gt_pose, gt_global_orient, gender = [], [], [], [] | |
| gt_smpl_dict = self.human_data['smpl'] | |
| for idx in range(self.num_data): | |
| gt_beta.append(gt_smpl_dict['betas'][idx]) | |
| gt_pose.append(gt_smpl_dict['body_pose'][idx]) | |
| gt_global_orient.append(gt_smpl_dict['global_orient'][idx]) | |
| if self.human_data['meta']['gender'][idx] == 'm': | |
| gender.append(0) | |
| else: | |
| gender.append(1) | |
| gt_beta = torch.FloatTensor(gt_beta) | |
| gt_pose = torch.FloatTensor(gt_pose).view(-1, 69) | |
| gt_global_orient = torch.FloatTensor(gt_global_orient) | |
| gender = torch.Tensor(gender) | |
| gt_output = self.body_model(betas=gt_beta, | |
| body_pose=gt_pose, | |
| global_orient=gt_global_orient, | |
| gender=gender) | |
| gt_vertices = gt_output['vertices'].detach().cpu().numpy() * 1000. | |
| gt_mask = np.ones(gt_vertices.shape[:-1]) | |
| # pred | |
| pred_pose = torch.FloatTensor(res['pred_poses']) | |
| pred_beta = torch.FloatTensor(res['pred_betas']) | |
| pred_output = self.body_model( | |
| betas=pred_beta[:, 0], | |
| body_pose=pred_pose[:, 0, 1:], | |
| global_orient=pred_pose[:, 0, 0].unsqueeze(1), | |
| pose2rot=False) | |
| pred_vertices = pred_output['vertices'].detach().cpu().numpy( | |
| ) * 1000. | |
| assert len(pred_vertices) == self.num_data | |
| return pred_vertices, gt_vertices, gt_mask | |
| elif mode == 'keypoint': | |
| pred_keypoints3d = res['keypoints'] | |
| assert len(pred_keypoints3d) == self.num_data | |
| # (B, 17, 3) | |
| pred_keypoints3d = np.array(pred_keypoints3d).reshape( | |
| len(pred_keypoints3d), -1, 3) | |
| # pred_keypoints3d,_ = convert_kps( | |
| # pred_keypoints3d, | |
| # src='smpl_54', | |
| # dst='h36m', | |
| # ) | |
| gt_smpl_pose = np.array(res['gt_poses']) | |
| gt_body_pose = gt_smpl_pose[..., 1:, :] | |
| gt_global_orient = gt_smpl_pose[..., 0, :] | |
| gt_betas = np.array(res['gt_betas']) | |
| gender = np.zeros([gt_betas.shape[0], gt_betas.shape[1]]) | |
| if self.dataset_name == 'pw3d': | |
| # betas = [] | |
| # body_pose = [] | |
| # global_orient = [] | |
| # gender = [] | |
| # smpl_dict = self.human_data['smpl'] | |
| # for idx in range(self.num_data): | |
| # betas.append(smpl_dict['betas'][idx]) | |
| # body_pose.append(smpl_dict['body_pose'][idx]) | |
| # global_orient.append(smpl_dict['global_orient'][idx]) | |
| # if self.human_data['meta']['gender'][idx] == 'm': | |
| # gender.append(0) | |
| # else: | |
| # gender.append(1) | |
| betas = torch.FloatTensor(gt_betas).view(-1, 10) | |
| body_pose = torch.FloatTensor(gt_body_pose).view(-1, 69) | |
| global_orient = torch.FloatTensor(gt_global_orient).view(-1, 3) | |
| gender = torch.Tensor(gender).view(-1) | |
| gt_output = self.body_model(betas=betas, | |
| body_pose=body_pose, | |
| global_orient=global_orient, | |
| gender=gender) | |
| gt_keypoints3d = gt_output['joints'].detach().cpu().numpy() | |
| # gt_keypoints3d,_ = convert_kps( | |
| # gt_keypoints3d, | |
| # src='smpl_54', | |
| # dst='h36m') | |
| gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17)) | |
| elif self.dataset_name == 'h36m': | |
| _, h36m_idxs, _ = get_mapping('human_data', 'h36m') | |
| gt_keypoints3d = \ | |
| self.human_data['keypoints3d'][:, h36m_idxs, :3] | |
| gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17)) | |
| elif self.dataset_name == 'humman': | |
| betas = [] | |
| body_pose = [] | |
| global_orient = [] | |
| smpl_dict = self.human_data['smpl'] | |
| for idx in range(self.num_data): | |
| betas.append(smpl_dict['betas'][idx]) | |
| body_pose.append(smpl_dict['body_pose'][idx]) | |
| global_orient.append(smpl_dict['global_orient'][idx]) | |
| betas = torch.FloatTensor(betas) | |
| body_pose = torch.FloatTensor(body_pose).view(-1, 69) | |
| global_orient = torch.FloatTensor(global_orient) | |
| gt_output = self.body_model(betas=betas, | |
| body_pose=body_pose, | |
| global_orient=global_orient) | |
| gt_keypoints3d = gt_output['joints'].detach().cpu().numpy() | |
| gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24)) | |
| else: | |
| raise NotImplementedError() | |
| # SMPL_49 only! | |
| if gt_keypoints3d.shape[1] == 49: | |
| assert pred_keypoints3d.shape[1] == 49 | |
| gt_keypoints3d = gt_keypoints3d[:, 25:, :] | |
| pred_keypoints3d = pred_keypoints3d[:, 25:, :] | |
| joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] | |
| gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
| pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
| # we only evaluate on 14 lsp joints | |
| pred_pelvis = (pred_keypoints3d[:, 2] + | |
| pred_keypoints3d[:, 3]) / 2 | |
| gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2 | |
| # H36M for testing! | |
| elif gt_keypoints3d.shape[1] == 17: | |
| assert pred_keypoints3d.shape[-2] == 17 | |
| H36M_TO_J17 = [ | |
| 6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9 | |
| ] | |
| H36M_TO_J14 = H36M_TO_J17[:14] | |
| joint_mapper = H36M_TO_J14 | |
| pred_pelvis = pred_keypoints3d[:, 0] | |
| gt_pelvis = gt_keypoints3d[:, 0] | |
| gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
| pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
| # keypoint 24 | |
| elif gt_keypoints3d.shape[1] == 24: | |
| assert pred_keypoints3d.shape[1] == 24 | |
| joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] | |
| gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
| pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
| # we only evaluate on 14 lsp joints | |
| pred_pelvis = (pred_keypoints3d[:, 2] + | |
| pred_keypoints3d[:, 3]) / 2 | |
| gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2 | |
| else: | |
| pass | |
| pred_keypoints3d = (pred_keypoints3d - | |
| pred_pelvis[:, None, :]) * 1000 | |
| gt_keypoints3d = (gt_keypoints3d - gt_pelvis[:, None, :]) * 1000 | |
| gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper] > 0 | |
| return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask | |
| def _report_mpjpe(self, res_file, metric='mpjpe', body_part=''): | |
| """Cauculate mean per joint position error (MPJPE) or its variants PA- | |
| MPJPE. | |
| Report mean per joint position error (MPJPE) and mean per joint | |
| position error after rigid alignment (PA-MPJPE) | |
| """ | |
| pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
| self._parse_result(res_file, mode='keypoint', body_part=body_part) | |
| err_name = metric.upper() | |
| if body_part != '': | |
| err_name = body_part.upper() + ' ' + err_name | |
| if metric == 'mpjpe': | |
| alignment = 'none' | |
| elif metric == 'pa-mpjpe': | |
| alignment = 'procrustes' | |
| else: | |
| raise ValueError(f'Invalid metric: {metric}') | |
| error = keypoint_mpjpe(pred_keypoints3d, gt_keypoints3d, | |
| gt_keypoints3d_mask, alignment) | |
| info_str = [(err_name, error)] | |
| return info_str | |
| def _report_3d_pck(self, res_file, metric='3dpck'): | |
| """Cauculate Percentage of Correct Keypoints (3DPCK) w. or w/o | |
| Procrustes alignment. | |
| Args: | |
| keypoint_results (list): Keypoint predictions. See | |
| 'Body3DMpiInf3dhpDataset.evaluate' for details. | |
| metric (str): Specify mpjpe variants. Supported options are: | |
| - ``'3dpck'``: Standard 3DPCK. | |
| - ``'pa-3dpck'``: | |
| 3DPCK after aligning prediction to groundtruth | |
| via a rigid transformation (scale, rotation and | |
| translation). | |
| """ | |
| pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
| self._parse_result(res_file) | |
| err_name = metric.upper() | |
| if metric == '3dpck': | |
| alignment = 'none' | |
| elif metric == 'pa-3dpck': | |
| alignment = 'procrustes' | |
| else: | |
| raise ValueError(f'Invalid metric: {metric}') | |
| error = keypoint_3d_pck(pred_keypoints3d, gt_keypoints3d, | |
| gt_keypoints3d_mask, alignment) | |
| name_value_tuples = [(err_name, error)] | |
| return name_value_tuples | |
| def _report_3d_auc(self, res_file, metric='3dauc'): | |
| """Cauculate the Area Under the Curve (AUC) computed for a range of | |
| 3DPCK thresholds. | |
| Args: | |
| keypoint_results (list): Keypoint predictions. See | |
| 'Body3DMpiInf3dhpDataset.evaluate' for details. | |
| metric (str): Specify mpjpe variants. Supported options are: | |
| - ``'3dauc'``: Standard 3DAUC. | |
| - ``'pa-3dauc'``: 3DAUC after aligning prediction to | |
| groundtruth via a rigid transformation (scale, rotation and | |
| translation). | |
| """ | |
| pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
| self._parse_result(res_file) | |
| err_name = metric.upper() | |
| if metric == '3dauc': | |
| alignment = 'none' | |
| elif metric == 'pa-3dauc': | |
| alignment = 'procrustes' | |
| else: | |
| raise ValueError(f'Invalid metric: {metric}') | |
| error = keypoint_3d_auc(pred_keypoints3d, gt_keypoints3d, | |
| gt_keypoints3d_mask, alignment) | |
| name_value_tuples = [(err_name, error)] | |
| return name_value_tuples | |
| def _report_pve(self, res_file, metric='pve', body_part=''): | |
| """Cauculate per vertex error.""" | |
| pred_verts, gt_verts, _ = \ | |
| self._parse_result(res_file, mode='vertice', body_part=body_part) | |
| err_name = metric.upper() | |
| if body_part != '': | |
| err_name = body_part.upper() + ' ' + err_name | |
| if metric == 'pve': | |
| alignment = 'none' | |
| elif metric == 'pa-pve': | |
| alignment = 'procrustes' | |
| else: | |
| raise ValueError(f'Invalid metric: {metric}') | |
| error = vertice_pve(pred_verts, gt_verts, alignment) | |
| return [(err_name, error)] | |
| def _report_ihmr(self, res_file): | |
| """Calculate IHMR metric. | |
| https://arxiv.org/abs/2203.16427 | |
| """ | |
| pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
| self._parse_result(res_file, mode='keypoint') | |
| pred_verts, gt_verts, _ = \ | |
| self._parse_result(res_file, mode='vertice') | |
| from detrsmpl.utils.geometry import rot6d_to_rotmat | |
| mean_param_path = 'data/body_models/smpl_mean_params.npz' | |
| mean_params = np.load(mean_param_path) | |
| mean_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) | |
| mean_shape = torch.from_numpy( | |
| mean_params['shape'][:].astype('float32')).unsqueeze(0) | |
| mean_pose = rot6d_to_rotmat(mean_pose).view(1, 24, 3, 3) | |
| mean_output = self.body_model(betas=mean_shape, | |
| body_pose=mean_pose[:, 1:], | |
| global_orient=mean_pose[:, :1], | |
| pose2rot=False) | |
| mean_verts = mean_output['vertices'].detach().cpu().numpy() * 1000. | |
| dis = (gt_verts - mean_verts) * (gt_verts - mean_verts) | |
| dis = np.sqrt(dis.sum(axis=-1)).mean(axis=-1) | |
| # from the most remote one to the nearest one | |
| idx_order = np.argsort(dis)[::-1] | |
| num_data = idx_order.shape[0] | |
| def report_ihmr_idx(idx): | |
| mpvpe = vertice_pve(pred_verts[idx], gt_verts[idx]) | |
| mpjpe = keypoint_mpjpe(pred_keypoints3d[idx], gt_keypoints3d[idx], | |
| gt_keypoints3d_mask[idx], 'none') | |
| pampjpe = keypoint_mpjpe(pred_keypoints3d[idx], | |
| gt_keypoints3d[idx], | |
| gt_keypoints3d_mask[idx], 'procrustes') | |
| return (mpvpe, mpjpe, pampjpe) | |
| def report_ihmr_tail(percentage): | |
| cur_data = int(num_data * percentage / 100.0) | |
| idx = idx_order[:cur_data] | |
| mpvpe, mpjpe, pampjpe = report_ihmr_idx(idx) | |
| res_mpvpe = ('bMPVPE Tail ' + str(percentage) + '%', mpvpe) | |
| res_mpjpe = ('bMPJPE Tail ' + str(percentage) + '%', mpjpe) | |
| res_pampjpe = ('bPA-MPJPE Tail ' + str(percentage) + '%', pampjpe) | |
| return [res_mpvpe, res_mpjpe, res_pampjpe] | |
| def report_ihmr_all(num_bin): | |
| num_per_bin = np.array([0 for _ in range(num_bin) | |
| ]).astype(np.float32) | |
| sum_mpvpe = np.array([0 | |
| for _ in range(num_bin)]).astype(np.float32) | |
| sum_mpjpe = np.array([0 | |
| for _ in range(num_bin)]).astype(np.float32) | |
| sum_pampjpe = np.array([0 for _ in range(num_bin) | |
| ]).astype(np.float32) | |
| max_dis = dis[idx_order[0]] | |
| min_dis = dis[idx_order[-1]] | |
| delta = (max_dis - min_dis) / num_bin | |
| for i in range(num_data): | |
| idx = int((dis[i] - min_dis) / delta - 0.001) | |
| res_mpvpe, res_mpjpe, res_pampjpe = report_ihmr_idx([i]) | |
| num_per_bin[idx] += 1 | |
| sum_mpvpe[idx] += res_mpvpe | |
| sum_mpjpe[idx] += res_mpjpe | |
| sum_pampjpe[idx] += res_pampjpe | |
| for i in range(num_bin): | |
| if num_per_bin[i] > 0: | |
| sum_mpvpe[i] = sum_mpvpe[i] / num_per_bin[i] | |
| sum_mpjpe[i] = sum_mpjpe[i] / num_per_bin[i] | |
| sum_pampjpe[i] = sum_pampjpe[i] / num_per_bin[i] | |
| valid_idx = np.where(num_per_bin > 0) | |
| res_mpvpe = ('bMPVPE All', sum_mpvpe[valid_idx].mean()) | |
| res_mpjpe = ('bMPJPE All', sum_mpjpe[valid_idx].mean()) | |
| res_pampjpe = ('bPA-MPJPE All', sum_pampjpe[valid_idx].mean()) | |
| return [res_mpvpe, res_mpjpe, res_pampjpe] | |
| metrics = [] | |
| metrics.extend(report_ihmr_all(num_bin=100)) | |
| metrics.extend(report_ihmr_tail(percentage=10)) | |
| metrics.extend(report_ihmr_tail(percentage=5)) | |
| return metrics | |