|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
|
from collections.abc import Sized |
|
|
from typing import Any, List, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from mmengine.device import get_device |
|
|
from .base_data_element import BaseDataElement |
|
|
|
|
|
BoolTypeTensor: Union[Any] |
|
|
LongTypeTensor: Union[Any] |
|
|
|
|
|
if get_device() == 'npu': |
|
|
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] |
|
|
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] |
|
|
elif get_device() == 'mlu': |
|
|
BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor] |
|
|
LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor] |
|
|
else: |
|
|
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] |
|
|
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] |
|
|
|
|
|
IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor, |
|
|
BoolTypeTensor, np.ndarray] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InstanceData(BaseDataElement): |
|
|
"""Data structure for instance-level annotations or predictions. |
|
|
|
|
|
Subclass of :class:`BaseDataElement`. All value in `data_fields` |
|
|
should have the same length. This design refer to |
|
|
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 |
|
|
InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value |
|
|
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, |
|
|
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. |
|
|
|
|
|
Examples: |
|
|
>>> # custom data structure |
|
|
>>> class TmpObject: |
|
|
... def __init__(self, tmp) -> None: |
|
|
... assert isinstance(tmp, list) |
|
|
... self.tmp = tmp |
|
|
... def __len__(self): |
|
|
... return len(self.tmp) |
|
|
... def __getitem__(self, item): |
|
|
... if isinstance(item, int): |
|
|
... if item >= len(self) or item < -len(self): # type:ignore |
|
|
... raise IndexError(f'Index {item} out of range!') |
|
|
... else: |
|
|
... # keep the dimension |
|
|
... item = slice(item, None, len(self)) |
|
|
... return TmpObject(self.tmp[item]) |
|
|
... @staticmethod |
|
|
... def cat(tmp_objs): |
|
|
... assert all(isinstance(results, TmpObject) for results in tmp_objs) |
|
|
... if len(tmp_objs) == 1: |
|
|
... return tmp_objs[0] |
|
|
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] |
|
|
... tmp_list = list(itertools.chain(*tmp_list)) |
|
|
... new_data = TmpObject(tmp_list) |
|
|
... return new_data |
|
|
... def __repr__(self): |
|
|
... return str(self.tmp) |
|
|
>>> from mmengine.structures import InstanceData |
|
|
>>> import numpy as np |
|
|
>>> import torch |
|
|
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) |
|
|
>>> instance_data = InstanceData(metainfo=img_meta) |
|
|
>>> 'img_shape' in instance_data |
|
|
True |
|
|
>>> instance_data.det_labels = torch.LongTensor([2, 3]) |
|
|
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) |
|
|
>>> instance_data.bboxes = torch.rand((2, 4)) |
|
|
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) |
|
|
>>> len(instance_data) |
|
|
2 |
|
|
>>> print(instance_data) |
|
|
<InstanceData( |
|
|
META INFORMATION |
|
|
img_shape: (800, 1196, 3) |
|
|
pad_shape: (800, 1216, 3) |
|
|
DATA FIELDS |
|
|
det_labels: tensor([2, 3]) |
|
|
det_scores: tensor([0.8000, 0.7000]) |
|
|
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], |
|
|
[0.8101, 0.3105, 0.5123, 0.6263]]) |
|
|
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] |
|
|
) at 0x7fb492de6280> |
|
|
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices] |
|
|
>>> sorted_results.det_scores |
|
|
tensor([0.7000, 0.8000]) |
|
|
>>> print(instance_data[instance_data.det_scores > 0.75]) |
|
|
<InstanceData( |
|
|
META INFORMATION |
|
|
img_shape: (800, 1196, 3) |
|
|
pad_shape: (800, 1216, 3) |
|
|
DATA FIELDS |
|
|
det_labels: tensor([2]) |
|
|
det_scores: tensor([0.8000]) |
|
|
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) |
|
|
polygons: [[1, 2, 3, 4]] |
|
|
) at 0x7f64ecf0ec40> |
|
|
>>> print(instance_data[instance_data.det_scores > 1]) |
|
|
<InstanceData( |
|
|
META INFORMATION |
|
|
img_shape: (800, 1196, 3) |
|
|
pad_shape: (800, 1216, 3) |
|
|
DATA FIELDS |
|
|
det_labels: tensor([], dtype=torch.int64) |
|
|
det_scores: tensor([]) |
|
|
bboxes: tensor([], size=(0, 4)) |
|
|
polygons: [] |
|
|
) at 0x7f660a6a7f70> |
|
|
>>> print(instance_data.cat([instance_data, instance_data])) |
|
|
<InstanceData( |
|
|
META INFORMATION |
|
|
img_shape: (800, 1196, 3) |
|
|
pad_shape: (800, 1216, 3) |
|
|
DATA FIELDS |
|
|
det_labels: tensor([2, 3, 2, 3]) |
|
|
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) |
|
|
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], |
|
|
[0.8101, 0.3105, 0.5123, 0.6263], |
|
|
[0.4997, 0.7707, 0.0595, 0.4188], |
|
|
[0.8101, 0.3105, 0.5123, 0.6263]]) |
|
|
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] |
|
|
) at 0x7f203542feb0> |
|
|
""" |
|
|
|
|
|
def __setattr__(self, name: str, value: Sized): |
|
|
"""setattr is only used to set data. |
|
|
|
|
|
The value must have the attribute of `__len__` and have the same length |
|
|
of `InstanceData`. |
|
|
""" |
|
|
if name in ('_metainfo_fields', '_data_fields'): |
|
|
if not hasattr(self, name): |
|
|
super().__setattr__(name, value) |
|
|
else: |
|
|
raise AttributeError(f'{name} has been used as a ' |
|
|
'private attribute, which is immutable.') |
|
|
|
|
|
else: |
|
|
assert isinstance(value, |
|
|
Sized), 'value must contain `__len__` attribute' |
|
|
|
|
|
if len(self) > 0: |
|
|
assert len(value) == len(self), 'The length of ' \ |
|
|
f'values {len(value)} is ' \ |
|
|
'not consistent with ' \ |
|
|
'the length of this ' \ |
|
|
':obj:`InstanceData` ' \ |
|
|
f'{len(self)}' |
|
|
super().__setattr__(name, value) |
|
|
|
|
|
__setitem__ = __setattr__ |
|
|
|
|
|
def __getitem__(self, item: IndexType) -> 'InstanceData': |
|
|
""" |
|
|
Args: |
|
|
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, |
|
|
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): |
|
|
Get the corresponding values according to item. |
|
|
|
|
|
Returns: |
|
|
:obj:`InstanceData`: Corresponding values. |
|
|
""" |
|
|
assert isinstance(item, IndexType.__args__) |
|
|
if isinstance(item, list): |
|
|
item = np.array(item) |
|
|
if isinstance(item, np.ndarray): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
item = item.astype(np.int64) if item.dtype == np.int32 else item |
|
|
item = torch.from_numpy(item) |
|
|
|
|
|
if isinstance(item, str): |
|
|
return getattr(self, item) |
|
|
|
|
|
if isinstance(item, int): |
|
|
if item >= len(self) or item < -len(self): |
|
|
raise IndexError(f'Index {item} out of range!') |
|
|
else: |
|
|
|
|
|
item = slice(item, None, len(self)) |
|
|
|
|
|
new_data = self.__class__(metainfo=self.metainfo) |
|
|
if isinstance(item, torch.Tensor): |
|
|
assert item.dim() == 1, 'Only support to get the' \ |
|
|
' values along the first dimension.' |
|
|
if isinstance(item, BoolTypeTensor.__args__): |
|
|
assert len(item) == len(self), 'The shape of the ' \ |
|
|
'input(BoolTensor) ' \ |
|
|
f'{len(item)} ' \ |
|
|
'does not match the shape ' \ |
|
|
'of the indexed tensor ' \ |
|
|
'in results_field ' \ |
|
|
f'{len(self)} at ' \ |
|
|
'first dimension.' |
|
|
|
|
|
for k, v in self.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
new_data[k] = v[item] |
|
|
elif isinstance(v, np.ndarray): |
|
|
new_data[k] = v[item.cpu().numpy()] |
|
|
elif isinstance( |
|
|
v, (str, list, tuple)) or (hasattr(v, '__getitem__') |
|
|
and hasattr(v, 'cat')): |
|
|
|
|
|
if isinstance(item, BoolTypeTensor.__args__): |
|
|
indexes = torch.nonzero(item).view( |
|
|
-1).cpu().numpy().tolist() |
|
|
else: |
|
|
indexes = item.cpu().numpy().tolist() |
|
|
slice_list = [] |
|
|
if indexes: |
|
|
for index in indexes: |
|
|
slice_list.append(slice(index, None, len(v))) |
|
|
else: |
|
|
slice_list.append(slice(None, 0, None)) |
|
|
r_list = [v[s] for s in slice_list] |
|
|
if isinstance(v, (str, list, tuple)): |
|
|
new_value = r_list[0] |
|
|
for r in r_list[1:]: |
|
|
new_value = new_value + r |
|
|
else: |
|
|
new_value = v.cat(r_list) |
|
|
new_data[k] = new_value |
|
|
else: |
|
|
raise ValueError( |
|
|
f'The type of `{k}` is `{type(v)}`, which has no ' |
|
|
'attribute of `cat`, so it does not ' |
|
|
'support slice with `bool`') |
|
|
|
|
|
else: |
|
|
|
|
|
for k, v in self.items(): |
|
|
new_data[k] = v[item] |
|
|
return new_data |
|
|
|
|
|
@staticmethod |
|
|
def cat(instances_list: List['InstanceData']) -> 'InstanceData': |
|
|
"""Concat the instances of all :obj:`InstanceData` in the list. |
|
|
|
|
|
Note: To ensure that cat returns as expected, make sure that |
|
|
all elements in the list must have exactly the same keys. |
|
|
|
|
|
Args: |
|
|
instances_list (list[:obj:`InstanceData`]): A list |
|
|
of :obj:`InstanceData`. |
|
|
|
|
|
Returns: |
|
|
:obj:`InstanceData` |
|
|
""" |
|
|
assert all( |
|
|
isinstance(results, InstanceData) for results in instances_list) |
|
|
assert len(instances_list) > 0 |
|
|
if len(instances_list) == 1: |
|
|
return instances_list[0] |
|
|
|
|
|
|
|
|
|
|
|
field_keys_list = [ |
|
|
instances.all_keys() for instances in instances_list |
|
|
] |
|
|
assert len({len(field_keys) for field_keys in field_keys_list}) \ |
|
|
== 1 and len(set(itertools.chain(*field_keys_list))) \ |
|
|
== len(field_keys_list[0]), 'There are different keys in ' \ |
|
|
'`instances_list`, which may ' \ |
|
|
'cause the cat operation ' \ |
|
|
'to fail. Please make sure all ' \ |
|
|
'elements in `instances_list` ' \ |
|
|
'have the exact same key.' |
|
|
|
|
|
new_data = instances_list[0].__class__( |
|
|
metainfo=instances_list[0].metainfo) |
|
|
for k in instances_list[0].keys(): |
|
|
values = [results[k] for results in instances_list] |
|
|
v0 = values[0] |
|
|
if isinstance(v0, torch.Tensor): |
|
|
new_values = torch.cat(values, dim=0) |
|
|
elif isinstance(v0, np.ndarray): |
|
|
new_values = np.concatenate(values, axis=0) |
|
|
elif isinstance(v0, (str, list, tuple)): |
|
|
new_values = v0[:] |
|
|
for v in values[1:]: |
|
|
new_values += v |
|
|
elif hasattr(v0, 'cat'): |
|
|
new_values = v0.cat(values) |
|
|
else: |
|
|
raise ValueError( |
|
|
f'The type of `{k}` is `{type(v0)}` which has no ' |
|
|
'attribute of `cat`') |
|
|
new_data[k] = new_values |
|
|
return new_data |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""int: The length of InstanceData.""" |
|
|
if len(self._data_fields) > 0: |
|
|
return len(self.values()[0]) |
|
|
else: |
|
|
return 0 |
|
|
|