Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import logging | |
| import os | |
| from collections import defaultdict | |
| from typing import Dict, List, Optional, Sequence, Tuple, Union | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmengine.config import Config, ConfigDict | |
| from mmengine.infer.infer import ModelType | |
| from mmengine.logging import print_log | |
| from mmengine.model import revert_sync_batchnorm | |
| from mmengine.registry import init_default_scope | |
| from mmengine.structures import InstanceData | |
| from mmpose.evaluation.functional import nms | |
| from mmpose.registry import INFERENCERS | |
| from mmpose.structures import PoseDataSample, merge_data_samples | |
| from .base_mmpose_inferencer import BaseMMPoseInferencer | |
| InstanceList = List[InstanceData] | |
| InputType = Union[str, np.ndarray] | |
| InputsType = Union[InputType, Sequence[InputType]] | |
| PredType = Union[InstanceData, InstanceList] | |
| ImgType = Union[np.ndarray, Sequence[np.ndarray]] | |
| ConfigType = Union[Config, ConfigDict] | |
| ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] | |
| class Hand3DInferencer(BaseMMPoseInferencer): | |
| """The inferencer for 3D hand pose estimation. | |
| Args: | |
| model (str, optional): Pretrained 2D pose estimation algorithm. | |
| It's the path to the config file or the model name defined in | |
| metafile. For example, it could be: | |
| - model alias, e.g. ``'body'``, | |
| - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``, | |
| - config path | |
| Defaults to ``None``. | |
| weights (str, optional): Path to the checkpoint. If it is not | |
| specified and "model" is a model name of metafile, the weights | |
| will be loaded from metafile. Defaults to None. | |
| device (str, optional): Device to run inference. If None, the | |
| available device will be automatically used. Defaults to None. | |
| scope (str, optional): The scope of the model. Defaults to "mmpose". | |
| det_model (str, optional): Config path or alias of detection model. | |
| Defaults to None. | |
| det_weights (str, optional): Path to the checkpoints of detection | |
| model. Defaults to None. | |
| det_cat_ids (int or list[int], optional): Category id for | |
| detection model. Defaults to None. | |
| """ | |
| preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'} | |
| forward_kwargs: set = {'disable_rebase_keypoint'} | |
| visualize_kwargs: set = { | |
| 'return_vis', | |
| 'show', | |
| 'wait_time', | |
| 'draw_bbox', | |
| 'radius', | |
| 'thickness', | |
| 'kpt_thr', | |
| 'vis_out_dir', | |
| 'num_instances', | |
| } | |
| postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} | |
| def __init__(self, | |
| model: Union[ModelType, str], | |
| weights: Optional[str] = None, | |
| device: Optional[str] = None, | |
| scope: Optional[str] = 'mmpose', | |
| det_model: Optional[Union[ModelType, str]] = None, | |
| det_weights: Optional[str] = None, | |
| det_cat_ids: Optional[Union[int, Tuple]] = None, | |
| show_progress: bool = False) -> None: | |
| init_default_scope(scope) | |
| super().__init__( | |
| model=model, | |
| weights=weights, | |
| device=device, | |
| scope=scope, | |
| show_progress=show_progress) | |
| self.model = revert_sync_batchnorm(self.model) | |
| # assign dataset metainfo to self.visualizer | |
| self.visualizer.set_dataset_meta(self.model.dataset_meta) | |
| # initialize hand detector | |
| self._init_detector( | |
| det_model=det_model, | |
| det_weights=det_weights, | |
| det_cat_ids=det_cat_ids, | |
| device=device, | |
| ) | |
| self._video_input = False | |
| self._buffer = defaultdict(list) | |
| def preprocess_single(self, | |
| input: InputType, | |
| index: int, | |
| bbox_thr: float = 0.3, | |
| nms_thr: float = 0.3, | |
| bboxes: Union[List[List], List[np.ndarray], | |
| np.ndarray] = []): | |
| """Process a single input into a model-feedable format. | |
| Args: | |
| input (InputType): Input given by user. | |
| index (int): index of the input | |
| bbox_thr (float): threshold for bounding box detection. | |
| Defaults to 0.3. | |
| nms_thr (float): IoU threshold for bounding box NMS. | |
| Defaults to 0.3. | |
| Yields: | |
| Any: Data processed by the ``pipeline`` and ``collate_fn``. | |
| """ | |
| if isinstance(input, str): | |
| data_info = dict(img_path=input) | |
| else: | |
| data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0')) | |
| data_info.update(self.model.dataset_meta) | |
| if self.detector is not None: | |
| try: | |
| det_results = self.detector( | |
| input, return_datasamples=True)['predictions'] | |
| except ValueError: | |
| print_log( | |
| 'Support for mmpose and mmdet versions up to 3.1.0 ' | |
| 'will be discontinued in upcoming releases. To ' | |
| 'ensure ongoing compatibility, please upgrade to ' | |
| 'mmdet version 3.2.0 or later.', | |
| logger='current', | |
| level=logging.WARNING) | |
| det_results = self.detector( | |
| input, return_datasample=True)['predictions'] | |
| pred_instance = det_results[0].pred_instances.cpu().numpy() | |
| bboxes = np.concatenate( | |
| (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) | |
| label_mask = np.zeros(len(bboxes), dtype=np.uint8) | |
| for cat_id in self.det_cat_ids: | |
| label_mask = np.logical_or(label_mask, | |
| pred_instance.labels == cat_id) | |
| bboxes = bboxes[np.logical_and(label_mask, | |
| pred_instance.scores > bbox_thr)] | |
| bboxes = bboxes[nms(bboxes, nms_thr)] | |
| data_infos = [] | |
| if len(bboxes) > 0: | |
| for bbox in bboxes: | |
| inst = data_info.copy() | |
| inst['bbox'] = bbox[None, :4] | |
| inst['bbox_score'] = bbox[4:5] | |
| data_infos.append(self.pipeline(inst)) | |
| else: | |
| inst = data_info.copy() | |
| # get bbox from the image size | |
| if isinstance(input, str): | |
| input = mmcv.imread(input) | |
| h, w = input.shape[:2] | |
| inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32) | |
| inst['bbox_score'] = np.ones(1, dtype=np.float32) | |
| data_infos.append(self.pipeline(inst)) | |
| return data_infos | |
| def forward(self, | |
| inputs: Union[dict, tuple], | |
| disable_rebase_keypoint: bool = False): | |
| """Performs a forward pass through the model. | |
| Args: | |
| inputs (Union[dict, tuple]): The input data to be processed. Can | |
| be either a dictionary or a tuple. | |
| disable_rebase_keypoint (bool, optional): Flag to disable rebasing | |
| the height of the keypoints. Defaults to False. | |
| Returns: | |
| A list of data samples with prediction instances. | |
| """ | |
| data_samples = self.model.test_step(inputs) | |
| data_samples_2d = [] | |
| for idx, res in enumerate(data_samples): | |
| pred_instances = res.pred_instances | |
| keypoints = pred_instances.keypoints | |
| rel_root_depth = pred_instances.rel_root_depth | |
| scores = pred_instances.keypoint_scores | |
| hand_type = pred_instances.hand_type | |
| res_2d = PoseDataSample() | |
| gt_instances = res.gt_instances.clone() | |
| pred_instances = pred_instances.clone() | |
| res_2d.gt_instances = gt_instances | |
| res_2d.pred_instances = pred_instances | |
| # add relative root depth to left hand joints | |
| keypoints[:, 21:, 2] += rel_root_depth | |
| # set joint scores according to hand type | |
| scores[:, :21] *= hand_type[:, [0]] | |
| scores[:, 21:] *= hand_type[:, [1]] | |
| # normalize kpt score | |
| if scores.max() > 1: | |
| scores /= 255 | |
| res_2d.pred_instances.set_field(keypoints[..., :2].copy(), | |
| 'keypoints') | |
| # rotate the keypoint to make z-axis correspondent to height | |
| # for better visualization | |
| vis_R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) | |
| keypoints[..., :3] = keypoints[..., :3] @ vis_R | |
| # rebase height (z-axis) | |
| if not disable_rebase_keypoint: | |
| valid = scores > 0 | |
| keypoints[..., 2] -= np.min( | |
| keypoints[valid, 2], axis=-1, keepdims=True) | |
| data_samples[idx].pred_instances.keypoints = keypoints | |
| data_samples[idx].pred_instances.keypoint_scores = scores | |
| data_samples_2d.append(res_2d) | |
| data_samples = [merge_data_samples(data_samples)] | |
| data_samples_2d = merge_data_samples(data_samples_2d) | |
| self._buffer['pose2d_results'] = data_samples_2d | |
| return data_samples | |
| def visualize( | |
| self, | |
| inputs: list, | |
| preds: List[PoseDataSample], | |
| return_vis: bool = False, | |
| show: bool = False, | |
| draw_bbox: bool = False, | |
| wait_time: float = 0, | |
| radius: int = 3, | |
| thickness: int = 1, | |
| kpt_thr: float = 0.3, | |
| num_instances: int = 1, | |
| vis_out_dir: str = '', | |
| window_name: str = '', | |
| ) -> List[np.ndarray]: | |
| """Visualize predictions. | |
| Args: | |
| inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. | |
| preds (Any): Predictions of the model. | |
| return_vis (bool): Whether to return images with predicted results. | |
| show (bool): Whether to display the image in a popup window. | |
| Defaults to False. | |
| wait_time (float): The interval of show (ms). Defaults to 0 | |
| draw_bbox (bool): Whether to draw the bounding boxes. | |
| Defaults to False | |
| radius (int): Keypoint radius for visualization. Defaults to 3 | |
| thickness (int): Link thickness for visualization. Defaults to 1 | |
| kpt_thr (float): The threshold to visualize the keypoints. | |
| Defaults to 0.3 | |
| vis_out_dir (str, optional): Directory to save visualization | |
| results w/o predictions. If left as empty, no file will | |
| be saved. Defaults to ''. | |
| window_name (str, optional): Title of display window. | |
| window_close_event_handler (callable, optional): | |
| Returns: | |
| List[np.ndarray]: Visualization results. | |
| """ | |
| if (not return_vis) and (not show) and (not vis_out_dir): | |
| return | |
| if getattr(self, 'visualizer', None) is None: | |
| raise ValueError('Visualization needs the "visualizer" term' | |
| 'defined in the config, but got None.') | |
| self.visualizer.radius = radius | |
| self.visualizer.line_width = thickness | |
| results = [] | |
| for single_input, pred in zip(inputs, preds): | |
| if isinstance(single_input, str): | |
| img = mmcv.imread(single_input, channel_order='rgb') | |
| elif isinstance(single_input, np.ndarray): | |
| img = mmcv.bgr2rgb(single_input) | |
| else: | |
| raise ValueError('Unsupported input type: ' | |
| f'{type(single_input)}') | |
| img_name = os.path.basename(pred.metainfo['img_path']) | |
| # since visualization and inference utilize the same process, | |
| # the wait time is reduced when a video input is utilized, | |
| # thereby eliminating the issue of inference getting stuck. | |
| wait_time = 1e-5 if self._video_input else wait_time | |
| if num_instances < 0: | |
| num_instances = len(pred.pred_instances) | |
| visualization = self.visualizer.add_datasample( | |
| window_name, | |
| img, | |
| data_sample=pred, | |
| det_data_sample=self._buffer['pose2d_results'], | |
| draw_gt=False, | |
| draw_bbox=draw_bbox, | |
| show=show, | |
| wait_time=wait_time, | |
| convert_keypoint=False, | |
| axis_azimuth=-115, | |
| axis_limit=200, | |
| axis_elev=15, | |
| kpt_thr=kpt_thr, | |
| num_instances=num_instances) | |
| results.append(visualization) | |
| if vis_out_dir: | |
| self.save_visualization( | |
| visualization, | |
| vis_out_dir, | |
| img_name=img_name, | |
| ) | |
| if return_vis: | |
| return results | |
| else: | |
| return [] | |