|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import abc |
|
|
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from mmengine.structures import BaseDataElement, PixelData |
|
|
from mmengine.utils import is_list_of |
|
|
|
|
|
IndexType = Union[str, slice, int, list, torch.LongTensor, |
|
|
torch.cuda.LongTensor, torch.BoolTensor, |
|
|
torch.cuda.BoolTensor, np.ndarray] |
|
|
|
|
|
|
|
|
class MultilevelPixelData(BaseDataElement): |
|
|
"""Data structure for multi-level pixel-wise annotations or predictions. |
|
|
|
|
|
All data items in ``data_fields`` of ``MultilevelPixelData`` are lists |
|
|
of np.ndarray or torch.Tensor, and should meet the following requirements: |
|
|
|
|
|
- Have the same length, which is the number of levels |
|
|
- At each level, the data should have 3 dimensions in order of channel, |
|
|
height and weight |
|
|
- At each level, the data should have the same height and weight |
|
|
|
|
|
Examples: |
|
|
>>> metainfo = dict(num_keypoints=17) |
|
|
>>> sizes = [(64, 48), (128, 96), (256, 192)] |
|
|
>>> heatmaps = [np.random.rand(17, h, w) for h, w in sizes] |
|
|
>>> masks = [torch.rand(1, h, w) for h, w in sizes] |
|
|
>>> data = MultilevelPixelData(metainfo=metainfo, |
|
|
... heatmaps=heatmaps, |
|
|
... masks=masks) |
|
|
|
|
|
>>> # get data item |
|
|
>>> heatmaps = data.heatmaps # A list of 3 numpy.ndarrays |
|
|
>>> masks = data.masks # A list of 3 torch.Tensors |
|
|
|
|
|
>>> # get level |
|
|
>>> data_l0 = data[0] # PixelData with fields 'heatmaps' and 'masks' |
|
|
>>> data.nlevel |
|
|
3 |
|
|
|
|
|
>>> # get shape |
|
|
>>> data.shape |
|
|
((64, 48), (128, 96), (256, 192)) |
|
|
|
|
|
>>> # set |
|
|
>>> offset_maps = [torch.rand(2, h, w) for h, w in sizes] |
|
|
>>> data.offset_maps = offset_maps |
|
|
""" |
|
|
|
|
|
def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: |
|
|
object.__setattr__(self, '_nlevel', None) |
|
|
super().__init__(metainfo=metainfo, **kwargs) |
|
|
|
|
|
@property |
|
|
def nlevel(self): |
|
|
"""Return the level number. |
|
|
|
|
|
Returns: |
|
|
Optional[int]: The level number, or ``None`` if the data has not |
|
|
been assigned. |
|
|
""" |
|
|
return self._nlevel |
|
|
|
|
|
def __getitem__(self, item: Union[int, str, list, |
|
|
slice]) -> Union[PixelData, Sequence]: |
|
|
if isinstance(item, int): |
|
|
if self.nlevel is None or item >= self.nlevel: |
|
|
raise IndexError( |
|
|
f'Lcale index {item} out of range ({self.nlevel})') |
|
|
return self.get(f'_level_{item}') |
|
|
|
|
|
if isinstance(item, str): |
|
|
if item not in self: |
|
|
raise KeyError(item) |
|
|
return getattr(self, item) |
|
|
|
|
|
|
|
|
raise NotImplementedError( |
|
|
f'{self.__class__.__name__} does not support index type ' |
|
|
f'{type(item)}') |
|
|
|
|
|
def levels(self) -> List[PixelData]: |
|
|
if self.nlevel: |
|
|
return list(self[i] for i in range(self.nlevel)) |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def shape(self) -> Optional[Tuple[Tuple]]: |
|
|
"""Get the shape of multi-level pixel data. |
|
|
|
|
|
Returns: |
|
|
Optional[tuple]: A tuple of data shape at each level, or ``None`` |
|
|
if the data has not been assigned. |
|
|
""" |
|
|
if self.nlevel is None: |
|
|
return None |
|
|
|
|
|
return tuple(level.shape for level in self.levels()) |
|
|
|
|
|
def set_data(self, data: dict) -> None: |
|
|
"""Set or change key-value pairs in ``data_field`` by parameter |
|
|
``data``. |
|
|
|
|
|
Args: |
|
|
data (dict): A dict contains annotations of image or |
|
|
model predictions. |
|
|
""" |
|
|
assert isinstance(data, |
|
|
dict), f'meta should be a `dict` but got {data}' |
|
|
for k, v in data.items(): |
|
|
self.set_field(v, k, field_type='data') |
|
|
|
|
|
def set_field(self, |
|
|
value: Any, |
|
|
name: str, |
|
|
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, |
|
|
field_type: str = 'data') -> None: |
|
|
"""Special method for set union field, used as property.setter |
|
|
functions.""" |
|
|
assert field_type in ['metainfo', 'data'] |
|
|
if dtype is not None: |
|
|
assert isinstance( |
|
|
value, |
|
|
dtype), f'{value} should be a {dtype} but got {type(value)}' |
|
|
|
|
|
if name.startswith('_level_'): |
|
|
raise AttributeError( |
|
|
f'Cannot set {name} to be a field because the pattern ' |
|
|
'<_level_{n}> is reserved for inner data field') |
|
|
|
|
|
if field_type == 'metainfo': |
|
|
if name in self._data_fields: |
|
|
raise AttributeError( |
|
|
f'Cannot set {name} to be a field of metainfo ' |
|
|
f'because {name} is already a data field') |
|
|
self._metainfo_fields.add(name) |
|
|
|
|
|
else: |
|
|
if name in self._metainfo_fields: |
|
|
raise AttributeError( |
|
|
f'Cannot set {name} to be a field of data ' |
|
|
f'because {name} is already a metainfo field') |
|
|
|
|
|
if not isinstance(value, abc.Sequence): |
|
|
raise TypeError( |
|
|
'The value should be a sequence (of numpy.ndarray or' |
|
|
f'torch.Tesnor), but got a {type(value)}') |
|
|
|
|
|
if len(value) == 0: |
|
|
raise ValueError('Setting empty value is not allowed') |
|
|
|
|
|
if not isinstance(value[0], (torch.Tensor, np.ndarray)): |
|
|
raise TypeError( |
|
|
'The value should be a sequence of numpy.ndarray or' |
|
|
f'torch.Tesnor, but got a sequence of {type(value[0])}') |
|
|
|
|
|
if self.nlevel is not None: |
|
|
assert len(value) == self.nlevel, ( |
|
|
f'The length of the value ({len(value)}) should match the' |
|
|
f'number of the levels ({self.nlevel})') |
|
|
else: |
|
|
object.__setattr__(self, '_nlevel', len(value)) |
|
|
for i in range(self.nlevel): |
|
|
object.__setattr__(self, f'_level_{i}', PixelData()) |
|
|
|
|
|
for i, v in enumerate(value): |
|
|
self[i].set_field(v, name, field_type='data') |
|
|
|
|
|
self._data_fields.add(name) |
|
|
|
|
|
object.__setattr__(self, name, value) |
|
|
|
|
|
def __delattr__(self, item: str): |
|
|
"""delete the item in dataelement. |
|
|
|
|
|
Args: |
|
|
item (str): The key to delete. |
|
|
""" |
|
|
if item in ('_metainfo_fields', '_data_fields'): |
|
|
raise AttributeError(f'{item} has been used as a ' |
|
|
'private attribute, which is immutable. ') |
|
|
|
|
|
if item in self._metainfo_fields: |
|
|
super().__delattr__(item) |
|
|
else: |
|
|
for level in self.levels(): |
|
|
level.__delattr__(item) |
|
|
self._data_fields.remove(item) |
|
|
|
|
|
def __getattr__(self, name): |
|
|
if name in {'_data_fields', '_metainfo_fields' |
|
|
} or name not in self._data_fields: |
|
|
raise AttributeError( |
|
|
f'\'{self.__class__.__name__}\' object has no attribute ' |
|
|
f'\'{name}\'') |
|
|
|
|
|
return [getattr(level, name) for level in self.levels()] |
|
|
|
|
|
def pop(self, *args) -> Any: |
|
|
"""pop property in data and metainfo as the same as python.""" |
|
|
assert len(args) < 3, '``pop`` get more than 2 arguments' |
|
|
name = args[0] |
|
|
if name in self._metainfo_fields: |
|
|
self._metainfo_fields.remove(name) |
|
|
return self.__dict__.pop(*args) |
|
|
|
|
|
elif name in self._data_fields: |
|
|
self._data_fields.remove(name) |
|
|
return [level.pop(*args) for level in self.levels()] |
|
|
|
|
|
|
|
|
elif len(args) == 2: |
|
|
return args[1] |
|
|
else: |
|
|
|
|
|
|
|
|
raise KeyError(f'{args[0]} is not contained in metainfo or data') |
|
|
|
|
|
def _convert(self, apply_to: Type, |
|
|
func: Callable[[Any], Any]) -> 'MultilevelPixelData': |
|
|
"""Convert data items with the given function. |
|
|
|
|
|
Args: |
|
|
apply_to (Type): The type of data items to apply the conversion |
|
|
func (Callable): The conversion function that takes a data item |
|
|
as the input and return the converted result |
|
|
|
|
|
Returns: |
|
|
MultilevelPixelData: the converted data element. |
|
|
""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if is_list_of(v, apply_to): |
|
|
v = [func(_v) for _v in v] |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
def cpu(self) -> 'MultilevelPixelData': |
|
|
"""Convert all tensors to CPU in data.""" |
|
|
return self._convert(apply_to=torch.Tensor, func=lambda x: x.cpu()) |
|
|
|
|
|
def cuda(self) -> 'MultilevelPixelData': |
|
|
"""Convert all tensors to GPU in data.""" |
|
|
return self._convert(apply_to=torch.Tensor, func=lambda x: x.cuda()) |
|
|
|
|
|
def detach(self) -> 'MultilevelPixelData': |
|
|
"""Detach all tensors in data.""" |
|
|
return self._convert(apply_to=torch.Tensor, func=lambda x: x.detach()) |
|
|
|
|
|
def numpy(self) -> 'MultilevelPixelData': |
|
|
"""Convert all tensor to np.narray in data.""" |
|
|
return self._convert( |
|
|
apply_to=torch.Tensor, func=lambda x: x.detach().cpu().numpy()) |
|
|
|
|
|
def to_tensor(self) -> 'MultilevelPixelData': |
|
|
"""Convert all tensor to np.narray in data.""" |
|
|
return self._convert( |
|
|
apply_to=np.ndarray, func=lambda x: torch.from_numpy(x)) |
|
|
|
|
|
|
|
|
def to(self, *args, **kwargs) -> 'MultilevelPixelData': |
|
|
"""Apply same name function to all tensors in data_fields.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if hasattr(v[0], 'to'): |
|
|
v = [v_.to(*args, **kwargs) for v_ in v] |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|