|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
from typing import List, Sequence, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from .base_data_element import BaseDataElement |
|
|
|
|
|
|
|
|
class PixelData(BaseDataElement): |
|
|
"""Data structure for pixel-level annotations or predictions. |
|
|
|
|
|
All data items in ``data_fields`` of ``PixelData`` meet the following |
|
|
requirements: |
|
|
|
|
|
- They all have 3 dimensions in orders of channel, height, and width. |
|
|
- They should have the same height and width. |
|
|
|
|
|
Examples: |
|
|
>>> metainfo = dict( |
|
|
... img_id=random.randint(0, 100), |
|
|
... img_shape=(random.randint(400, 600), random.randint(400, 600))) |
|
|
>>> image = np.random.randint(0, 255, (4, 20, 40)) |
|
|
>>> featmap = torch.randint(0, 255, (10, 20, 40)) |
|
|
>>> pixel_data = PixelData(metainfo=metainfo, |
|
|
... image=image, |
|
|
... featmap=featmap) |
|
|
>>> print(pixel_data.shape) |
|
|
(20, 40) |
|
|
|
|
|
>>> # slice |
|
|
>>> slice_data = pixel_data[10:20, 20:40] |
|
|
>>> assert slice_data.shape == (10, 20) |
|
|
>>> slice_data = pixel_data[10, 20] |
|
|
>>> assert slice_data.shape == (1, 1) |
|
|
|
|
|
>>> # set |
|
|
>>> pixel_data.map3 = torch.randint(0, 255, (20, 40)) |
|
|
>>> assert tuple(pixel_data.map3.shape) == (1, 20, 40) |
|
|
>>> with self.assertRaises(AssertionError): |
|
|
... # The dimension must be 3 or 2 |
|
|
... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40)) |
|
|
""" |
|
|
|
|
|
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]): |
|
|
"""Set attributes of ``PixelData``. |
|
|
|
|
|
If the dimension of value is 2 and its shape meet the demand, it |
|
|
will automatically expand its channel-dimension. |
|
|
|
|
|
Args: |
|
|
name (str): The key to access the value, stored in `PixelData`. |
|
|
value (Union[torch.Tensor, np.ndarray]): The value to store in. |
|
|
The type of value must be `torch.Tensor` or `np.ndarray`, |
|
|
and its shape must meet the requirements of `PixelData`. |
|
|
""" |
|
|
if name in ('_metainfo_fields', '_data_fields'): |
|
|
if not hasattr(self, name): |
|
|
super().__setattr__(name, value) |
|
|
else: |
|
|
raise AttributeError(f'{name} has been used as a ' |
|
|
'private attribute, which is immutable.') |
|
|
|
|
|
else: |
|
|
assert isinstance(value, (torch.Tensor, np.ndarray)), \ |
|
|
f'Can not set {type(value)}, only support' \ |
|
|
f' {(torch.Tensor, np.ndarray)}' |
|
|
|
|
|
if self.shape: |
|
|
assert tuple(value.shape[-2:]) == self.shape, ( |
|
|
'The height and width of ' |
|
|
f'values {tuple(value.shape[-2:])} is ' |
|
|
'not consistent with ' |
|
|
'the shape of this ' |
|
|
':obj:`PixelData` ' |
|
|
f'{self.shape}') |
|
|
assert value.ndim in [ |
|
|
2, 3 |
|
|
], f'The dim of value must be 2 or 3, but got {value.ndim}' |
|
|
if value.ndim == 2: |
|
|
value = value[None] |
|
|
warnings.warn('The shape of value will convert from ' |
|
|
f'{value.shape[-2:]} to {value.shape}') |
|
|
super().__setattr__(name, value) |
|
|
|
|
|
|
|
|
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData': |
|
|
""" |
|
|
Args: |
|
|
item (Sequence[Union[int, slice]]): Get the corresponding values |
|
|
according to item. |
|
|
|
|
|
Returns: |
|
|
:obj:`PixelData`: Corresponding values. |
|
|
""" |
|
|
|
|
|
new_data = self.__class__(metainfo=self.metainfo) |
|
|
if isinstance(item, tuple): |
|
|
|
|
|
assert len(item) == 2, 'Only support to slice height and width' |
|
|
tmp_item: List[slice] = list() |
|
|
for index, single_item in enumerate(item[::-1]): |
|
|
if isinstance(single_item, int): |
|
|
tmp_item.insert( |
|
|
0, slice(single_item, None, self.shape[-index - 1])) |
|
|
elif isinstance(single_item, slice): |
|
|
tmp_item.insert(0, single_item) |
|
|
else: |
|
|
raise TypeError( |
|
|
'The type of element in input must be int or slice, ' |
|
|
f'but got {type(single_item)}') |
|
|
tmp_item.insert(0, slice(None, None, None)) |
|
|
item = tuple(tmp_item) |
|
|
for k, v in self.items(): |
|
|
setattr(new_data, k, v[item]) |
|
|
else: |
|
|
raise TypeError( |
|
|
f'Unsupported type {type(item)} for slicing PixelData') |
|
|
return new_data |
|
|
|
|
|
@property |
|
|
def shape(self): |
|
|
"""The shape of pixel data.""" |
|
|
if len(self._data_fields) > 0: |
|
|
return tuple(self.values()[0].shape[-2:]) |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
|