Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import itertools | |
| import numpy as np | |
| import torch | |
| from .general_data import GeneralData | |
| class InstanceData(GeneralData): | |
| """Data structure for instance-level annnotations or predictions. | |
| Subclass of :class:`GeneralData`. All value in `data_fields` | |
| should have the same length. This design refer to | |
| https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 | |
| Examples: | |
| >>> from mmdet.core import InstanceData | |
| >>> import numpy as np | |
| >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |
| >>> results = InstanceData(img_meta) | |
| >>> img_shape in results | |
| True | |
| >>> results.det_labels = torch.LongTensor([0, 1, 2, 3]) | |
| >>> results["det_scores"] = torch.Tensor([0.01, 0.7, 0.6, 0.3]) | |
| >>> results["det_masks"] = np.ndarray(4, 2, 2) | |
| >>> len(results) | |
| 4 | |
| >>> print(resutls) | |
| <InstanceData( | |
| META INFORMATION | |
| pad_shape: (800, 1216, 3) | |
| img_shape: (800, 1196, 3) | |
| PREDICTIONS | |
| shape of det_labels: torch.Size([4]) | |
| shape of det_masks: (4, 2, 2) | |
| shape of det_scores: torch.Size([4]) | |
| ) at 0x7fe26b5ca990> | |
| >>> sorted_results = results[results.det_scores.sort().indices] | |
| >>> sorted_results.det_scores | |
| tensor([0.0100, 0.3000, 0.6000, 0.7000]) | |
| >>> sorted_results.det_labels | |
| tensor([0, 3, 2, 1]) | |
| >>> print(results[results.scores > 0.5]) | |
| <InstanceData( | |
| META INFORMATION | |
| pad_shape: (800, 1216, 3) | |
| img_shape: (800, 1196, 3) | |
| PREDICTIONS | |
| shape of det_labels: torch.Size([2]) | |
| shape of det_masks: (2, 2, 2) | |
| shape of det_scores: torch.Size([2]) | |
| ) at 0x7fe26b6d7790> | |
| >>> results[results.det_scores > 0.5].det_labels | |
| tensor([1, 2]) | |
| >>> results[results.det_scores > 0.5].det_scores | |
| tensor([0.7000, 0.6000]) | |
| """ | |
| def __setattr__(self, name, value): | |
| if name in ('_meta_info_fields', '_data_fields'): | |
| if not hasattr(self, name): | |
| super().__setattr__(name, value) | |
| else: | |
| raise AttributeError( | |
| f'{name} has been used as a ' | |
| f'private attribute, which is immutable. ') | |
| else: | |
| assert isinstance(value, (torch.Tensor, np.ndarray, list)), \ | |
| f'Can set {type(value)}, only support' \ | |
| f' {(torch.Tensor, np.ndarray, list)}' | |
| if self._data_fields: | |
| assert len(value) == len(self), f'the length of ' \ | |
| f'values {len(value)} is ' \ | |
| f'not consistent with' \ | |
| f' the length ' \ | |
| f'of this :obj:`InstanceData` ' \ | |
| f'{len(self)} ' | |
| super().__setattr__(name, value) | |
| def __getitem__(self, item): | |
| """ | |
| Args: | |
| item (str, obj:`slice`, | |
| obj`torch.LongTensor`, obj:`torch.BoolTensor`): | |
| get the corresponding values according to item. | |
| Returns: | |
| obj:`InstanceData`: Corresponding values. | |
| """ | |
| assert len(self), ' This is a empty instance' | |
| assert isinstance( | |
| item, (str, slice, int, torch.LongTensor, torch.BoolTensor)) | |
| if isinstance(item, str): | |
| return getattr(self, item) | |
| if type(item) == int: | |
| if item >= len(self) or item < -len(self): | |
| raise IndexError(f'Index {item} out of range!') | |
| else: | |
| # keep the dimension | |
| item = slice(item, None, len(self)) | |
| new_data = self.new() | |
| if isinstance(item, (torch.Tensor)): | |
| assert item.dim() == 1, 'Only support to get the' \ | |
| ' values along the first dimension.' | |
| if isinstance(item, torch.BoolTensor): | |
| assert len(item) == len(self), f'The shape of the' \ | |
| f' input(BoolTensor)) ' \ | |
| f'{len(item)} ' \ | |
| f' does not match the shape ' \ | |
| f'of the indexed tensor ' \ | |
| f'in results_filed ' \ | |
| f'{len(self)} at ' \ | |
| f'first dimension. ' | |
| for k, v in self.items(): | |
| if isinstance(v, torch.Tensor): | |
| new_data[k] = v[item] | |
| elif isinstance(v, np.ndarray): | |
| new_data[k] = v[item.cpu().numpy()] | |
| elif isinstance(v, list): | |
| r_list = [] | |
| # convert to indexes from boolTensor | |
| if isinstance(item, torch.BoolTensor): | |
| indexes = torch.nonzero(item).view(-1) | |
| else: | |
| indexes = item | |
| for index in indexes: | |
| r_list.append(v[index]) | |
| new_data[k] = r_list | |
| else: | |
| # item is a slice | |
| for k, v in self.items(): | |
| new_data[k] = v[item] | |
| return new_data | |
| def cat(instances_list): | |
| """Concat the predictions of all :obj:`InstanceData` in the list. | |
| Args: | |
| instances_list (list[:obj:`InstanceData`]): A list | |
| of :obj:`InstanceData`. | |
| Returns: | |
| obj:`InstanceData` | |
| """ | |
| assert all( | |
| isinstance(results, InstanceData) for results in instances_list) | |
| assert len(instances_list) > 0 | |
| if len(instances_list) == 1: | |
| return instances_list[0] | |
| new_data = instances_list[0].new() | |
| for k in instances_list[0]._data_fields: | |
| values = [results[k] for results in instances_list] | |
| v0 = values[0] | |
| if isinstance(v0, torch.Tensor): | |
| values = torch.cat(values, dim=0) | |
| elif isinstance(v0, np.ndarray): | |
| values = np.concatenate(values, axis=0) | |
| elif isinstance(v0, list): | |
| values = list(itertools.chain(*values)) | |
| else: | |
| raise ValueError( | |
| f'Can not concat the {k} which is a {type(v0)}') | |
| new_data[k] = values | |
| return new_data | |
| def __len__(self): | |
| if len(self._data_fields): | |
| for v in self.values(): | |
| return len(v) | |
| else: | |
| raise AssertionError('This is an empty `InstanceData`.') | |