Spaces:
Sleeping
Sleeping
| # Copyright 2020 Google Research. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Bounding Box List definition. | |
| BoxList represents a list of bounding boxes as tensorflow | |
| tensors, where each bounding box is represented as a row of 4 numbers, | |
| [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes | |
| within a given list correspond to a single image. See also | |
| box_list.py for common box related operations (such as area, iou, etc). | |
| Optionally, users can add additional related fields (such as weights). | |
| We assume the following things to be true about fields: | |
| * they correspond to boxes in the box_list along the 0th dimension | |
| * they have inferable rank at graph construction time | |
| * all dimensions except for possibly the 0th can be inferred | |
| (i.e., not None) at graph construction time. | |
| Some other notes: | |
| * Following tensorflow conventions, we use height, width ordering, | |
| and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering | |
| * Tensors are always provided as (flat) [N, 4] tensors. | |
| """ | |
| import torch | |
| from typing import Optional, List, Dict | |
| class BoxList(object): | |
| """Box collection.""" | |
| data: Dict[str, torch.Tensor] | |
| def __init__(self, boxes): | |
| """Constructs box collection. | |
| Args: | |
| boxes: a tensor of shape [N, 4] representing box corners | |
| Raises: | |
| ValueError: if invalid dimensions for bbox data or if bbox data is not in float32 format. | |
| """ | |
| if len(boxes.shape) != 2 or boxes.shape[-1] != 4: | |
| raise ValueError('Invalid dimensions for box data.') | |
| if boxes.dtype != torch.float32: | |
| raise ValueError('Invalid tensor type: should be tf.float32') | |
| self.data = {'boxes': boxes} | |
| def num_boxes(self): | |
| """Returns number of boxes held in collection. | |
| Returns: | |
| a tensor representing the number of boxes held in the collection. | |
| """ | |
| return self.data['boxes'].shape[0] | |
| def get_all_fields(self): | |
| """Returns all fields.""" | |
| return self.data.keys() | |
| def get_extra_fields(self): | |
| """Returns all non-box fields (i.e., everything not named 'boxes').""" | |
| # return [k for k in self.data.keys() if k != 'boxes'] # FIXME torscript doesn't support comprehensions yet | |
| extra: List[str] = [] | |
| for k in self.data.keys(): | |
| if k != 'boxes': | |
| extra.append(k) | |
| return extra | |
| def add_field(self, field: str, field_data: torch.Tensor): | |
| """Add field to box list. | |
| This method can be used to add related box data such as weights/labels, etc. | |
| Args: | |
| field: a string key to access the data via `get` | |
| field_data: a tensor containing the data to store in the BoxList | |
| """ | |
| self.data[field] = field_data | |
| def has_field(self, field: str): | |
| return field in self.data | |
| #@property # FIXME for torchscript compat | |
| def boxes(self): | |
| """Convenience function for accessing box coordinates. | |
| Returns: | |
| a tensor with shape [N, 4] representing box coordinates. | |
| """ | |
| return self.get_field('boxes') | |
| #@boxes.setter # FIXME for torchscript compat | |
| def set_boxes(self, boxes): | |
| """Convenience function for setting box coordinates. | |
| Args: | |
| boxes: a tensor of shape [N, 4] representing box corners | |
| Raises: | |
| ValueError: if invalid dimensions for bbox data | |
| """ | |
| if len(boxes.shape) != 2 or boxes.shape[-1] != 4: | |
| raise ValueError('Invalid dimensions for box data.') | |
| self.data['boxes'] = boxes | |
| def get_field(self, field: str): | |
| """Accesses a box collection and associated fields. | |
| This function returns specified field with object; if no field is specified, | |
| it returns the box coordinates. | |
| Args: | |
| field: this optional string parameter can be used to specify a related field to be accessed. | |
| Returns: | |
| a tensor representing the box collection or an associated field. | |
| Raises: | |
| ValueError: if invalid field | |
| """ | |
| if not self.has_field(field): | |
| raise ValueError(f'field {field} does not exist') | |
| return self.data[field] | |
| def set_field(self, field: str, value: torch.Tensor): | |
| """Sets the value of a field. | |
| Updates the field of a box_list with a given value. | |
| Args: | |
| field: (string) name of the field to set value. | |
| value: the value to assign to the field. | |
| Raises: | |
| ValueError: if the box_list does not have specified field. | |
| """ | |
| if not self.has_field(field): | |
| raise ValueError(f'field {field} does not exist') | |
| self.data[field] = value | |
| def get_center_coordinates_and_sizes(self): | |
| """Computes the center coordinates, height and width of the boxes. | |
| Returns: | |
| a list of 4 1-D tensors [ycenter, xcenter, height, width]. | |
| """ | |
| box_corners = self.boxes() | |
| ymin, xmin, ymax, xmax = box_corners.t().unbind() | |
| width = xmax - xmin | |
| height = ymax - ymin | |
| ycenter = ymin + height / 2. | |
| xcenter = xmin + width / 2. | |
| return [ycenter, xcenter, height, width] | |
| def transpose_coordinates(self): | |
| """Transpose the coordinate representation in a boxlist. | |
| """ | |
| y_min, x_min, y_max, x_max = self.boxes().chunk(4, dim=1) | |
| self.set_boxes(torch.cat([x_min, y_min, x_max, y_max], 1)) | |
| def as_tensor_dict(self, fields: Optional[List[str]] = None): | |
| """Retrieves specified fields as a dictionary of tensors. | |
| Args: | |
| fields: (optional) list of fields to return in the dictionary. | |
| If None (default), all fields are returned. | |
| Returns: | |
| tensor_dict: A dictionary of tensors specified by fields. | |
| Raises: | |
| ValueError: if specified field is not contained in boxlist. | |
| """ | |
| tensor_dict = {} | |
| if fields is None: | |
| fields = self.get_all_fields() | |
| for field in fields: | |
| if not self.has_field(field): | |
| raise ValueError('boxlist must contain all specified fields') | |
| tensor_dict[field] = self.get_field(field) | |
| return tensor_dict | |
| #@property | |
| def device(self): | |
| return self.data['boxes'].device | |