Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import inspect | |
| import logging | |
| import mimetypes | |
| import os | |
| from collections import defaultdict | |
| from typing import (Callable, Dict, Generator, Iterable, List, Optional, | |
| Sequence, Tuple, Union) | |
| import cv2 | |
| import mmcv | |
| import mmengine | |
| import numpy as np | |
| import torch.nn as nn | |
| from mmengine.config import Config, ConfigDict | |
| from mmengine.dataset import Compose | |
| from mmengine.fileio import (get_file_backend, isdir, join_path, | |
| list_dir_or_file) | |
| from mmengine.infer.infer import BaseInferencer, ModelType | |
| from mmengine.logging import print_log | |
| from mmengine.registry import init_default_scope | |
| from mmengine.runner.checkpoint import _load_checkpoint_to_model | |
| from mmengine.structures import InstanceData | |
| from mmengine.utils import mkdir_or_exist | |
| from rich.progress import track | |
| from mmpose.apis.inference import dataset_meta_from_config | |
| from mmpose.registry import DATASETS | |
| from mmpose.structures import PoseDataSample, split_instances | |
| from .utils import default_det_models | |
| try: | |
| from mmdet.apis.det_inferencer import DetInferencer | |
| has_mmdet = True | |
| except (ImportError, ModuleNotFoundError): | |
| has_mmdet = False | |
| InstanceList = List[InstanceData] | |
| InputType = Union[str, np.ndarray] | |
| InputsType = Union[InputType, Sequence[InputType]] | |
| PredType = Union[InstanceData, InstanceList] | |
| ImgType = Union[np.ndarray, Sequence[np.ndarray]] | |
| ConfigType = Union[Config, ConfigDict] | |
| ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] | |
| class BaseMMPoseInferencer(BaseInferencer): | |
| """The base class for MMPose inferencers.""" | |
| preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'} | |
| forward_kwargs: set = set() | |
| visualize_kwargs: set = { | |
| 'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness', | |
| 'kpt_thr', 'vis_out_dir', 'black_background' | |
| } | |
| postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} | |
| def __init__(self, | |
| model: Union[ModelType, str, None] = None, | |
| weights: Optional[str] = None, | |
| device: Optional[str] = None, | |
| scope: Optional[str] = None, | |
| show_progress: bool = False) -> None: | |
| super().__init__( | |
| model, weights, device, scope, show_progress=show_progress) | |
| def _init_detector( | |
| self, | |
| det_model: Optional[Union[ModelType, str]] = None, | |
| det_weights: Optional[str] = None, | |
| det_cat_ids: Optional[Union[int, Tuple]] = None, | |
| device: Optional[str] = None, | |
| ): | |
| object_type = DATASETS.get(self.cfg.dataset_type).__module__.split( | |
| 'datasets.')[-1].split('.')[0].lower() | |
| if det_model in ('whole_image', 'whole-image') or \ | |
| (det_model is None and | |
| object_type not in default_det_models): | |
| self.detector = None | |
| else: | |
| det_scope = 'mmdet' | |
| if det_model is None: | |
| det_info = default_det_models[object_type] | |
| det_model, det_weights, det_cat_ids = det_info[ | |
| 'model'], det_info['weights'], det_info['cat_ids'] | |
| elif os.path.exists(det_model): | |
| det_cfg = Config.fromfile(det_model) | |
| det_scope = det_cfg.default_scope | |
| if has_mmdet: | |
| det_kwargs = dict( | |
| model=det_model, | |
| weights=det_weights, | |
| device=device, | |
| scope=det_scope, | |
| ) | |
| # for compatibility with low version of mmdet | |
| if 'show_progress' in inspect.signature( | |
| DetInferencer).parameters: | |
| det_kwargs['show_progress'] = False | |
| self.detector = DetInferencer(**det_kwargs) | |
| else: | |
| raise RuntimeError( | |
| 'MMDetection (v3.0.0 or above) is required to build ' | |
| 'inferencers for top-down pose estimation models.') | |
| if isinstance(det_cat_ids, (tuple, list)): | |
| self.det_cat_ids = det_cat_ids | |
| else: | |
| self.det_cat_ids = (det_cat_ids, ) | |
| def _load_weights_to_model(self, model: nn.Module, | |
| checkpoint: Optional[dict], | |
| cfg: Optional[ConfigType]) -> None: | |
| """Loading model weights and meta information from cfg and checkpoint. | |
| Subclasses could override this method to load extra meta information | |
| from ``checkpoint`` and ``cfg`` to model. | |
| Args: | |
| model (nn.Module): Model to load weights and meta information. | |
| checkpoint (dict, optional): The loaded checkpoint. | |
| cfg (Config or ConfigDict, optional): The loaded config. | |
| """ | |
| if checkpoint is not None: | |
| _load_checkpoint_to_model(model, checkpoint) | |
| checkpoint_meta = checkpoint.get('meta', {}) | |
| # save the dataset_meta in the model for convenience | |
| if 'dataset_meta' in checkpoint_meta: | |
| # mmpose 1.x | |
| model.dataset_meta = checkpoint_meta['dataset_meta'] | |
| else: | |
| print_log( | |
| 'dataset_meta are not saved in the checkpoint\'s ' | |
| 'meta data, load via config.', | |
| logger='current', | |
| level=logging.WARNING) | |
| model.dataset_meta = dataset_meta_from_config( | |
| cfg, dataset_mode='train') | |
| else: | |
| print_log( | |
| 'Checkpoint is not loaded, and the inference ' | |
| 'result is calculated by the randomly initialized ' | |
| 'model!', | |
| logger='current', | |
| level=logging.WARNING) | |
| model.dataset_meta = dataset_meta_from_config( | |
| cfg, dataset_mode='train') | |
| def _inputs_to_list(self, inputs: InputsType) -> Iterable: | |
| """Preprocess the inputs to a list. | |
| Preprocess inputs to a list according to its type: | |
| - list or tuple: return inputs | |
| - str: | |
| - Directory path: return all files in the directory | |
| - other cases: return a list containing the string. The string | |
| could be a path to file, a url or other types of string | |
| according to the task. | |
| Args: | |
| inputs (InputsType): Inputs for the inferencer. | |
| Returns: | |
| list: List of input for the :meth:`preprocess`. | |
| """ | |
| self._video_input = False | |
| if isinstance(inputs, str): | |
| backend = get_file_backend(inputs) | |
| if hasattr(backend, 'isdir') and isdir(inputs): | |
| # Backends like HttpsBackend do not implement `isdir`, so only | |
| # those backends that implement `isdir` could accept the | |
| # inputs as a directory | |
| filepath_list = [ | |
| join_path(inputs, fname) | |
| for fname in list_dir_or_file(inputs, list_dir=False) | |
| ] | |
| inputs = [] | |
| for filepath in filepath_list: | |
| input_type = mimetypes.guess_type(filepath)[0].split( | |
| '/')[0] | |
| if input_type == 'image': | |
| inputs.append(filepath) | |
| inputs.sort() | |
| else: | |
| # if inputs is a path to a video file, it will be converted | |
| # to a list containing separated frame filenames | |
| input_type = mimetypes.guess_type(inputs)[0].split('/')[0] | |
| if input_type == 'video': | |
| self._video_input = True | |
| video = mmcv.VideoReader(inputs) | |
| self.video_info = dict( | |
| fps=video.fps, | |
| name=os.path.basename(inputs), | |
| writer=None, | |
| width=video.width, | |
| height=video.height, | |
| predictions=[]) | |
| inputs = video | |
| elif input_type == 'image': | |
| inputs = [inputs] | |
| else: | |
| raise ValueError(f'Expected input to be an image, video, ' | |
| f'or folder, but received {inputs} of ' | |
| f'type {input_type}.') | |
| elif isinstance(inputs, np.ndarray): | |
| inputs = [inputs] | |
| return inputs | |
| def _get_webcam_inputs(self, inputs: str) -> Generator: | |
| """Sets up and returns a generator function that reads frames from a | |
| webcam input. The generator function returns a new frame each time it | |
| is iterated over. | |
| Args: | |
| inputs (str): A string describing the webcam input, in the format | |
| "webcam:id". | |
| Returns: | |
| A generator function that yields frames from the webcam input. | |
| Raises: | |
| ValueError: If the inputs string is not in the expected format. | |
| """ | |
| # Ensure the inputs string is in the expected format. | |
| inputs = inputs.lower() | |
| assert inputs.startswith('webcam'), f'Expected input to start with ' \ | |
| f'"webcam", but got "{inputs}"' | |
| # Parse the camera ID from the inputs string. | |
| inputs_ = inputs.split(':') | |
| if len(inputs_) == 1: | |
| camera_id = 0 | |
| elif len(inputs_) == 2 and str.isdigit(inputs_[1]): | |
| camera_id = int(inputs_[1]) | |
| else: | |
| raise ValueError( | |
| f'Expected webcam input to have format "webcam:id", ' | |
| f'but got "{inputs}"') | |
| # Attempt to open the video capture object. | |
| vcap = cv2.VideoCapture(camera_id) | |
| if not vcap.isOpened(): | |
| print_log( | |
| f'Cannot open camera (ID={camera_id})', | |
| logger='current', | |
| level=logging.WARNING) | |
| return [] | |
| # Set video input flag and metadata. | |
| self._video_input = True | |
| (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.') | |
| if int(major_ver) < 3: | |
| fps = vcap.get(cv2.cv.CV_CAP_PROP_FPS) | |
| width = vcap.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH) | |
| height = vcap.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT) | |
| else: | |
| fps = vcap.get(cv2.CAP_PROP_FPS) | |
| width = vcap.get(cv2.CAP_PROP_FRAME_WIDTH) | |
| height = vcap.get(cv2.CAP_PROP_FRAME_HEIGHT) | |
| self.video_info = dict( | |
| fps=fps, | |
| name='webcam.mp4', | |
| writer=None, | |
| width=width, | |
| height=height, | |
| predictions=[]) | |
| def _webcam_reader() -> Generator: | |
| while True: | |
| if cv2.waitKey(5) & 0xFF == 27: | |
| vcap.release() | |
| break | |
| ret_val, frame = vcap.read() | |
| if not ret_val: | |
| break | |
| yield frame | |
| return _webcam_reader() | |
| def _init_pipeline(self, cfg: ConfigType) -> Callable: | |
| """Initialize the test pipeline. | |
| Args: | |
| cfg (ConfigType): model config path or dict | |
| Returns: | |
| A pipeline to handle various input data, such as ``str``, | |
| ``np.ndarray``. The returned pipeline will be used to process | |
| a single data. | |
| """ | |
| scope = cfg.get('default_scope', 'mmpose') | |
| if scope is not None: | |
| init_default_scope(scope) | |
| return Compose(cfg.test_dataloader.dataset.pipeline) | |
| def update_model_visualizer_settings(self, **kwargs): | |
| """Update the settings of models and visualizer according to inference | |
| arguments.""" | |
| pass | |
| def preprocess(self, | |
| inputs: InputsType, | |
| batch_size: int = 1, | |
| bboxes: Optional[List] = None, | |
| bbox_thr: float = 0.3, | |
| nms_thr: float = 0.3, | |
| **kwargs): | |
| """Process the inputs into a model-feedable format. | |
| Args: | |
| inputs (InputsType): Inputs given by user. | |
| batch_size (int): batch size. Defaults to 1. | |
| bbox_thr (float): threshold for bounding box detection. | |
| Defaults to 0.3. | |
| nms_thr (float): IoU threshold for bounding box NMS. | |
| Defaults to 0.3. | |
| Yields: | |
| Any: Data processed by the ``pipeline`` and ``collate_fn``. | |
| List[str or np.ndarray]: List of original inputs in the batch | |
| """ | |
| # One-stage pose estimators perform prediction filtering within the | |
| # head's `predict` method. Here, we set the arguments for filtering | |
| if self.cfg.model.type == 'BottomupPoseEstimator': | |
| # 1. init with default arguments | |
| test_cfg = self.model.head.test_cfg.copy() | |
| # 2. update the score_thr and nms_thr in the test_cfg of the head | |
| if 'score_thr' in test_cfg: | |
| test_cfg['score_thr'] = bbox_thr | |
| if 'nms_thr' in test_cfg: | |
| test_cfg['nms_thr'] = nms_thr | |
| self.model.test_cfg = test_cfg | |
| for i, input in enumerate(inputs): | |
| bbox = bboxes[i] if bboxes else [] | |
| data_infos = self.preprocess_single( | |
| input, | |
| index=i, | |
| bboxes=bbox, | |
| bbox_thr=bbox_thr, | |
| nms_thr=nms_thr, | |
| **kwargs) | |
| # only supports inference with batch size 1 | |
| yield self.collate_fn(data_infos), [input] | |
| def __call__( | |
| self, | |
| inputs: InputsType, | |
| return_datasamples: bool = False, | |
| batch_size: int = 1, | |
| out_dir: Optional[str] = None, | |
| **kwargs, | |
| ) -> dict: | |
| """Call the inferencer. | |
| Args: | |
| inputs (InputsType): Inputs for the inferencer. | |
| return_datasamples (bool): Whether to return results as | |
| :obj:`BaseDataElement`. Defaults to False. | |
| batch_size (int): Batch size. Defaults to 1. | |
| out_dir (str, optional): directory to save visualization | |
| results and predictions. Will be overoden if vis_out_dir or | |
| pred_out_dir are given. Defaults to None | |
| **kwargs: Key words arguments passed to :meth:`preprocess`, | |
| :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. | |
| Each key in kwargs should be in the corresponding set of | |
| ``preprocess_kwargs``, ``forward_kwargs``, | |
| ``visualize_kwargs`` and ``postprocess_kwargs``. | |
| Returns: | |
| dict: Inference and visualization results. | |
| """ | |
| if out_dir is not None: | |
| if 'vis_out_dir' not in kwargs: | |
| kwargs['vis_out_dir'] = f'{out_dir}/visualizations' | |
| if 'pred_out_dir' not in kwargs: | |
| kwargs['pred_out_dir'] = f'{out_dir}/predictions' | |
| ( | |
| preprocess_kwargs, | |
| forward_kwargs, | |
| visualize_kwargs, | |
| postprocess_kwargs, | |
| ) = self._dispatch_kwargs(**kwargs) | |
| self.update_model_visualizer_settings(**kwargs) | |
| # preprocessing | |
| if isinstance(inputs, str) and inputs.startswith('webcam'): | |
| inputs = self._get_webcam_inputs(inputs) | |
| batch_size = 1 | |
| if not visualize_kwargs.get('show', False): | |
| print_log( | |
| 'The display mode is closed when using webcam ' | |
| 'input. It will be turned on automatically.', | |
| logger='current', | |
| level=logging.WARNING) | |
| visualize_kwargs['show'] = True | |
| else: | |
| inputs = self._inputs_to_list(inputs) | |
| # check the compatibility between inputs/outputs | |
| if not self._video_input and len(inputs) > 0: | |
| vis_out_dir = visualize_kwargs.get('vis_out_dir', None) | |
| if vis_out_dir is not None: | |
| _, file_extension = os.path.splitext(vis_out_dir) | |
| assert not file_extension, f'the argument `vis_out_dir` ' \ | |
| f'should be a folder while the input contains multiple ' \ | |
| f'images, but got {vis_out_dir}' | |
| if 'bbox_thr' in self.forward_kwargs: | |
| forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1) | |
| inputs = self.preprocess( | |
| inputs, batch_size=batch_size, **preprocess_kwargs) | |
| preds = [] | |
| for proc_inputs, ori_inputs in (track(inputs, description='Inference') | |
| if self.show_progress else inputs): | |
| preds = self.forward(proc_inputs, **forward_kwargs) | |
| visualization = self.visualize(ori_inputs, preds, | |
| **visualize_kwargs) | |
| results = self.postprocess( | |
| preds, | |
| visualization, | |
| return_datasamples=return_datasamples, | |
| **postprocess_kwargs) | |
| yield results | |
| if self._video_input: | |
| self._finalize_video_processing( | |
| postprocess_kwargs.get('pred_out_dir', '')) | |
| # In 3D Inferencers, some intermediate results (e.g. 2d keypoints) | |
| # will be temporarily stored in `self._buffer`. It's essential to | |
| # clear this information to prevent any interference with subsequent | |
| # inferences. | |
| if hasattr(self, '_buffer'): | |
| self._buffer.clear() | |
| def visualize(self, | |
| inputs: list, | |
| preds: List[PoseDataSample], | |
| return_vis: bool = False, | |
| show: bool = False, | |
| draw_bbox: bool = False, | |
| wait_time: float = 0, | |
| radius: int = 3, | |
| thickness: int = 1, | |
| kpt_thr: float = 0.3, | |
| vis_out_dir: str = '', | |
| window_name: str = '', | |
| black_background: bool = False, | |
| **kwargs) -> List[np.ndarray]: | |
| """Visualize predictions. | |
| Args: | |
| inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. | |
| preds (Any): Predictions of the model. | |
| return_vis (bool): Whether to return images with predicted results. | |
| show (bool): Whether to display the image in a popup window. | |
| Defaults to False. | |
| wait_time (float): The interval of show (ms). Defaults to 0 | |
| draw_bbox (bool): Whether to draw the bounding boxes. | |
| Defaults to False | |
| radius (int): Keypoint radius for visualization. Defaults to 3 | |
| thickness (int): Link thickness for visualization. Defaults to 1 | |
| kpt_thr (float): The threshold to visualize the keypoints. | |
| Defaults to 0.3 | |
| vis_out_dir (str, optional): Directory to save visualization | |
| results w/o predictions. If left as empty, no file will | |
| be saved. Defaults to ''. | |
| window_name (str, optional): Title of display window. | |
| black_background (bool, optional): Whether to plot keypoints on a | |
| black image instead of the input image. Defaults to False. | |
| Returns: | |
| List[np.ndarray]: Visualization results. | |
| """ | |
| if (not return_vis) and (not show) and (not vis_out_dir): | |
| return | |
| if getattr(self, 'visualizer', None) is None: | |
| raise ValueError('Visualization needs the "visualizer" term' | |
| 'defined in the config, but got None.') | |
| self.visualizer.radius = radius | |
| self.visualizer.line_width = thickness | |
| results = [] | |
| for single_input, pred in zip(inputs, preds): | |
| if isinstance(single_input, str): | |
| img = mmcv.imread(single_input, channel_order='rgb') | |
| elif isinstance(single_input, np.ndarray): | |
| img = mmcv.bgr2rgb(single_input) | |
| else: | |
| raise ValueError('Unsupported input type: ' | |
| f'{type(single_input)}') | |
| if black_background: | |
| img = img * 0 | |
| img_name = os.path.basename(pred.metainfo['img_path']) | |
| window_name = window_name if window_name else img_name | |
| # since visualization and inference utilize the same process, | |
| # the wait time is reduced when a video input is utilized, | |
| # thereby eliminating the issue of inference getting stuck. | |
| wait_time = 1e-5 if self._video_input else wait_time | |
| visualization = self.visualizer.add_datasample( | |
| window_name, | |
| img, | |
| pred, | |
| draw_gt=False, | |
| draw_bbox=draw_bbox, | |
| show=show, | |
| wait_time=wait_time, | |
| kpt_thr=kpt_thr, | |
| **kwargs) | |
| results.append(visualization) | |
| if vis_out_dir: | |
| self.save_visualization( | |
| visualization, | |
| vis_out_dir, | |
| img_name=img_name, | |
| ) | |
| if return_vis: | |
| return results | |
| else: | |
| return [] | |
| def save_visualization(self, visualization, vis_out_dir, img_name=None): | |
| out_img = mmcv.rgb2bgr(visualization) | |
| _, file_extension = os.path.splitext(vis_out_dir) | |
| if file_extension: | |
| dir_name = os.path.dirname(vis_out_dir) | |
| file_name = os.path.basename(vis_out_dir) | |
| else: | |
| dir_name = vis_out_dir | |
| file_name = None | |
| mkdir_or_exist(dir_name) | |
| if self._video_input: | |
| if self.video_info['writer'] is None: | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| if file_name is None: | |
| file_name = os.path.basename(self.video_info['name']) | |
| out_file = join_path(dir_name, file_name) | |
| self.video_info['output_file'] = out_file | |
| self.video_info['writer'] = cv2.VideoWriter( | |
| out_file, fourcc, self.video_info['fps'], | |
| (visualization.shape[1], visualization.shape[0])) | |
| self.video_info['writer'].write(out_img) | |
| else: | |
| if file_name is None: | |
| file_name = img_name if img_name else 'visualization.jpg' | |
| out_file = join_path(dir_name, file_name) | |
| mmcv.imwrite(out_img, out_file) | |
| print_log( | |
| f'the output image has been saved at {out_file}', | |
| logger='current', | |
| level=logging.INFO) | |
| def postprocess( | |
| self, | |
| preds: List[PoseDataSample], | |
| visualization: List[np.ndarray], | |
| return_datasample=None, | |
| return_datasamples=False, | |
| pred_out_dir: str = '', | |
| ) -> dict: | |
| """Process the predictions and visualization results from ``forward`` | |
| and ``visualize``. | |
| This method should be responsible for the following tasks: | |
| 1. Convert datasamples into a json-serializable dict if needed. | |
| 2. Pack the predictions and visualization results and return them. | |
| 3. Dump or log the predictions. | |
| Args: | |
| preds (List[Dict]): Predictions of the model. | |
| visualization (np.ndarray): Visualized predictions. | |
| return_datasamples (bool): Whether to return results as | |
| datasamples. Defaults to False. | |
| pred_out_dir (str): Directory to save the inference results w/o | |
| visualization. If left as empty, no file will be saved. | |
| Defaults to ''. | |
| Returns: | |
| dict: Inference and visualization results with key ``predictions`` | |
| and ``visualization`` | |
| - ``visualization (Any)``: Returned by :meth:`visualize` | |
| - ``predictions`` (dict or DataSample): Returned by | |
| :meth:`forward` and processed in :meth:`postprocess`. | |
| If ``return_datasamples=False``, it usually should be a | |
| json-serializable dict containing only basic data elements such | |
| as strings and numbers. | |
| """ | |
| if return_datasample is not None: | |
| print_log( | |
| 'The `return_datasample` argument is deprecated ' | |
| 'and will be removed in future versions. Please ' | |
| 'use `return_datasamples`.', | |
| logger='current', | |
| level=logging.WARNING) | |
| return_datasamples = return_datasample | |
| result_dict = defaultdict(list) | |
| result_dict['visualization'] = visualization | |
| for pred in preds: | |
| if not return_datasamples: | |
| # convert datasamples to list of instance predictions | |
| pred = split_instances(pred.pred_instances) | |
| result_dict['predictions'].append(pred) | |
| if pred_out_dir != '': | |
| for pred, data_sample in zip(result_dict['predictions'], preds): | |
| if self._video_input: | |
| # For video or webcam input, predictions for each frame | |
| # are gathered in the 'predictions' key of 'video_info' | |
| # dictionary. All frame predictions are then stored into | |
| # a single file after processing all frames. | |
| self.video_info['predictions'].append(pred) | |
| else: | |
| # For non-video inputs, predictions are stored in separate | |
| # JSON files. The filename is determined by the basename | |
| # of the input image path with a '.json' extension. The | |
| # predictions are then dumped into this file. | |
| fname = os.path.splitext( | |
| os.path.basename( | |
| data_sample.metainfo['img_path']))[0] + '.json' | |
| mmengine.dump( | |
| pred, join_path(pred_out_dir, fname), indent=' ') | |
| return result_dict | |
| def _finalize_video_processing( | |
| self, | |
| pred_out_dir: str = '', | |
| ): | |
| """Finalize video processing by releasing the video writer and saving | |
| predictions to a file. | |
| This method should be called after completing the video processing. It | |
| releases the video writer, if it exists, and saves the predictions to a | |
| JSON file if a prediction output directory is provided. | |
| """ | |
| # Release the video writer if it exists | |
| if self.video_info['writer'] is not None: | |
| out_file = self.video_info['output_file'] | |
| print_log( | |
| f'the output video has been saved at {out_file}', | |
| logger='current', | |
| level=logging.INFO) | |
| self.video_info['writer'].release() | |
| # Save predictions | |
| if pred_out_dir: | |
| fname = os.path.splitext( | |
| os.path.basename(self.video_info['name']))[0] + '.json' | |
| predictions = [ | |
| dict(frame_id=i, instances=pred) | |
| for i, pred in enumerate(self.video_info['predictions']) | |
| ] | |
| mmengine.dump( | |
| predictions, join_path(pred_out_dir, fname), indent=' ') | |