| |
| import copy |
| from typing import Any, Iterator, Optional, Tuple, Type, Union |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| class BaseDataElement: |
| """A base data interface that supports Tensor-like and dict-like |
| operations. |
| |
| A typical data elements refer to predicted results or ground truth labels |
| on a task, such as predicted bboxes, instance masks, semantic |
| segmentation masks, etc. Because groundtruth labels and predicted results |
| often have similar properties (for example, the predicted bboxes and the |
| groundtruth bboxes), MMEngine uses the same abstract data interface to |
| encapsulate predicted results and groundtruth labels, and it is recommended |
| to use different name conventions to distinguish them, such as using |
| ``gt_instances`` and ``pred_instances`` to distinguish between labels and |
| predicted results. Additionally, we distinguish data elements at instance |
| level, pixel level, and label level. Each of these types has its own |
| characteristics. Therefore, MMEngine defines the base class |
| ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and |
| ``LabelData`` inheriting from ``BaseDataElement`` to represent different |
| types of ground truth labels or predictions. |
| |
| Another common data element is sample data. A sample data consists of input |
| data (such as an image) and its annotations and predictions. In general, |
| an image can have multiple types of annotations and/or predictions at the |
| same time (for example, both pixel-level semantic segmentation annotations |
| and instance-level detection bboxes annotations). All labels and |
| predictions of a training sample are often passed between Dataset, Model, |
| Visualizer, and Evaluator components. In order to simplify the interface |
| between components, we can treat them as a large data element and |
| encapsulate them. Such data elements are generally called XXDataSample in |
| the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` |
| allows `BaseDataElement` as its attribute. Such a class generally |
| encapsulates all the data of a sample in the algorithm library, and its |
| attributes generally are various types of data elements. For example, |
| MMDetection is assigned by the BaseDataElement to encapsulate all the data |
| elements of the sample labeling and prediction of a sample in the |
| algorithm library. |
| |
| The attributes in ``BaseDataElement`` are divided into two parts, |
| the ``metainfo`` and the ``data`` respectively. |
| |
| - ``metainfo``: Usually contains the |
| information about the image such as filename, |
| image_shape, pad_shape, etc. The attributes can be accessed or |
| modified by dict-like or object-like operations, such as |
| ``.`` (for data access and modification), ``in``, ``del``, |
| ``pop(str)``, ``get(str)``, ``metainfo_keys()``, |
| ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for |
| set or change key-value pairs in metainfo). |
| |
| - ``data``: Annotations or model predictions are |
| stored. The attributes can be accessed or modified by |
| dict-like or object-like operations, such as |
| ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, |
| ``values()``, ``items()``. Users can also apply tensor-like |
| methods to all :obj:`torch.Tensor` in the ``data_fields``, |
| such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, |
| ``to_tensor()``, ``.detach()``. |
| |
| Args: |
| metainfo (dict, optional): A dict contains the meta information |
| of single image, such as ``dict(img_shape=(512, 512, 3), |
| scale_factor=(1, 1, 1, 1))``. Defaults to None. |
| kwargs (dict, optional): A dict contains annotations of single image or |
| model predictions. Defaults to None. |
| |
| Examples: |
| >>> import torch |
| >>> from mmengine.structures import BaseDataElement |
| >>> gt_instances = BaseDataElement() |
| >>> bboxes = torch.rand((5, 4)) |
| >>> scores = torch.rand((5,)) |
| >>> img_id = 0 |
| >>> img_shape = (800, 1333) |
| >>> gt_instances = BaseDataElement( |
| ... metainfo=dict(img_id=img_id, img_shape=img_shape), |
| ... bboxes=bboxes, scores=scores) |
| >>> gt_instances = BaseDataElement( |
| ... metainfo=dict(img_id=img_id, img_shape=(640, 640))) |
| |
| >>> # new |
| >>> gt_instances1 = gt_instances.new( |
| ... metainfo=dict(img_id=1, img_shape=(640, 640)), |
| ... bboxes=torch.rand((5, 4)), |
| ... scores=torch.rand((5,))) |
| >>> gt_instances2 = gt_instances1.new() |
| |
| >>> # add and process property |
| >>> gt_instances = BaseDataElement() |
| >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) |
| >>> assert 'img_shape' in gt_instances.metainfo_keys() |
| >>> assert 'img_shape' in gt_instances |
| >>> assert 'img_shape' not in gt_instances.keys() |
| >>> assert 'img_shape' in gt_instances.all_keys() |
| >>> print(gt_instances.img_shape) |
| (100, 100) |
| >>> gt_instances.scores = torch.rand((5,)) |
| >>> assert 'scores' in gt_instances.keys() |
| >>> assert 'scores' in gt_instances |
| >>> assert 'scores' in gt_instances.all_keys() |
| >>> assert 'scores' not in gt_instances.metainfo_keys() |
| >>> print(gt_instances.scores) |
| tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) |
| >>> gt_instances.bboxes = torch.rand((5, 4)) |
| >>> assert 'bboxes' in gt_instances.keys() |
| >>> assert 'bboxes' in gt_instances |
| >>> assert 'bboxes' in gt_instances.all_keys() |
| >>> assert 'bboxes' not in gt_instances.metainfo_keys() |
| >>> print(gt_instances.bboxes) |
| tensor([[0.0900, 0.0424, 0.1755, 0.4469], |
| [0.8648, 0.0592, 0.3484, 0.0913], |
| [0.5808, 0.1909, 0.6165, 0.7088], |
| [0.5490, 0.4209, 0.9416, 0.2374], |
| [0.3652, 0.1218, 0.8805, 0.7523]]) |
| |
| >>> # delete and change property |
| >>> gt_instances = BaseDataElement( |
| ... metainfo=dict(img_id=0, img_shape=(640, 640)), |
| ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) |
| >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) |
| >>> gt_instances.img_shape # (1280, 1280) |
| >>> gt_instances.bboxes = gt_instances.bboxes * 2 |
| >>> gt_instances.get('img_shape', None) # (1280, 1280) |
| >>> gt_instances.get('bboxes', None) # 6x4 tensor |
| >>> del gt_instances.img_shape |
| >>> del gt_instances.bboxes |
| >>> assert 'img_shape' not in gt_instances |
| >>> assert 'bboxes' not in gt_instances |
| >>> gt_instances.pop('img_shape', None) # None |
| >>> gt_instances.pop('bboxes', None) # None |
| |
| >>> # Tensor-like |
| >>> cuda_instances = gt_instances.cuda() |
| >>> cuda_instances = gt_instances.to('cuda:0') |
| >>> cpu_instances = cuda_instances.cpu() |
| >>> cpu_instances = cuda_instances.to('cpu') |
| >>> fp16_instances = cuda_instances.to( |
| ... device=None, dtype=torch.float16, non_blocking=False, |
| ... copy=False, memory_format=torch.preserve_format) |
| >>> cpu_instances = cuda_instances.detach() |
| >>> np_instances = cpu_instances.numpy() |
| |
| >>> # print |
| >>> metainfo = dict(img_shape=(800, 1196, 3)) |
| >>> gt_instances = BaseDataElement( |
| ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) |
| >>> sample = BaseDataElement(metainfo=metainfo, |
| ... gt_instances=gt_instances) |
| >>> print(sample) |
| <BaseDataElement( |
| META INFORMATION |
| img_shape: (800, 1196, 3) |
| DATA FIELDS |
| gt_instances: <BaseDataElement( |
| META INFORMATION |
| img_shape: (800, 1196, 3) |
| DATA FIELDS |
| det_labels: tensor([0, 1, 2, 3]) |
| ) at 0x7f0ec5eadc70> |
| ) at 0x7f0fea49e130> |
| |
| >>> # inheritance |
| >>> class DetDataSample(BaseDataElement): |
| ... @property |
| ... def proposals(self): |
| ... return self._proposals |
| ... @proposals.setter |
| ... def proposals(self, value): |
| ... self.set_field(value, '_proposals', dtype=BaseDataElement) |
| ... @proposals.deleter |
| ... def proposals(self): |
| ... del self._proposals |
| ... @property |
| ... def gt_instances(self): |
| ... return self._gt_instances |
| ... @gt_instances.setter |
| ... def gt_instances(self, value): |
| ... self.set_field(value, '_gt_instances', |
| ... dtype=BaseDataElement) |
| ... @gt_instances.deleter |
| ... def gt_instances(self): |
| ... del self._gt_instances |
| ... @property |
| ... def pred_instances(self): |
| ... return self._pred_instances |
| ... @pred_instances.setter |
| ... def pred_instances(self, value): |
| ... self.set_field(value, '_pred_instances', |
| ... dtype=BaseDataElement) |
| ... @pred_instances.deleter |
| ... def pred_instances(self): |
| ... del self._pred_instances |
| >>> det_sample = DetDataSample() |
| >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) |
| >>> det_sample.proposals = proposals |
| >>> assert 'proposals' in det_sample |
| >>> assert det_sample.proposals == proposals |
| >>> del det_sample.proposals |
| >>> assert 'proposals' not in det_sample |
| >>> with self.assertRaises(AssertionError): |
| ... det_sample.proposals = torch.rand((5, 4)) |
| """ |
|
|
| def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: |
|
|
| self._metainfo_fields: set = set() |
| self._data_fields: set = set() |
|
|
| if metainfo is not None: |
| self.set_metainfo(metainfo=metainfo) |
| if kwargs: |
| self.set_data(kwargs) |
|
|
| def set_metainfo(self, metainfo: dict) -> None: |
| """Set or change key-value pairs in ``metainfo_field`` by parameter |
| ``metainfo``. |
| |
| Args: |
| metainfo (dict): A dict contains the meta information |
| of image, such as ``img_shape``, ``scale_factor``, etc. |
| """ |
| assert isinstance( |
| metainfo, |
| dict), f'metainfo should be a ``dict`` but got {type(metainfo)}' |
| meta = copy.deepcopy(metainfo) |
| for k, v in meta.items(): |
| self.set_field(name=k, value=v, field_type='metainfo', dtype=None) |
|
|
| def set_data(self, data: dict) -> None: |
| """Set or change key-value pairs in ``data_field`` by parameter |
| ``data``. |
| |
| Args: |
| data (dict): A dict contains annotations of image or |
| model predictions. |
| """ |
| assert isinstance(data, |
| dict), f'data should be a `dict` but got {data}' |
| for k, v in data.items(): |
| |
| |
| setattr(self, k, v) |
|
|
| def update(self, instance: 'BaseDataElement') -> None: |
| """The update() method updates the BaseDataElement with the elements |
| from another BaseDataElement object. |
| |
| Args: |
| instance (BaseDataElement): Another BaseDataElement object for |
| update the current object. |
| """ |
| assert isinstance( |
| instance, BaseDataElement |
| ), f'instance should be a `BaseDataElement` but got {type(instance)}' |
| self.set_metainfo(dict(instance.metainfo_items())) |
| self.set_data(dict(instance.items())) |
|
|
| def new(self, |
| *, |
| metainfo: Optional[dict] = None, |
| **kwargs) -> 'BaseDataElement': |
| """Return a new data element with same type. If ``metainfo`` and |
| ``data`` are None, the new data element will have same metainfo and |
| data. If metainfo or data is not None, the new result will overwrite it |
| with the input value. |
| |
| Args: |
| metainfo (dict, optional): A dict contains the meta information |
| of image, such as ``img_shape``, ``scale_factor``, etc. |
| Defaults to None. |
| kwargs (dict): A dict contains annotations of image or |
| model predictions. |
| |
| Returns: |
| BaseDataElement: A new data element with same type. |
| """ |
| new_data = self.__class__() |
|
|
| if metainfo is not None: |
| new_data.set_metainfo(metainfo) |
| else: |
| new_data.set_metainfo(dict(self.metainfo_items())) |
| if kwargs: |
| new_data.set_data(kwargs) |
| else: |
| new_data.set_data(dict(self.items())) |
| return new_data |
|
|
| def clone(self): |
| """Deep copy the current data element. |
| |
| Returns: |
| BaseDataElement: The copy of current data element. |
| """ |
| clone_data = self.__class__() |
| clone_data.set_metainfo(dict(self.metainfo_items())) |
| clone_data.set_data(dict(self.items())) |
| return clone_data |
|
|
| def keys(self) -> list: |
| """ |
| Returns: |
| list: Contains all keys in data_fields. |
| """ |
| |
| |
| |
| |
| private_keys = { |
| '_' + key |
| for key in self._data_fields |
| if isinstance(getattr(type(self), key, None), property) |
| } |
| return list(self._data_fields - private_keys) |
|
|
| def metainfo_keys(self) -> list: |
| """ |
| Returns: |
| list: Contains all keys in metainfo_fields. |
| """ |
| return list(self._metainfo_fields) |
|
|
| def values(self) -> list: |
| """ |
| Returns: |
| list: Contains all values in data. |
| """ |
| return [getattr(self, k) for k in self.keys()] |
|
|
| def metainfo_values(self) -> list: |
| """ |
| Returns: |
| list: Contains all values in metainfo. |
| """ |
| return [getattr(self, k) for k in self.metainfo_keys()] |
|
|
| def all_keys(self) -> list: |
| """ |
| Returns: |
| list: Contains all keys in metainfo and data. |
| """ |
| return self.metainfo_keys() + self.keys() |
|
|
| def all_values(self) -> list: |
| """ |
| Returns: |
| list: Contains all values in metainfo and data. |
| """ |
| return self.metainfo_values() + self.values() |
|
|
| def all_items(self) -> Iterator[Tuple[str, Any]]: |
| """ |
| Returns: |
| iterator: An iterator object whose element is (key, value) tuple |
| pairs for ``metainfo`` and ``data``. |
| """ |
| for k in self.all_keys(): |
| yield (k, getattr(self, k)) |
|
|
| def items(self) -> Iterator[Tuple[str, Any]]: |
| """ |
| Returns: |
| iterator: An iterator object whose element is (key, value) tuple |
| pairs for ``data``. |
| """ |
| for k in self.keys(): |
| yield (k, getattr(self, k)) |
|
|
| def metainfo_items(self) -> Iterator[Tuple[str, Any]]: |
| """ |
| Returns: |
| iterator: An iterator object whose element is (key, value) tuple |
| pairs for ``metainfo``. |
| """ |
| for k in self.metainfo_keys(): |
| yield (k, getattr(self, k)) |
|
|
| @property |
| def metainfo(self) -> dict: |
| """dict: A dict contains metainfo of current data element.""" |
| return dict(self.metainfo_items()) |
|
|
| def __setattr__(self, name: str, value: Any): |
| """setattr is only used to set data.""" |
| if name in ('_metainfo_fields', '_data_fields'): |
| if not hasattr(self, name): |
| super().__setattr__(name, value) |
| else: |
| raise AttributeError(f'{name} has been used as a ' |
| 'private attribute, which is immutable.') |
| else: |
| self.set_field( |
| name=name, value=value, field_type='data', dtype=None) |
|
|
| def __delattr__(self, item: str): |
| """Delete the item in dataelement. |
| |
| Args: |
| item (str): The key to delete. |
| """ |
| if item in ('_metainfo_fields', '_data_fields'): |
| raise AttributeError(f'{item} has been used as a ' |
| 'private attribute, which is immutable.') |
| super().__delattr__(item) |
| if item in self._metainfo_fields: |
| self._metainfo_fields.remove(item) |
| elif item in self._data_fields: |
| self._data_fields.remove(item) |
|
|
| |
| __delitem__ = __delattr__ |
|
|
| def get(self, key, default=None) -> Any: |
| """Get property in data and metainfo as the same as python.""" |
| |
| |
| return getattr(self, key, default) |
|
|
| def pop(self, *args) -> Any: |
| """Pop property in data and metainfo as the same as python.""" |
| assert len(args) < 3, '``pop`` get more than 2 arguments' |
| name = args[0] |
| if name in self._metainfo_fields: |
| self._metainfo_fields.remove(args[0]) |
| return self.__dict__.pop(*args) |
|
|
| elif name in self._data_fields: |
| self._data_fields.remove(args[0]) |
| return self.__dict__.pop(*args) |
|
|
| |
| elif len(args) == 2: |
| return args[1] |
| else: |
| |
| |
| raise KeyError(f'{args[0]} is not contained in metainfo or data') |
|
|
| def __contains__(self, item: str) -> bool: |
| """Whether the item is in dataelement. |
| |
| Args: |
| item (str): The key to inquire. |
| """ |
| return item in self._data_fields or item in self._metainfo_fields |
|
|
| def set_field(self, |
| value: Any, |
| name: str, |
| dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, |
| field_type: str = 'data') -> None: |
| """Special method for set union field, used as property.setter |
| functions.""" |
| assert field_type in ['metainfo', 'data'] |
| if dtype is not None: |
| assert isinstance( |
| value, |
| dtype), f'{value} should be a {dtype} but got {type(value)}' |
|
|
| if field_type == 'metainfo': |
| if name in self._data_fields: |
| raise AttributeError( |
| f'Cannot set {name} to be a field of metainfo ' |
| f'because {name} is already a data field') |
| self._metainfo_fields.add(name) |
| else: |
| if name in self._metainfo_fields: |
| raise AttributeError( |
| f'Cannot set {name} to be a field of data ' |
| f'because {name} is already a metainfo field') |
| self._data_fields.add(name) |
| super().__setattr__(name, value) |
|
|
| |
| def to(self, *args, **kwargs) -> 'BaseDataElement': |
| """Apply same name function to all tensors in data_fields.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| if hasattr(v, 'to'): |
| v = v.to(*args, **kwargs) |
| data = {k: v} |
| new_data.set_data(data) |
| return new_data |
|
|
| |
| def cpu(self) -> 'BaseDataElement': |
| """Convert all tensors to CPU in data.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| if isinstance(v, (torch.Tensor, BaseDataElement)): |
| v = v.cpu() |
| data = {k: v} |
| new_data.set_data(data) |
| return new_data |
|
|
| |
| def cuda(self) -> 'BaseDataElement': |
| """Convert all tensors to GPU in data.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| if isinstance(v, (torch.Tensor, BaseDataElement)): |
| v = v.cuda() |
| data = {k: v} |
| new_data.set_data(data) |
| return new_data |
|
|
| |
| def musa(self) -> 'BaseDataElement': |
| """Convert all tensors to musa in data.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| if isinstance(v, (torch.Tensor, BaseDataElement)): |
| v = v.musa() |
| data = {k: v} |
| new_data.set_data(data) |
| return new_data |
|
|
| |
| def npu(self) -> 'BaseDataElement': |
| """Convert all tensors to NPU in data.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| if isinstance(v, (torch.Tensor, BaseDataElement)): |
| v = v.npu() |
| data = {k: v} |
| new_data.set_data(data) |
| return new_data |
|
|
| def mlu(self) -> 'BaseDataElement': |
| """Convert all tensors to MLU in data.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| if isinstance(v, (torch.Tensor, BaseDataElement)): |
| v = v.mlu() |
| data = {k: v} |
| new_data.set_data(data) |
| return new_data |
|
|
| |
| def detach(self) -> 'BaseDataElement': |
| """Detach all tensors in data.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| if isinstance(v, (torch.Tensor, BaseDataElement)): |
| v = v.detach() |
| data = {k: v} |
| new_data.set_data(data) |
| return new_data |
|
|
| |
| def numpy(self) -> 'BaseDataElement': |
| """Convert all tensors to np.ndarray in data.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| if isinstance(v, (torch.Tensor, BaseDataElement)): |
| v = v.detach().cpu().numpy() |
| data = {k: v} |
| new_data.set_data(data) |
| return new_data |
|
|
| def to_tensor(self) -> 'BaseDataElement': |
| """Convert all np.ndarray to tensor in data.""" |
| new_data = self.new() |
| for k, v in self.items(): |
| data = {} |
| if isinstance(v, np.ndarray): |
| v = torch.from_numpy(v) |
| data[k] = v |
| elif isinstance(v, BaseDataElement): |
| v = v.to_tensor() |
| data[k] = v |
| new_data.set_data(data) |
| return new_data |
|
|
| def to_dict(self) -> dict: |
| """Convert BaseDataElement to dict.""" |
| return { |
| k: v.to_dict() if isinstance(v, BaseDataElement) else v |
| for k, v in self.all_items() |
| } |
|
|
| def __repr__(self) -> str: |
| """Represent the object.""" |
|
|
| def _addindent(s_: str, num_spaces: int) -> str: |
| """This func is modified from `pytorch` https://github.com/pytorch/ |
| pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu |
| les/module.py#L29. |
| |
| Args: |
| s_ (str): The string to add spaces. |
| num_spaces (int): The num of space to add. |
| |
| Returns: |
| str: The string after add indent. |
| """ |
| s = s_.split('\n') |
| |
| if len(s) == 1: |
| return s_ |
| first = s.pop(0) |
| s = [(num_spaces * ' ') + line for line in s] |
| s = '\n'.join(s) |
| s = first + '\n' + s |
| return s |
|
|
| def dump(obj: Any) -> str: |
| """Represent the object. |
| |
| Args: |
| obj (Any): The obj to represent. |
| |
| Returns: |
| str: The represented str. |
| """ |
| _repr = '' |
| if isinstance(obj, dict): |
| for k, v in obj.items(): |
| _repr += f'\n{k}: {_addindent(dump(v), 4)}' |
| elif isinstance(obj, BaseDataElement): |
| _repr += '\n\n META INFORMATION' |
| metainfo_items = dict(obj.metainfo_items()) |
| _repr += _addindent(dump(metainfo_items), 4) |
| _repr += '\n\n DATA FIELDS' |
| items = dict(obj.items()) |
| _repr += _addindent(dump(items), 4) |
| classname = obj.__class__.__name__ |
| _repr = f'<{classname}({_repr}\n) at {hex(id(obj))}>' |
| else: |
| _repr += repr(obj) |
| return _repr |
|
|
| return dump(self) |
|
|