|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import importlib |
|
|
import os.path as osp |
|
|
import re |
|
|
import warnings |
|
|
from abc import ABCMeta, abstractmethod |
|
|
from datetime import datetime |
|
|
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence, |
|
|
Tuple, Union) |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from rich.progress import track |
|
|
|
|
|
from mmengine.config import Config, ConfigDict |
|
|
from mmengine.config.utils import MODULE2PACKAGE |
|
|
from mmengine.dataset import pseudo_collate |
|
|
from mmengine.device import get_device |
|
|
from mmengine.fileio import (get_file_backend, isdir, join_path, |
|
|
list_dir_or_file, load) |
|
|
from mmengine.logging import print_log |
|
|
from mmengine.registry import FUNCTIONS, MODELS, VISUALIZERS, DefaultScope |
|
|
from mmengine.runner.checkpoint import (_load_checkpoint, |
|
|
_load_checkpoint_to_model) |
|
|
from mmengine.structures import InstanceData |
|
|
from mmengine.visualization import Visualizer |
|
|
|
|
|
InstanceList = List[InstanceData] |
|
|
InputType = Union[str, np.ndarray, torch.Tensor] |
|
|
InputsType = Union[InputType, Sequence[InputType]] |
|
|
ImgType = Union[np.ndarray, Sequence[np.ndarray]] |
|
|
ResType = Union[Dict, List[Dict]] |
|
|
ConfigType = Union[Config, ConfigDict] |
|
|
ModelType = Union[dict, ConfigType, str] |
|
|
|
|
|
|
|
|
class InferencerMeta(ABCMeta): |
|
|
"""Check the legality of the inferencer. |
|
|
|
|
|
All Inferencers should not define duplicated keys for |
|
|
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` and |
|
|
``postprocess_kwargs``. |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
assert isinstance(self.preprocess_kwargs, set) |
|
|
assert isinstance(self.forward_kwargs, set) |
|
|
assert isinstance(self.visualize_kwargs, set) |
|
|
assert isinstance(self.postprocess_kwargs, set) |
|
|
|
|
|
all_kwargs = ( |
|
|
self.preprocess_kwargs | self.forward_kwargs |
|
|
| self.visualize_kwargs | self.postprocess_kwargs) |
|
|
|
|
|
assert len(all_kwargs) == ( |
|
|
len(self.preprocess_kwargs) + len(self.forward_kwargs) + |
|
|
len(self.visualize_kwargs) + len(self.postprocess_kwargs)), ( |
|
|
f'Class define error! {self.__name__} should not ' |
|
|
'define duplicated keys for `preprocess_kwargs`, ' |
|
|
'`forward_kwargs`, `visualize_kwargs` and ' |
|
|
'`postprocess_kwargs` are not allowed.') |
|
|
|
|
|
|
|
|
class BaseInferencer(metaclass=InferencerMeta): |
|
|
"""Base inferencer for downstream tasks. |
|
|
|
|
|
The BaseInferencer provides the standard workflow for inference as follows: |
|
|
|
|
|
1. Preprocess the input data by :meth:`preprocess`. |
|
|
2. Forward the data to the model by :meth:`forward`. ``BaseInferencer`` |
|
|
assumes the model inherits from :class:`mmengine.models.BaseModel` and |
|
|
will call `model.test_step` in :meth:`forward` by default. |
|
|
3. Visualize the results by :meth:`visualize`. |
|
|
4. Postprocess and return the results by :meth:`postprocess`. |
|
|
|
|
|
When we call the subclasses inherited from BaseInferencer (not overriding |
|
|
``__call__``), the workflow will be executed in order. |
|
|
|
|
|
All subclasses of BaseInferencer could define the following class |
|
|
attributes for customization: |
|
|
|
|
|
- ``preprocess_kwargs``: The keys of the kwargs that will be passed to |
|
|
:meth:`preprocess`. |
|
|
- ``forward_kwargs``: The keys of the kwargs that will be passed to |
|
|
:meth:`forward` |
|
|
- ``visualize_kwargs``: The keys of the kwargs that will be passed to |
|
|
:meth:`visualize` |
|
|
- ``postprocess_kwargs``: The keys of the kwargs that will be passed to |
|
|
:meth:`postprocess` |
|
|
|
|
|
All attributes mentioned above should be a ``set`` of keys (strings), |
|
|
and each key should not be duplicated. Actually, :meth:`__call__` will |
|
|
dispatch all the arguments to the corresponding methods according to the |
|
|
``xxx_kwargs`` mentioned above, therefore, the key in sets should |
|
|
be unique to avoid ambiguous dispatching. |
|
|
|
|
|
Warning: |
|
|
If subclasses defined the class attributes mentioned above with |
|
|
duplicated keys, an ``AssertionError`` will be raised during import |
|
|
process. |
|
|
|
|
|
Subclasses inherited from ``BaseInferencer`` should implement |
|
|
:meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`: |
|
|
|
|
|
- _init_pipeline: Return a callable object to preprocess the input data. |
|
|
- visualize: Visualize the results returned by :meth:`forward`. |
|
|
- postprocess: Postprocess the results returned by :meth:`forward` and |
|
|
:meth:`visualize`. |
|
|
|
|
|
Args: |
|
|
model (str, optional): Path to the config file or the model name |
|
|
defined in metafile. Take the `mmdet metafile <https://github.com/open-mmlab/mmdetection/blob/master/configs/retinanet/metafile.yml>`_ |
|
|
as an example, the `model` could be `retinanet_r18_fpn_1x_coco` or |
|
|
its alias. If model is not specified, user must provide the |
|
|
`weights` saved by MMEngine which contains the config string. |
|
|
Defaults to None. |
|
|
weights (str, optional): Path to the checkpoint. If it is not specified |
|
|
and model is a model name of metafile, the weights will be loaded |
|
|
from metafile. Defaults to None. |
|
|
device (str, optional): Device to run inference. If None, the available |
|
|
device will be automatically used. Defaults to None. |
|
|
scope (str, optional): The scope of the model. Defaults to None. |
|
|
show_progress (bool): Control whether to display the progress bar during |
|
|
the inference process. Defaults to True. |
|
|
`New in version 0.7.4.` |
|
|
|
|
|
Note: |
|
|
Since ``Inferencer`` could be used to infer batch data, |
|
|
`collate_fn` should be defined. If `collate_fn` is not defined in config |
|
|
file, the `collate_fn` will be `pseudo_collate` by default. |
|
|
""" |
|
|
|
|
|
preprocess_kwargs: set = set() |
|
|
forward_kwargs: set = set() |
|
|
visualize_kwargs: set = set() |
|
|
postprocess_kwargs: set = set() |
|
|
|
|
|
def __init__(self, |
|
|
model: Union[ModelType, str, None] = None, |
|
|
weights: Optional[str] = None, |
|
|
device: Optional[str] = None, |
|
|
scope: Optional[str] = None, |
|
|
show_progress: bool = True) -> None: |
|
|
if scope is None: |
|
|
default_scope = DefaultScope.get_current_instance() |
|
|
if default_scope is not None: |
|
|
scope = default_scope.scope_name |
|
|
self.scope = scope |
|
|
|
|
|
cfg: ConfigType |
|
|
if isinstance(model, str): |
|
|
if osp.isfile(model): |
|
|
cfg = Config.fromfile(model) |
|
|
else: |
|
|
|
|
|
|
|
|
cfg, _weights = self._load_model_from_metafile(model) |
|
|
if weights is None: |
|
|
weights = _weights |
|
|
elif isinstance(model, (Config, ConfigDict)): |
|
|
cfg = copy.deepcopy(model) |
|
|
elif isinstance(model, dict): |
|
|
cfg = copy.deepcopy(ConfigDict(model)) |
|
|
elif model is None: |
|
|
if weights is None: |
|
|
raise ValueError( |
|
|
'If model is None, the weights must be specified since ' |
|
|
'the config needs to be loaded from the weights') |
|
|
cfg = ConfigDict() |
|
|
else: |
|
|
raise TypeError('model must be a filepath or any ConfigType' |
|
|
f'object, but got {type(model)}') |
|
|
|
|
|
if device is None: |
|
|
device = get_device() |
|
|
|
|
|
self.model = self._init_model(cfg, weights, device) |
|
|
self.pipeline = self._init_pipeline(cfg) |
|
|
self.collate_fn = self._init_collate(cfg) |
|
|
self.visualizer = self._init_visualizer(cfg) |
|
|
self.cfg = cfg |
|
|
self.show_progress = show_progress |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
inputs: InputsType, |
|
|
return_datasamples: bool = False, |
|
|
batch_size: int = 1, |
|
|
**kwargs, |
|
|
) -> dict: |
|
|
"""Call the inferencer. |
|
|
|
|
|
Args: |
|
|
inputs (InputsType): Inputs for the inferencer. |
|
|
return_datasamples (bool): Whether to return results as |
|
|
:obj:`BaseDataElement`. Defaults to False. |
|
|
batch_size (int): Batch size. Defaults to 1. |
|
|
**kwargs: Key words arguments passed to :meth:`preprocess`, |
|
|
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`. |
|
|
Each key in kwargs should be in the corresponding set of |
|
|
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` |
|
|
and ``postprocess_kwargs``. |
|
|
|
|
|
Returns: |
|
|
dict: Inference and visualization results. |
|
|
""" |
|
|
( |
|
|
preprocess_kwargs, |
|
|
forward_kwargs, |
|
|
visualize_kwargs, |
|
|
postprocess_kwargs, |
|
|
) = self._dispatch_kwargs(**kwargs) |
|
|
|
|
|
ori_inputs = self._inputs_to_list(inputs) |
|
|
inputs = self.preprocess( |
|
|
ori_inputs, batch_size=batch_size, **preprocess_kwargs) |
|
|
preds = [] |
|
|
for data in (track(inputs, description='Inference') |
|
|
if self.show_progress else inputs): |
|
|
preds.extend(self.forward(data, **forward_kwargs)) |
|
|
visualization = self.visualize( |
|
|
ori_inputs, preds, |
|
|
**visualize_kwargs) |
|
|
results = self.postprocess(preds, visualization, return_datasamples, |
|
|
**postprocess_kwargs) |
|
|
return results |
|
|
|
|
|
def _inputs_to_list(self, inputs: InputsType) -> list: |
|
|
"""Preprocess the inputs to a list. |
|
|
|
|
|
Preprocess inputs to a list according to its type: |
|
|
|
|
|
- list or tuple: return inputs |
|
|
- str: |
|
|
- Directory path: return all files in the directory |
|
|
- other cases: return a list containing the string. The string |
|
|
could be a path to file, a url or other types of string according |
|
|
to the task. |
|
|
|
|
|
Args: |
|
|
inputs (InputsType): Inputs for the inferencer. |
|
|
|
|
|
Returns: |
|
|
list: List of input for the :meth:`preprocess`. |
|
|
""" |
|
|
if isinstance(inputs, str): |
|
|
backend = get_file_backend(inputs) |
|
|
if hasattr(backend, 'isdir') and isdir(inputs): |
|
|
|
|
|
|
|
|
|
|
|
filename_list = list_dir_or_file(inputs, list_dir=False) |
|
|
inputs = [ |
|
|
join_path(inputs, filename) for filename in filename_list |
|
|
] |
|
|
|
|
|
if not isinstance(inputs, (list, tuple)): |
|
|
inputs = [inputs] |
|
|
|
|
|
return list(inputs) |
|
|
|
|
|
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): |
|
|
"""Process the inputs into a model-feedable format. |
|
|
|
|
|
Customize your preprocess by overriding this method. Preprocess should |
|
|
return an iterable object, of which each item will be used as the |
|
|
input of ``model.test_step``. |
|
|
|
|
|
``BaseInferencer.preprocess`` will return an iterable chunked data, |
|
|
which will be used in __call__ like this: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
def __call__(self, inputs, batch_size=1, **kwargs): |
|
|
chunked_data = self.preprocess(inputs, batch_size, **kwargs) |
|
|
for batch in chunked_data: |
|
|
preds = self.forward(batch, **kwargs) |
|
|
|
|
|
Args: |
|
|
inputs (InputsType): Inputs given by user. |
|
|
batch_size (int): batch size. Defaults to 1. |
|
|
|
|
|
Yields: |
|
|
Any: Data processed by the ``pipeline`` and ``collate_fn``. |
|
|
""" |
|
|
chunked_data = self._get_chunk_data( |
|
|
map(self.pipeline, inputs), batch_size) |
|
|
yield from map(self.collate_fn, chunked_data) |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, inputs: Union[dict, tuple], **kwargs) -> Any: |
|
|
"""Feed the inputs to the model.""" |
|
|
return self.model.test_step(inputs) |
|
|
|
|
|
@abstractmethod |
|
|
def visualize(self, |
|
|
inputs: list, |
|
|
preds: Any, |
|
|
show: bool = False, |
|
|
**kwargs) -> List[np.ndarray]: |
|
|
"""Visualize predictions. |
|
|
|
|
|
Customize your visualization by overriding this method. visualize |
|
|
should return visualization results, which could be np.ndarray or any |
|
|
other objects. |
|
|
|
|
|
Args: |
|
|
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. |
|
|
preds (Any): Predictions of the model. |
|
|
show (bool): Whether to display the image in a popup window. |
|
|
Defaults to False. |
|
|
|
|
|
Returns: |
|
|
List[np.ndarray]: Visualization results. |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
def postprocess( |
|
|
self, |
|
|
preds: Any, |
|
|
visualization: List[np.ndarray], |
|
|
return_datasample=False, |
|
|
**kwargs, |
|
|
) -> dict: |
|
|
"""Process the predictions and visualization results from ``forward`` |
|
|
and ``visualize``. |
|
|
|
|
|
This method should be responsible for the following tasks: |
|
|
|
|
|
1. Convert datasamples into a json-serializable dict if needed. |
|
|
2. Pack the predictions and visualization results and return them. |
|
|
3. Dump or log the predictions. |
|
|
|
|
|
Customize your postprocess by overriding this method. Make sure |
|
|
``postprocess`` will return a dict with visualization results and |
|
|
inference results. |
|
|
|
|
|
Args: |
|
|
preds (List[Dict]): Predictions of the model. |
|
|
visualization (np.ndarray): Visualized predictions. |
|
|
return_datasample (bool): Whether to return results as datasamples. |
|
|
Defaults to False. |
|
|
|
|
|
Returns: |
|
|
dict: Inference and visualization results with key ``predictions`` |
|
|
and ``visualization`` |
|
|
|
|
|
- ``visualization (Any)``: Returned by :meth:`visualize` |
|
|
- ``predictions`` (dict or DataSample): Returned by |
|
|
:meth:`forward` and processed in :meth:`postprocess`. |
|
|
If ``return_datasample=False``, it usually should be a |
|
|
json-serializable dict containing only basic data elements such |
|
|
as strings and numbers. |
|
|
""" |
|
|
|
|
|
def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]: |
|
|
"""Load config and weights from metafile. |
|
|
|
|
|
Args: |
|
|
model (str): model name defined in metafile. |
|
|
|
|
|
Returns: |
|
|
Tuple[Config, str]: Loaded Config and weights path defined in |
|
|
metafile. |
|
|
""" |
|
|
model = model.lower() |
|
|
|
|
|
assert self.scope is not None, ( |
|
|
'scope should be initialized if you want ' |
|
|
'to load config from metafile.') |
|
|
assert self.scope in MODULE2PACKAGE, ( |
|
|
f'{self.scope} not in {MODULE2PACKAGE}!,' |
|
|
'please pass a valid scope.') |
|
|
|
|
|
repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(self.scope) |
|
|
for model_cfg in BaseInferencer._get_models_from_metafile( |
|
|
repo_or_mim_dir): |
|
|
model_name = model_cfg['Name'].lower() |
|
|
model_aliases = model_cfg.get('Alias', []) |
|
|
if isinstance(model_aliases, str): |
|
|
model_aliases = [model_aliases.lower()] |
|
|
else: |
|
|
model_aliases = [alias.lower() for alias in model_aliases] |
|
|
if (model_name == model or model in model_aliases): |
|
|
cfg = Config.fromfile( |
|
|
osp.join(repo_or_mim_dir, model_cfg['Config'])) |
|
|
weights = model_cfg['Weights'] |
|
|
weights = weights[0] if isinstance(weights, list) else weights |
|
|
return cfg, weights |
|
|
raise ValueError(f'Cannot find model: {model} in {self.scope}') |
|
|
|
|
|
@staticmethod |
|
|
def _get_repo_or_mim_dir(scope): |
|
|
"""Get the directory where the ``Configs`` located when the package is |
|
|
installed or ``PYTHONPATH`` is set. |
|
|
|
|
|
Args: |
|
|
scope (str): The scope of repository. |
|
|
|
|
|
Returns: |
|
|
str: The directory where the ``Configs`` is located. |
|
|
""" |
|
|
try: |
|
|
module = importlib.import_module(scope) |
|
|
except ImportError: |
|
|
if scope not in MODULE2PACKAGE: |
|
|
raise KeyError( |
|
|
f'{scope} is not a valid scope. The available scopes ' |
|
|
f'are {MODULE2PACKAGE.keys()}') |
|
|
else: |
|
|
project = MODULE2PACKAGE[scope] |
|
|
raise ImportError( |
|
|
f'Cannot import {scope} correctly, please try to install ' |
|
|
f'the {project} by "pip install {project}"') |
|
|
|
|
|
|
|
|
|
|
|
package_path = module.__path__[0] |
|
|
|
|
|
if osp.exists(osp.join(osp.dirname(package_path), 'configs')): |
|
|
repo_dir = osp.dirname(package_path) |
|
|
return repo_dir |
|
|
else: |
|
|
mim_dir = osp.join(package_path, '.mim') |
|
|
if not osp.exists(osp.join(mim_dir, 'configs')): |
|
|
raise FileNotFoundError( |
|
|
f'Cannot find `configs` directory in {package_path}!, ' |
|
|
f'please check the completeness of the {scope}.') |
|
|
return mim_dir |
|
|
|
|
|
def _init_model( |
|
|
self, |
|
|
cfg: ConfigType, |
|
|
weights: Optional[str], |
|
|
device: str = 'cpu', |
|
|
) -> nn.Module: |
|
|
"""Initialize the model with the given config and checkpoint on the |
|
|
specific device. |
|
|
|
|
|
Args: |
|
|
cfg (ConfigType): Config containing the model information. |
|
|
weights (str, optional): Path to the checkpoint. |
|
|
device (str, optional): Device to run inference. Defaults to 'cpu'. |
|
|
|
|
|
Returns: |
|
|
nn.Module: Model loaded with checkpoint. |
|
|
""" |
|
|
checkpoint: Optional[dict] = None |
|
|
if weights is not None: |
|
|
checkpoint = _load_checkpoint(weights, map_location='cpu') |
|
|
|
|
|
if not cfg: |
|
|
assert checkpoint is not None |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg_string = checkpoint['message_hub']['runtime_info']['cfg'] |
|
|
except KeyError: |
|
|
assert 'meta' in checkpoint, ( |
|
|
'If model(config) is not provided, the checkpoint must' |
|
|
'contain the config string in `meta` or `message_hub`, ' |
|
|
'but both `meta` and `message_hub` are not found in the ' |
|
|
'checkpoint.') |
|
|
meta = checkpoint['meta'] |
|
|
if 'cfg' in meta: |
|
|
cfg_string = meta['cfg'] |
|
|
else: |
|
|
raise ValueError( |
|
|
'Cannot find the config in the checkpoint.') |
|
|
cfg.update( |
|
|
Config.fromstring(cfg_string, file_format='.py')._cfg_dict) |
|
|
|
|
|
|
|
|
|
|
|
if cfg.model.get('pretrained') is not None: |
|
|
del cfg.model.pretrained |
|
|
|
|
|
model = MODELS.build(cfg.model) |
|
|
model.cfg = cfg |
|
|
self._load_weights_to_model(model, checkpoint, cfg) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def _load_weights_to_model(self, model: nn.Module, |
|
|
checkpoint: Optional[dict], |
|
|
cfg: Optional[ConfigType]) -> None: |
|
|
"""Loading model weights and meta information from cfg and checkpoint. |
|
|
|
|
|
Subclasses could override this method to load extra meta information |
|
|
from ``checkpoint`` and ``cfg`` to model. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): Model to load weights and meta information. |
|
|
checkpoint (dict, optional): The loaded checkpoint. |
|
|
cfg (Config or ConfigDict, optional): The loaded config. |
|
|
""" |
|
|
if checkpoint is not None: |
|
|
_load_checkpoint_to_model(model, checkpoint) |
|
|
else: |
|
|
warnings.warn('Checkpoint is not loaded, and the inference ' |
|
|
'result is calculated by the randomly initialized ' |
|
|
'model!') |
|
|
|
|
|
def _init_collate(self, cfg: ConfigType) -> Callable: |
|
|
"""Initialize the ``collate_fn`` with the given config. |
|
|
|
|
|
The returned ``collate_fn`` will be used to collate the batch data. |
|
|
If will be used in :meth:`preprocess` like this |
|
|
|
|
|
.. code-block:: python |
|
|
def preprocess(self, inputs, batch_size, **kwargs): |
|
|
... |
|
|
dataloader = map(self.collate_fn, dataloader) |
|
|
yield from dataloader |
|
|
|
|
|
Args: |
|
|
cfg (ConfigType): Config which could contained the `collate_fn` |
|
|
information. If `collate_fn` is not defined in config, it will |
|
|
be :func:`pseudo_collate`. |
|
|
|
|
|
Returns: |
|
|
Callable: Collate function. |
|
|
""" |
|
|
try: |
|
|
with FUNCTIONS.switch_scope_and_registry(self.scope) as registry: |
|
|
collate_fn = registry.get(cfg.test_dataloader.collate_fn) |
|
|
except AttributeError: |
|
|
collate_fn = pseudo_collate |
|
|
return collate_fn |
|
|
|
|
|
@abstractmethod |
|
|
def _init_pipeline(self, cfg: ConfigType) -> Callable: |
|
|
"""Initialize the test pipeline. |
|
|
|
|
|
Return a pipeline to handle various input data, such as ``str``, |
|
|
``np.ndarray``. It is an abstract method in BaseInferencer, and should |
|
|
be implemented in subclasses. |
|
|
|
|
|
The returned pipeline will be used to process a single data. |
|
|
It will be used in :meth:`preprocess` like this: |
|
|
|
|
|
.. code-block:: python |
|
|
def preprocess(self, inputs, batch_size, **kwargs): |
|
|
... |
|
|
dataset = map(self.pipeline, dataset) |
|
|
... |
|
|
""" |
|
|
|
|
|
def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]: |
|
|
"""Initialize visualizers. |
|
|
|
|
|
Args: |
|
|
cfg (ConfigType): Config containing the visualizer information. |
|
|
|
|
|
Returns: |
|
|
Visualizer or None: Visualizer initialized with config. |
|
|
""" |
|
|
if 'visualizer' not in cfg: |
|
|
return None |
|
|
timestamp = str(datetime.timestamp(datetime.now())) |
|
|
name = cfg.visualizer.get('name', timestamp) |
|
|
if Visualizer.check_instance_created(name): |
|
|
name = f'{name}-{timestamp}' |
|
|
cfg.visualizer.name = name |
|
|
return VISUALIZERS.build(cfg.visualizer) |
|
|
|
|
|
def _get_chunk_data(self, inputs: Iterable, chunk_size: int): |
|
|
"""Get batch data from dataset. |
|
|
|
|
|
Args: |
|
|
inputs (Iterable): An iterable dataset. |
|
|
chunk_size (int): Equivalent to batch size. |
|
|
|
|
|
Yields: |
|
|
list: batch data. |
|
|
""" |
|
|
inputs_iter = iter(inputs) |
|
|
while True: |
|
|
try: |
|
|
chunk_data = [] |
|
|
for _ in range(chunk_size): |
|
|
processed_data = next(inputs_iter) |
|
|
chunk_data.append(processed_data) |
|
|
yield chunk_data |
|
|
except StopIteration: |
|
|
if chunk_data: |
|
|
yield chunk_data |
|
|
break |
|
|
|
|
|
def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]: |
|
|
"""Dispatch kwargs to preprocess(), forward(), visualize() and |
|
|
postprocess() according to the actual demands. |
|
|
|
|
|
Returns: |
|
|
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess, |
|
|
forward, visualize and postprocess respectively. |
|
|
""" |
|
|
|
|
|
method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \ |
|
|
self.visualize_kwargs | self.postprocess_kwargs |
|
|
|
|
|
union_kwargs = method_kwargs | set(kwargs.keys()) |
|
|
if union_kwargs != method_kwargs: |
|
|
unknown_kwargs = union_kwargs - method_kwargs |
|
|
raise ValueError( |
|
|
f'unknown argument {unknown_kwargs} for `preprocess`, ' |
|
|
'`forward`, `visualize` and `postprocess`') |
|
|
|
|
|
preprocess_kwargs = {} |
|
|
forward_kwargs = {} |
|
|
visualize_kwargs = {} |
|
|
postprocess_kwargs = {} |
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if key in self.preprocess_kwargs: |
|
|
preprocess_kwargs[key] = value |
|
|
elif key in self.forward_kwargs: |
|
|
forward_kwargs[key] = value |
|
|
elif key in self.visualize_kwargs: |
|
|
visualize_kwargs[key] = value |
|
|
else: |
|
|
postprocess_kwargs[key] = value |
|
|
|
|
|
return ( |
|
|
preprocess_kwargs, |
|
|
forward_kwargs, |
|
|
visualize_kwargs, |
|
|
postprocess_kwargs, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _get_models_from_metafile(dir: str): |
|
|
"""Load model config defined in metafile from package path. |
|
|
|
|
|
Args: |
|
|
dir (str): Path to the directory of Config. It requires the |
|
|
directory ``Config``, file ``model-index.yml`` exists in the |
|
|
``dir``. |
|
|
|
|
|
Yields: |
|
|
dict: Model config defined in metafile. |
|
|
""" |
|
|
meta_indexes = load(osp.join(dir, 'model-index.yml')) |
|
|
for meta_path in meta_indexes['Import']: |
|
|
|
|
|
meta_path = osp.join(dir, meta_path) |
|
|
metainfo = load(meta_path) |
|
|
yield from metainfo['Models'] |
|
|
|
|
|
@staticmethod |
|
|
def list_models(scope: Optional[str] = None, patterns: str = r'.*'): |
|
|
"""List models defined in metafile of corresponding packages. |
|
|
|
|
|
Args: |
|
|
scope (str, optional): The scope to which the model belongs. |
|
|
Defaults to None. |
|
|
patterns (str, optional): Regular expressions for the searched |
|
|
models. Once matched with ``Alias`` or ``Name`` filed in |
|
|
metafile, corresponding model will be added to the return list. |
|
|
Defaults to '.*'. |
|
|
|
|
|
Returns: |
|
|
dict: Model dict with model name and its alias. |
|
|
""" |
|
|
matched_models = [] |
|
|
if scope is None: |
|
|
default_scope = DefaultScope.get_current_instance() |
|
|
assert default_scope is not None, ( |
|
|
'scope should be initialized if you want ' |
|
|
'to load config from metafile.') |
|
|
assert scope in MODULE2PACKAGE, ( |
|
|
f'{scope} not in {MODULE2PACKAGE}!, please make pass a valid ' |
|
|
'scope.') |
|
|
root_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope) |
|
|
for model_cfg in BaseInferencer._get_models_from_metafile( |
|
|
root_or_mim_dir): |
|
|
model_name = [model_cfg['Name']] |
|
|
model_name.extend(model_cfg.get('Alias', [])) |
|
|
for name in model_name: |
|
|
if re.match(patterns, name) is not None: |
|
|
matched_models.append(name) |
|
|
output_str = '' |
|
|
for name in matched_models: |
|
|
output_str += f'model_name: {name}\n' |
|
|
print_log(output_str, logger='current') |
|
|
return matched_models |
|
|
|