| | |
| | import itertools |
| | import warnings |
| | from typing import Any, Dict, List, Tuple, Union |
| | import torch |
| |
|
| |
|
| | class Instances: |
| | """ |
| | This class represents a list of instances in an image. |
| | It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields". |
| | All fields must have the same ``__len__`` which is the number of instances. |
| | |
| | All other (non-field) attributes of this class are considered private: |
| | they must start with '_' and are not modifiable by a user. |
| | |
| | Some basic usage: |
| | |
| | 1. Set/get/check a field: |
| | |
| | .. code-block:: python |
| | |
| | instances.gt_boxes = Boxes(...) |
| | print(instances.pred_masks) # a tensor of shape (N, H, W) |
| | print('gt_masks' in instances) |
| | |
| | 2. ``len(instances)`` returns the number of instances |
| | 3. Indexing: ``instances[indices]`` will apply the indexing on all the fields |
| | and returns a new :class:`Instances`. |
| | Typically, ``indices`` is a integer vector of indices, |
| | or a binary mask of length ``num_instances`` |
| | |
| | .. code-block:: python |
| | |
| | category_3_detections = instances[instances.pred_classes == 3] |
| | confident_detections = instances[instances.scores > 0.9] |
| | """ |
| |
|
| | def __init__(self, image_size: Tuple[int, int], **kwargs: Any): |
| | """ |
| | Args: |
| | image_size (height, width): the spatial size of the image. |
| | kwargs: fields to add to this `Instances`. |
| | """ |
| | self._image_size = image_size |
| | self._fields: Dict[str, Any] = {} |
| | for k, v in kwargs.items(): |
| | self.set(k, v) |
| |
|
| | @property |
| | def image_size(self) -> Tuple[int, int]: |
| | """ |
| | Returns: |
| | tuple: height, width |
| | """ |
| | return self._image_size |
| |
|
| | def __setattr__(self, name: str, val: Any) -> None: |
| | if name.startswith("_"): |
| | super().__setattr__(name, val) |
| | else: |
| | self.set(name, val) |
| |
|
| | def __getattr__(self, name: str) -> Any: |
| | if name == "_fields" or name not in self._fields: |
| | raise AttributeError("Cannot find field '{}' in the given Instances!".format(name)) |
| | return self._fields[name] |
| |
|
| | def set(self, name: str, value: Any) -> None: |
| | """ |
| | Set the field named `name` to `value`. |
| | The length of `value` must be the number of instances, |
| | and must agree with other existing fields in this object. |
| | """ |
| | with warnings.catch_warnings(record=True): |
| | data_len = len(value) |
| | if len(self._fields): |
| | assert ( |
| | len(self) == data_len |
| | ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self)) |
| | self._fields[name] = value |
| |
|
| | def has(self, name: str) -> bool: |
| | """ |
| | Returns: |
| | bool: whether the field called `name` exists. |
| | """ |
| | return name in self._fields |
| |
|
| | def remove(self, name: str) -> None: |
| | """ |
| | Remove the field called `name`. |
| | """ |
| | del self._fields[name] |
| |
|
| | def get(self, name: str) -> Any: |
| | """ |
| | Returns the field called `name`. |
| | """ |
| | return self._fields[name] |
| |
|
| | def get_fields(self) -> Dict[str, Any]: |
| | """ |
| | Returns: |
| | dict: a dict which maps names (str) to data of the fields |
| | |
| | Modifying the returned dict will modify this instance. |
| | """ |
| | return self._fields |
| |
|
| | |
| | def to(self, *args: Any, **kwargs: Any) -> "Instances": |
| | """ |
| | Returns: |
| | Instances: all fields are called with a `to(device)`, if the field has this method. |
| | """ |
| | ret = Instances(self._image_size) |
| | for k, v in self._fields.items(): |
| | if hasattr(v, "to"): |
| | v = v.to(*args, **kwargs) |
| | ret.set(k, v) |
| | return ret |
| |
|
| | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances": |
| | """ |
| | Args: |
| | item: an index-like object and will be used to index all the fields. |
| | |
| | Returns: |
| | If `item` is a string, return the data in the corresponding field. |
| | Otherwise, returns an `Instances` where all fields are indexed by `item`. |
| | """ |
| | if type(item) == int: |
| | if item >= len(self) or item < -len(self): |
| | raise IndexError("Instances index out of range!") |
| | else: |
| | item = slice(item, None, len(self)) |
| |
|
| | ret = Instances(self._image_size) |
| | for k, v in self._fields.items(): |
| | ret.set(k, v[item]) |
| | return ret |
| |
|
| | def __len__(self) -> int: |
| | for v in self._fields.values(): |
| | |
| | return v.__len__() |
| | raise NotImplementedError("Empty Instances does not support __len__!") |
| |
|
| | def __iter__(self): |
| | raise NotImplementedError("`Instances` object is not iterable!") |
| |
|
| | @staticmethod |
| | def cat(instance_lists: List["Instances"]) -> "Instances": |
| | """ |
| | Args: |
| | instance_lists (list[Instances]) |
| | |
| | Returns: |
| | Instances |
| | """ |
| | assert all(isinstance(i, Instances) for i in instance_lists) |
| | assert len(instance_lists) > 0 |
| | if len(instance_lists) == 1: |
| | return instance_lists[0] |
| |
|
| | image_size = instance_lists[0].image_size |
| | if not isinstance(image_size, torch.Tensor): |
| | for i in instance_lists[1:]: |
| | assert i.image_size == image_size |
| | ret = Instances(image_size) |
| | for k in instance_lists[0]._fields.keys(): |
| | values = [i.get(k) for i in instance_lists] |
| | v0 = values[0] |
| | if isinstance(v0, torch.Tensor): |
| | values = torch.cat(values, dim=0) |
| | elif isinstance(v0, list): |
| | values = list(itertools.chain(*values)) |
| | elif hasattr(type(v0), "cat"): |
| | values = type(v0).cat(values) |
| | else: |
| | raise ValueError("Unsupported type {} for concatenation".format(type(v0))) |
| | ret.set(k, values) |
| | return ret |
| |
|
| | def __str__(self) -> str: |
| | s = self.__class__.__name__ + "(" |
| | s += "num_instances={}, ".format(len(self)) |
| | s += "image_height={}, ".format(self._image_size[0]) |
| | s += "image_width={}, ".format(self._image_size[1]) |
| | s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items()))) |
| | return s |
| |
|
| | __repr__ = __str__ |
| |
|