Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import os.path | |
| from abc import ABCMeta | |
| from collections import OrderedDict | |
| from typing import Any, List, Optional, Union | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from mmcv.runner import get_dist_info | |
| from detrsmpl.core.conventions.keypoints_mapping import ( | |
| convert_kps, | |
| get_keypoint_num, | |
| get_mapping, | |
| ) | |
| from detrsmpl.core.evaluation import ( | |
| keypoint_3d_auc, | |
| keypoint_3d_pck, | |
| keypoint_mpjpe, | |
| vertice_pve, | |
| ) | |
| from detrsmpl.data.data_structures.human_data import HumanData | |
| from detrsmpl.data.data_structures.human_data_cache import ( | |
| HumanDataCacheReader, | |
| HumanDataCacheWriter, | |
| ) | |
| from detrsmpl.models.body_models.builder import build_body_model | |
| from .base_dataset import BaseDataset | |
| from .builder import DATASETS | |
| class HumanImageDataset(BaseDataset, metaclass=ABCMeta): | |
| """Human Image Dataset. | |
| Args: | |
| data_prefix (str): the prefix of data path. | |
| pipeline (list): a list of dict, where each element represents | |
| a operation defined in `detrsmpl.datasets.pipelines`. | |
| dataset_name (str | None): the name of dataset. It is used to | |
| identify the type of evaluation metric. Default: None. | |
| body_model (dict | None, optional): the config for body model, | |
| which will be used to generate meshes and keypoints. | |
| Default: None. | |
| ann_file (str | None, optional): the annotation file. When ann_file | |
| is str, the subclass is expected to read from the ann_file. | |
| When ann_file is None, the subclass is expected to read | |
| according to data_prefix. | |
| convention (str, optional): keypoints convention. Keypoints will be | |
| converted from "human_data" to the given one. | |
| Default: "human_data" | |
| cache_data_path (str | None, optional): the path to store the cache | |
| file. When cache_data_path is None, each dataset will store a copy | |
| into memory. If cache_data_path is set, the dataset will first | |
| create one cache file and then use a cache reader to reduce memory | |
| cost and initialization time. The cache file will be generated | |
| only once if they are not found at the the path. Otherwise, only | |
| cache readers will be established. | |
| test_mode (bool, optional): in train mode or test mode. | |
| Default: False. | |
| """ | |
| # metric | |
| ALLOWED_METRICS = { | |
| 'mpjpe', 'pa-mpjpe', 'pve', '3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc', | |
| 'ihmr' | |
| } | |
| def __init__(self, | |
| data_prefix: str, | |
| pipeline: list, | |
| dataset_name: str, | |
| body_model: Optional[Union[dict, None]] = None, | |
| ann_file: Optional[Union[str, None]] = None, | |
| convention: Optional[str] = 'human_data', | |
| cache_data_path: Optional[Union[str, None]] = None, | |
| test_mode: Optional[bool] = False): | |
| self.convention = convention | |
| self.num_keypoints = get_keypoint_num(convention) | |
| self.cache_data_path = cache_data_path | |
| super(HumanImageDataset, | |
| self).__init__(data_prefix, pipeline, ann_file, test_mode, | |
| dataset_name) | |
| if body_model is not None: | |
| self.body_model = build_body_model(body_model) | |
| else: | |
| self.body_model = None | |
| def get_annotation_file(self): | |
| """Get path of the annotation file.""" | |
| ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets') | |
| self.ann_file = os.path.join(ann_prefix, self.ann_file) | |
| def load_annotations(self): | |
| """Load annotation from the annotation file. | |
| Here we simply use :obj:`HumanData` to parse the annotation. | |
| """ | |
| rank, world_size = get_dist_info() | |
| self.get_annotation_file() | |
| if self.cache_data_path is None: | |
| use_human_data = True | |
| elif rank == 0 and not os.path.exists(self.cache_data_path): | |
| use_human_data = True | |
| else: | |
| use_human_data = False | |
| if use_human_data: | |
| self.human_data = HumanData.fromfile(self.ann_file) | |
| if self.human_data.check_keypoints_compressed(): | |
| self.human_data.decompress_keypoints() | |
| # change keypoint from 'human_data' to the given convention | |
| if 'keypoints3d' in self.human_data: | |
| keypoints3d = self.human_data['keypoints3d'] | |
| assert 'keypoints3d_mask' in self.human_data | |
| keypoints3d_mask = self.human_data['keypoints3d_mask'] | |
| keypoints3d, keypoints3d_mask = \ | |
| convert_kps( | |
| keypoints3d, | |
| src='human_data', | |
| dst=self.convention, | |
| mask=keypoints3d_mask) | |
| self.human_data.__setitem__('keypoints3d', keypoints3d) | |
| self.human_data.__setitem__('keypoints3d_convention', | |
| self.convention) | |
| self.human_data.__setitem__('keypoints3d_mask', | |
| keypoints3d_mask) | |
| if 'keypoints2d' in self.human_data: | |
| keypoints2d = self.human_data['keypoints2d'] | |
| assert 'keypoints2d_mask' in self.human_data | |
| keypoints2d_mask = self.human_data['keypoints2d_mask'] | |
| keypoints2d, keypoints2d_mask = \ | |
| convert_kps( | |
| keypoints2d, | |
| src='human_data', | |
| dst=self.convention, | |
| mask=keypoints2d_mask) | |
| self.human_data.__setitem__('keypoints2d', keypoints2d) | |
| self.human_data.__setitem__('keypoints2d_convention', | |
| self.convention) | |
| self.human_data.__setitem__('keypoints2d_mask', | |
| keypoints2d_mask) | |
| self.human_data.compress_keypoints_by_mask() | |
| if self.cache_data_path is not None: | |
| if rank == 0 and not os.path.exists(self.cache_data_path): | |
| writer_kwargs, sliced_data = self.human_data.get_sliced_cache() | |
| writer = HumanDataCacheWriter(**writer_kwargs) | |
| writer.update_sliced_dict(sliced_data) | |
| writer.dump(self.cache_data_path) | |
| if world_size > 1: | |
| dist.barrier() | |
| self.cache_reader = HumanDataCacheReader( | |
| npz_path=self.cache_data_path) | |
| self.num_data = self.cache_reader.data_len | |
| self.human_data = None | |
| else: | |
| self.cache_reader = None | |
| self.num_data = self.human_data.data_len | |
| def prepare_raw_data(self, idx: int): | |
| """Get item from self.human_data.""" | |
| sample_idx = idx | |
| if self.cache_reader is not None: | |
| self.human_data = self.cache_reader.get_item(idx) | |
| idx = idx % self.cache_reader.slice_size | |
| info = {} | |
| info['img_prefix'] = None | |
| image_path = self.human_data['image_path'][idx] | |
| info['image_path'] = os.path.join(self.data_prefix, 'datasets', | |
| self.dataset_name, image_path) | |
| if image_path.endswith('smc'): | |
| device, device_id, frame_id = self.human_data['image_id'][idx] | |
| info['image_id'] = (device, int(device_id), int(frame_id)) | |
| info['dataset_name'] = self.dataset_name | |
| info['sample_idx'] = sample_idx | |
| if 'bbox_xywh' in self.human_data: | |
| info['bbox_xywh'] = self.human_data['bbox_xywh'][idx] | |
| x, y, w, h, s = info['bbox_xywh'] | |
| cx = x + w / 2 | |
| cy = y + h / 2 | |
| w = h = max(w, h) | |
| info['center'] = np.array([cx, cy]) | |
| info['scale'] = np.array([w, h]) | |
| else: | |
| info['bbox_xywh'] = np.zeros((5)) | |
| info['center'] = np.zeros((2)) | |
| info['scale'] = np.zeros((2)) | |
| # in later modules, we will check validity of each keypoint by | |
| # its confidence. Therefore, we do not need the mask of keypoints. | |
| if 'keypoints2d' in self.human_data: | |
| info['keypoints2d'] = self.human_data['keypoints2d'][idx] | |
| info['has_keypoints2d'] = 1 | |
| else: | |
| info['keypoints2d'] = np.zeros((self.num_keypoints, 3)) | |
| info['has_keypoints2d'] = 0 | |
| if 'keypoints3d' in self.human_data: | |
| info['keypoints3d'] = self.human_data['keypoints3d'][idx] | |
| info['has_keypoints3d'] = 1 | |
| else: | |
| info['keypoints3d'] = np.zeros((self.num_keypoints, 4)) | |
| info['has_keypoints3d'] = 0 | |
| if 'smpl' in self.human_data: | |
| smpl_dict = self.human_data['smpl'] | |
| else: | |
| smpl_dict = {} | |
| if 'smpl' in self.human_data: | |
| if 'has_smpl' in self.human_data: | |
| info['has_smpl'] = int(self.human_data['has_smpl'][idx]) | |
| else: | |
| info['has_smpl'] = 1 | |
| else: | |
| info['has_smpl'] = 0 | |
| if 'body_pose' in smpl_dict: | |
| info['smpl_body_pose'] = smpl_dict['body_pose'][idx] | |
| else: | |
| info['smpl_body_pose'] = np.zeros((23, 3)) | |
| if 'global_orient' in smpl_dict: | |
| info['smpl_global_orient'] = smpl_dict['global_orient'][idx] | |
| else: | |
| info['smpl_global_orient'] = np.zeros((3)) | |
| if 'betas' in smpl_dict: | |
| info['smpl_betas'] = smpl_dict['betas'][idx] | |
| else: | |
| info['smpl_betas'] = np.zeros((10)) | |
| if 'transl' in smpl_dict: | |
| info['smpl_transl'] = smpl_dict['transl'][idx] | |
| else: | |
| info['smpl_transl'] = np.zeros((3)) | |
| return info | |
| def prepare_data(self, idx: int): | |
| """Generate and transform data.""" | |
| info = self.prepare_raw_data(idx) | |
| return self.pipeline(info) | |
| def evaluate(self, | |
| outputs: list, | |
| res_folder: str, | |
| metric: Optional[Union[str, List[str]]] = 'pa-mpjpe', | |
| **kwargs: dict): | |
| """Evaluate 3D keypoint results. | |
| Args: | |
| outputs (list): results from model inference. | |
| res_folder (str): path to store results. | |
| metric (Optional[Union[str, List(str)]]): | |
| the type of metric. Default: 'pa-mpjpe' | |
| kwargs (dict): other arguments. | |
| Returns: | |
| dict: | |
| A dict of all evaluation results. | |
| """ | |
| metrics = metric if isinstance(metric, list) else [metric] | |
| for metric in metrics: | |
| if metric not in self.ALLOWED_METRICS: | |
| raise KeyError(f'metric {metric} is not supported') | |
| res_file = os.path.join(res_folder, 'result_keypoints.json') | |
| # for keeping correctness during multi-gpu test, we sort all results | |
| res_dict = {} | |
| for out in outputs: | |
| target_id = out['image_idx'] | |
| batch_size = len(out['keypoints_3d']) | |
| for i in range(batch_size): | |
| res_dict[int(target_id[i])] = dict( | |
| keypoints=out['keypoints_3d'][i], | |
| poses=out['smpl_pose'][i], | |
| betas=out['smpl_beta'][i], | |
| ) | |
| keypoints, poses, betas = [], [], [] | |
| for i in range(self.num_data): | |
| keypoints.append(res_dict[i]['keypoints']) | |
| poses.append(res_dict[i]['poses']) | |
| betas.append(res_dict[i]['betas']) | |
| res = dict(keypoints=keypoints, poses=poses, betas=betas) | |
| mmcv.dump(res, res_file) | |
| name_value_tuples = [] | |
| for _metric in metrics: | |
| if _metric == 'mpjpe': | |
| _nv_tuples = self._report_mpjpe(res) | |
| elif _metric == 'pa-mpjpe': | |
| _nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe') | |
| elif _metric == '3dpck': | |
| _nv_tuples = self._report_3d_pck(res) | |
| elif _metric == 'pa-3dpck': | |
| _nv_tuples = self._report_3d_pck(res, metric='pa-3dpck') | |
| elif _metric == '3dauc': | |
| _nv_tuples = self._report_3d_auc(res) | |
| elif _metric == 'pa-3dauc': | |
| _nv_tuples = self._report_3d_auc(res, metric='pa-3dauc') | |
| elif _metric == 'pve': | |
| _nv_tuples = self._report_pve(res) | |
| elif _metric == 'ihmr': | |
| _nv_tuples = self._report_ihmr(res) | |
| else: | |
| raise NotImplementedError | |
| name_value_tuples.extend(_nv_tuples) | |
| name_value = OrderedDict(name_value_tuples) | |
| return name_value | |
| def _write_keypoint_results(keypoints: Any, res_file: str): | |
| """Write results into a json file.""" | |
| with open(res_file, 'w') as f: | |
| json.dump(keypoints, f, sort_keys=True, indent=4) | |
| def _parse_result(self, res, mode='keypoint', body_part=None): | |
| """Parse results.""" | |
| if mode == 'vertice': | |
| # gt | |
| gt_beta, gt_pose, gt_global_orient, gender = [], [], [], [] | |
| gt_smpl_dict = self.human_data['smpl'] | |
| for idx in range(self.num_data): | |
| gt_beta.append(gt_smpl_dict['betas'][idx]) | |
| gt_pose.append(gt_smpl_dict['body_pose'][idx]) | |
| gt_global_orient.append(gt_smpl_dict['global_orient'][idx]) | |
| if self.human_data['meta']['gender'][idx] == 'm': | |
| gender.append(0) | |
| else: | |
| gender.append(1) | |
| gt_beta = torch.FloatTensor(gt_beta) | |
| gt_pose = torch.FloatTensor(gt_pose).view(-1, 69) | |
| gt_global_orient = torch.FloatTensor(gt_global_orient) | |
| gender = torch.Tensor(gender) | |
| gt_output = self.body_model(betas=gt_beta, | |
| body_pose=gt_pose, | |
| global_orient=gt_global_orient, | |
| gender=gender) | |
| gt_vertices = gt_output['vertices'].detach().cpu().numpy() * 1000. | |
| gt_mask = np.ones(gt_vertices.shape[:-1]) | |
| # pred | |
| pred_pose = torch.FloatTensor(res['poses']) | |
| pred_beta = torch.FloatTensor(res['betas']) | |
| pred_output = self.body_model( | |
| betas=pred_beta, | |
| body_pose=pred_pose[:, 1:], | |
| global_orient=pred_pose[:, 0].unsqueeze(1), | |
| pose2rot=False, | |
| gender=gender) | |
| pred_vertices = pred_output['vertices'].detach().cpu().numpy( | |
| ) * 1000. | |
| assert len(pred_vertices) == self.num_data | |
| return pred_vertices, gt_vertices, gt_mask | |
| elif mode == 'keypoint': | |
| pred_keypoints3d = res['keypoints'] | |
| assert len(pred_keypoints3d) == self.num_data | |
| # (B, 17, 3) | |
| pred_keypoints3d = np.array(pred_keypoints3d) | |
| if self.dataset_name == 'pw3d': | |
| betas = [] | |
| body_pose = [] | |
| global_orient = [] | |
| gender = [] | |
| smpl_dict = self.human_data['smpl'] | |
| for idx in range(self.num_data): | |
| betas.append(smpl_dict['betas'][idx]) | |
| body_pose.append(smpl_dict['body_pose'][idx]) | |
| global_orient.append(smpl_dict['global_orient'][idx]) | |
| if self.human_data['meta']['gender'][idx] == 'm': | |
| gender.append(0) | |
| else: | |
| gender.append(1) | |
| betas = torch.FloatTensor(betas) | |
| body_pose = torch.FloatTensor(body_pose).view(-1, 69) | |
| global_orient = torch.FloatTensor(global_orient) | |
| gender = torch.Tensor(gender) | |
| gt_output = self.body_model(betas=betas, | |
| body_pose=body_pose, | |
| global_orient=global_orient, | |
| gender=gender) | |
| gt_keypoints3d = gt_output['joints'].detach().cpu().numpy() | |
| gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24)) | |
| elif self.dataset_name == 'h36m': | |
| _, h36m_idxs, _ = get_mapping('human_data', 'h36m') | |
| gt_keypoints3d = \ | |
| self.human_data['keypoints3d'][:, h36m_idxs, :3] | |
| gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17)) | |
| elif self.dataset_name == 'humman': | |
| betas = [] | |
| body_pose = [] | |
| global_orient = [] | |
| smpl_dict = self.human_data['smpl'] | |
| for idx in range(self.num_data): | |
| betas.append(smpl_dict['betas'][idx]) | |
| body_pose.append(smpl_dict['body_pose'][idx]) | |
| global_orient.append(smpl_dict['global_orient'][idx]) | |
| betas = torch.FloatTensor(betas) | |
| body_pose = torch.FloatTensor(body_pose).view(-1, 69) | |
| global_orient = torch.FloatTensor(global_orient) | |
| gt_output = self.body_model(betas=betas, | |
| body_pose=body_pose, | |
| global_orient=global_orient) | |
| gt_keypoints3d = gt_output['joints'].detach().cpu().numpy() | |
| gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24)) | |
| else: | |
| raise NotImplementedError() | |
| # SMPL_49 only! | |
| if gt_keypoints3d.shape[1] == 49: | |
| assert pred_keypoints3d.shape[1] == 49 | |
| gt_keypoints3d = gt_keypoints3d[:, 25:, :] | |
| pred_keypoints3d = pred_keypoints3d[:, 25:, :] | |
| joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] | |
| gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
| pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
| # we only evaluate on 14 lsp joints | |
| pred_pelvis = (pred_keypoints3d[:, 2] + | |
| pred_keypoints3d[:, 3]) / 2 | |
| gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2 | |
| # H36M for testing! | |
| elif gt_keypoints3d.shape[1] == 17: | |
| assert pred_keypoints3d.shape[1] == 17 | |
| H36M_TO_J17 = [ | |
| 6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9 | |
| ] | |
| H36M_TO_J14 = H36M_TO_J17[:14] | |
| joint_mapper = H36M_TO_J14 | |
| pred_pelvis = pred_keypoints3d[:, 0] | |
| gt_pelvis = gt_keypoints3d[:, 0] | |
| gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
| pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
| # keypoint 24 | |
| elif gt_keypoints3d.shape[1] == 24: | |
| assert pred_keypoints3d.shape[1] == 24 | |
| joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] | |
| gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :] | |
| pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :] | |
| # we only evaluate on 14 lsp joints | |
| pred_pelvis = (pred_keypoints3d[:, 2] + | |
| pred_keypoints3d[:, 3]) / 2 | |
| gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2 | |
| else: | |
| pass | |
| pred_keypoints3d = (pred_keypoints3d - | |
| pred_pelvis[:, None, :]) * 1000 | |
| gt_keypoints3d = (gt_keypoints3d - gt_pelvis[:, None, :]) * 1000 | |
| gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper] > 0 | |
| return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask | |
| def _report_mpjpe(self, res_file, metric='mpjpe', body_part=''): | |
| """Cauculate mean per joint position error (MPJPE) or its variants PA- | |
| MPJPE. | |
| Report mean per joint position error (MPJPE) and mean per joint | |
| position error after rigid alignment (PA-MPJPE) | |
| """ | |
| pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
| self._parse_result(res_file, mode='keypoint', body_part=body_part) | |
| err_name = metric.upper() | |
| if body_part != '': | |
| err_name = body_part.upper() + ' ' + err_name | |
| if metric == 'mpjpe': | |
| alignment = 'none' | |
| elif metric == 'pa-mpjpe': | |
| alignment = 'procrustes' | |
| else: | |
| raise ValueError(f'Invalid metric: {metric}') | |
| error = keypoint_mpjpe(pred_keypoints3d, gt_keypoints3d, | |
| gt_keypoints3d_mask, alignment) | |
| info_str = [(err_name, error)] | |
| return info_str | |
| def _report_3d_pck(self, res_file, metric='3dpck'): | |
| """Cauculate Percentage of Correct Keypoints (3DPCK) w. or w/o | |
| Procrustes alignment. | |
| Args: | |
| keypoint_results (list): Keypoint predictions. See | |
| 'Body3DMpiInf3dhpDataset.evaluate' for details. | |
| metric (str): Specify mpjpe variants. Supported options are: | |
| - ``'3dpck'``: Standard 3DPCK. | |
| - ``'pa-3dpck'``: | |
| 3DPCK after aligning prediction to groundtruth | |
| via a rigid transformation (scale, rotation and | |
| translation). | |
| """ | |
| pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
| self._parse_result(res_file) | |
| err_name = metric.upper() | |
| if metric == '3dpck': | |
| alignment = 'none' | |
| elif metric == 'pa-3dpck': | |
| alignment = 'procrustes' | |
| else: | |
| raise ValueError(f'Invalid metric: {metric}') | |
| error = keypoint_3d_pck(pred_keypoints3d, gt_keypoints3d, | |
| gt_keypoints3d_mask, alignment) | |
| name_value_tuples = [(err_name, error)] | |
| return name_value_tuples | |
| def _report_3d_auc(self, res_file, metric='3dauc'): | |
| """Cauculate the Area Under the Curve (AUC) computed for a range of | |
| 3DPCK thresholds. | |
| Args: | |
| keypoint_results (list): Keypoint predictions. See | |
| 'Body3DMpiInf3dhpDataset.evaluate' for details. | |
| metric (str): Specify mpjpe variants. Supported options are: | |
| - ``'3dauc'``: Standard 3DAUC. | |
| - ``'pa-3dauc'``: 3DAUC after aligning prediction to | |
| groundtruth via a rigid transformation (scale, rotation and | |
| translation). | |
| """ | |
| pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
| self._parse_result(res_file) | |
| err_name = metric.upper() | |
| if metric == '3dauc': | |
| alignment = 'none' | |
| elif metric == 'pa-3dauc': | |
| alignment = 'procrustes' | |
| else: | |
| raise ValueError(f'Invalid metric: {metric}') | |
| error = keypoint_3d_auc(pred_keypoints3d, gt_keypoints3d, | |
| gt_keypoints3d_mask, alignment) | |
| name_value_tuples = [(err_name, error)] | |
| return name_value_tuples | |
| def _report_pve(self, res_file, metric='pve', body_part=''): | |
| """Cauculate per vertex error.""" | |
| pred_verts, gt_verts, _ = \ | |
| self._parse_result(res_file, mode='vertice', body_part=body_part) | |
| err_name = metric.upper() | |
| if body_part != '': | |
| err_name = body_part.upper() + ' ' + err_name | |
| if metric == 'pve': | |
| alignment = 'none' | |
| elif metric == 'pa-pve': | |
| alignment = 'procrustes' | |
| else: | |
| raise ValueError(f'Invalid metric: {metric}') | |
| error = vertice_pve(pred_verts, gt_verts, alignment) | |
| return [(err_name, error)] | |
| def _report_ihmr(self, res_file): | |
| """Calculate IHMR metric. | |
| https://arxiv.org/abs/2203.16427 | |
| """ | |
| pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \ | |
| self._parse_result(res_file, mode='keypoint') | |
| pred_verts, gt_verts, _ = \ | |
| self._parse_result(res_file, mode='vertice') | |
| from detrsmpl.utils.geometry import rot6d_to_rotmat | |
| mean_param_path = 'data/body_models/smpl_mean_params.npz' | |
| mean_params = np.load(mean_param_path) | |
| mean_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) | |
| mean_shape = torch.from_numpy( | |
| mean_params['shape'][:].astype('float32')).unsqueeze(0) | |
| mean_pose = rot6d_to_rotmat(mean_pose).view(1, 24, 3, 3) | |
| mean_output = self.body_model(betas=mean_shape, | |
| body_pose=mean_pose[:, 1:], | |
| global_orient=mean_pose[:, :1], | |
| pose2rot=False) | |
| mean_verts = mean_output['vertices'].detach().cpu().numpy() * 1000. | |
| dis = (gt_verts - mean_verts) * (gt_verts - mean_verts) | |
| dis = np.sqrt(dis.sum(axis=-1)).mean(axis=-1) | |
| # from the most remote one to the nearest one | |
| idx_order = np.argsort(dis)[::-1] | |
| num_data = idx_order.shape[0] | |
| def report_ihmr_idx(idx): | |
| mpvpe = vertice_pve(pred_verts[idx], gt_verts[idx]) | |
| mpjpe = keypoint_mpjpe(pred_keypoints3d[idx], gt_keypoints3d[idx], | |
| gt_keypoints3d_mask[idx], 'none') | |
| pampjpe = keypoint_mpjpe(pred_keypoints3d[idx], | |
| gt_keypoints3d[idx], | |
| gt_keypoints3d_mask[idx], 'procrustes') | |
| return (mpvpe, mpjpe, pampjpe) | |
| def report_ihmr_tail(percentage): | |
| cur_data = int(num_data * percentage / 100.0) | |
| idx = idx_order[:cur_data] | |
| mpvpe, mpjpe, pampjpe = report_ihmr_idx(idx) | |
| res_mpvpe = ('bMPVPE Tail ' + str(percentage) + '%', mpvpe) | |
| res_mpjpe = ('bMPJPE Tail ' + str(percentage) + '%', mpjpe) | |
| res_pampjpe = ('bPA-MPJPE Tail ' + str(percentage) + '%', pampjpe) | |
| return [res_mpvpe, res_mpjpe, res_pampjpe] | |
| def report_ihmr_all(num_bin): | |
| num_per_bin = np.array([0 for _ in range(num_bin) | |
| ]).astype(np.float32) | |
| sum_mpvpe = np.array([0 | |
| for _ in range(num_bin)]).astype(np.float32) | |
| sum_mpjpe = np.array([0 | |
| for _ in range(num_bin)]).astype(np.float32) | |
| sum_pampjpe = np.array([0 for _ in range(num_bin) | |
| ]).astype(np.float32) | |
| max_dis = dis[idx_order[0]] | |
| min_dis = dis[idx_order[-1]] | |
| delta = (max_dis - min_dis) / num_bin | |
| for i in range(num_data): | |
| idx = int((dis[i] - min_dis) / delta - 0.001) | |
| res_mpvpe, res_mpjpe, res_pampjpe = report_ihmr_idx([i]) | |
| num_per_bin[idx] += 1 | |
| sum_mpvpe[idx] += res_mpvpe | |
| sum_mpjpe[idx] += res_mpjpe | |
| sum_pampjpe[idx] += res_pampjpe | |
| for i in range(num_bin): | |
| if num_per_bin[i] > 0: | |
| sum_mpvpe[i] = sum_mpvpe[i] / num_per_bin[i] | |
| sum_mpjpe[i] = sum_mpjpe[i] / num_per_bin[i] | |
| sum_pampjpe[i] = sum_pampjpe[i] / num_per_bin[i] | |
| valid_idx = np.where(num_per_bin > 0) | |
| res_mpvpe = ('bMPVPE All', sum_mpvpe[valid_idx].mean()) | |
| res_mpjpe = ('bMPJPE All', sum_mpjpe[valid_idx].mean()) | |
| res_pampjpe = ('bPA-MPJPE All', sum_pampjpe[valid_idx].mean()) | |
| return [res_mpvpe, res_mpjpe, res_pampjpe] | |
| metrics = [] | |
| metrics.extend(report_ihmr_all(num_bin=100)) | |
| metrics.extend(report_ihmr_tail(percentage=10)) | |
| metrics.extend(report_ihmr_tail(percentage=5)) | |
| return metrics | |