| |
| from typing import List, Optional |
|
|
| import torch |
|
|
| try: |
| import mmpretrain |
| from mmpretrain.models.classifiers import ImageClassifier |
| except ImportError: |
| mmpretrain = None |
| ImageClassifier = object |
|
|
| from mmdet.registry import MODELS |
| from mmdet.structures import ReIDDataSample |
|
|
|
|
| @MODELS.register_module() |
| class BaseReID(ImageClassifier): |
| """Base model for re-identification.""" |
|
|
| def __init__(self, *args, **kwargs): |
| if mmpretrain is None: |
| raise RuntimeError('Please run "pip install openmim" and ' |
| 'run "mim install mmpretrain" to ' |
| 'install mmpretrain first.') |
| super().__init__(*args, **kwargs) |
|
|
| def forward(self, |
| inputs: torch.Tensor, |
| data_samples: Optional[List[ReIDDataSample]] = None, |
| mode: str = 'tensor'): |
| """The unified entry for a forward process in both training and test. |
| |
| The method should accept three modes: "tensor", "predict" and "loss": |
| |
| - "tensor": Forward the whole network and return tensor or tuple of |
| tensor without any post-processing, same as a common nn.Module. |
| - "predict": Forward and return the predictions, which are fully |
| processed to a list of :obj:`ReIDDataSample`. |
| - "loss": Forward and return a dict of losses according to the given |
| inputs and data samples. |
| |
| Note that this method doesn't handle neither back propagation nor |
| optimizer updating, which are done in the :meth:`train_step`. |
| |
| Args: |
| inputs (torch.Tensor): The input tensor with shape |
| (N, C, H, W) or (N, T, C, H, W). |
| data_samples (List[ReIDDataSample], optional): The annotation |
| data of every sample. It's required if ``mode="loss"``. |
| Defaults to None. |
| mode (str): Return what kind of value. Defaults to 'tensor'. |
| |
| Returns: |
| The return type depends on ``mode``. |
| |
| - If ``mode="tensor"``, return a tensor or a tuple of tensor. |
| - If ``mode="predict"``, return a list of |
| :obj:`ReIDDataSample`. |
| - If ``mode="loss"``, return a dict of tensor. |
| """ |
| if len(inputs.size()) == 5: |
| assert inputs.size(0) == 1 |
| inputs = inputs[0] |
| return super().forward(inputs, data_samples, mode) |
|
|