Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from collections import defaultdict | |
| from typing import Any, Optional, Sequence, Union | |
| from mmengine.evaluator.evaluator import Evaluator | |
| from mmengine.evaluator.metric import BaseMetric | |
| from mmengine.structures import BaseDataElement | |
| from mmpose.datasets.datasets.utils import parse_pose_metainfo | |
| from mmpose.registry import DATASETS, EVALUATORS | |
| class MultiDatasetEvaluator(Evaluator): | |
| """Wrapper class to compose multiple :class:`BaseMetric` instances. | |
| Args: | |
| metrics (dict or BaseMetric or Sequence): The configs of metrics. | |
| datasets (Sequence[str]): The configs of datasets. | |
| """ | |
| def __init__( | |
| self, | |
| metrics: Union[dict, BaseMetric, Sequence], | |
| datasets: Sequence[dict], | |
| ): | |
| assert len(metrics) == len(datasets), 'the argument ' \ | |
| 'datasets should have same length as metrics' | |
| super().__init__(metrics) | |
| # Initialize metrics for each dataset | |
| metrics_dict = dict() | |
| for dataset, metric in zip(datasets, self.metrics): | |
| metainfo_file = DATASETS.module_dict[dataset['type']].METAINFO | |
| dataset_meta = parse_pose_metainfo(metainfo_file) | |
| metric.dataset_meta = dataset_meta | |
| dataset_name = dataset_meta['dataset_name'] | |
| metrics_dict[dataset_name] = metric | |
| self.metrics_dict = metrics_dict | |
| def dataset_meta(self) -> Optional[dict]: | |
| """Optional[dict]: Meta info of the dataset.""" | |
| return self._dataset_meta | |
| def dataset_meta(self, dataset_meta: dict) -> None: | |
| """Set the dataset meta info to the evaluator and it's metrics.""" | |
| self._dataset_meta = dataset_meta | |
| def process(self, | |
| data_samples: Sequence[BaseDataElement], | |
| data_batch: Optional[Any] = None): | |
| """Convert ``BaseDataSample`` to dict and invoke process method of each | |
| metric. | |
| Args: | |
| data_samples (Sequence[BaseDataElement]): predictions of the model, | |
| and the ground truth of the validation set. | |
| data_batch (Any, optional): A batch of data from the dataloader. | |
| """ | |
| _data_samples = defaultdict(list) | |
| _data_batch = dict( | |
| inputs=defaultdict(list), | |
| data_samples=defaultdict(list), | |
| ) | |
| for inputs, data_ds, data_sample in zip(data_batch['inputs'], | |
| data_batch['data_samples'], | |
| data_samples): | |
| if isinstance(data_sample, BaseDataElement): | |
| data_sample = data_sample.to_dict() | |
| assert isinstance(data_sample, dict) | |
| dataset_name = data_sample.get('dataset_name', | |
| self.dataset_meta['dataset_name']) | |
| _data_samples[dataset_name].append(data_sample) | |
| _data_batch['inputs'][dataset_name].append(inputs) | |
| _data_batch['data_samples'][dataset_name].append(data_ds) | |
| for dataset_name, metric in self.metrics_dict.items(): | |
| if dataset_name in _data_samples: | |
| data_batch = dict( | |
| inputs=_data_batch['inputs'][dataset_name], | |
| data_samples=_data_batch['data_samples'][dataset_name]) | |
| metric.process(data_batch, _data_samples[dataset_name]) | |
| else: | |
| continue | |