Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from .np_box_list import * | |
| EPSILON = 1e-7 | |
| class MaskList(BoxList): | |
| """Convenience wrapper for BoxList with masks. | |
| BoxMaskList extends the np_box_list.BoxList to contain masks as well. | |
| In particular, its constructor receives both boxes and masks. Note that the | |
| masks correspond to the full image. | |
| """ | |
| def __init__(self, box_data, mask_data): | |
| """Constructs box collection. | |
| Args: | |
| box_data: a numpy array of shape [N, 4] representing box coordinates | |
| mask_data: a numpy array of shape [N, height, width] representing masks | |
| with values are in {0,1}. The masks correspond to the full | |
| image. The height and the width will be equal to image height and width. | |
| Raises: | |
| ValueError: if bbox data is not a numpy array | |
| ValueError: if invalid dimensions for bbox data | |
| ValueError: if mask data is not a numpy array | |
| ValueError: if invalid dimension for mask data | |
| """ | |
| super(MaskList, self).__init__(box_data) | |
| if not isinstance(mask_data, np.ndarray): | |
| raise ValueError('Mask data must be a numpy array.') | |
| if len(mask_data.shape) != 3: | |
| raise ValueError('Invalid dimensions for mask data.') | |
| if mask_data.dtype != np.uint8: | |
| raise ValueError('Invalid data type for mask data: uint8 is required.') | |
| if mask_data.shape[0] != box_data.shape[0]: | |
| raise ValueError('There should be the same number of boxes and masks.') | |
| self.data['masks'] = mask_data | |
| def get_masks(self): | |
| """Convenience function for accessing masks. | |
| Returns: | |
| a numpy array of shape [N, height, width] representing masks | |
| """ | |
| return self.get_field('masks') | |
| def boxlist_to_masklist(boxlist): | |
| """Converts a BoxList containing 'masks' into a BoxMaskList. | |
| Args: | |
| boxlist: An np_box_list.BoxList object. | |
| Returns: | |
| An BoxMaskList object. | |
| Raises: | |
| ValueError: If boxlist does not contain `masks` as a field. | |
| """ | |
| if not boxlist.has_field('masks'): | |
| raise ValueError('boxlist does not contain mask field.') | |
| masklist = MaskList(box_data=boxlist.get(), mask_data=boxlist.get_field('masks')) | |
| extra_fields = boxlist.get_extra_fields() | |
| for key in extra_fields: | |
| if key != 'masks': | |
| masklist.data[key] = boxlist.get_field(key) | |
| return masklist | |
| def area_mask(masks): | |
| """Computes area of masks. | |
| Args: | |
| masks: Numpy array with shape [N, height, width] holding N masks. Masks | |
| values are of type np.uint8 and values are in {0,1}. | |
| Returns: | |
| a numpy array with shape [N*1] representing mask areas. | |
| Raises: | |
| ValueError: If masks.dtype is not np.uint8 | |
| """ | |
| if masks.dtype != np.uint8: | |
| raise ValueError('Masks type should be np.uint8') | |
| return np.sum(masks, axis=(1, 2), dtype=np.float32) | |
| def intersection_mask(masks1, masks2): | |
| """Compute pairwise intersection areas between masks. | |
| Args: | |
| masks1: a numpy array with shape [N, height, width] holding N masks. Masks | |
| values are of type np.uint8 and values are in {0,1}. | |
| masks2: a numpy array with shape [M, height, width] holding M masks. Masks | |
| values are of type np.uint8 and values are in {0,1}. | |
| Returns: | |
| a numpy array with shape [N*M] representing pairwise intersection area. | |
| Raises: | |
| ValueError: If masks1 and masks2 are not of type np.uint8. | |
| """ | |
| if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: | |
| raise ValueError('masks1 and masks2 should be of type np.uint8') | |
| n = masks1.shape[0] | |
| m = masks2.shape[0] | |
| answer = np.zeros([n, m], dtype=np.float32) | |
| for i in np.arange(n): | |
| for j in np.arange(m): | |
| answer[i, j] = np.sum(np.minimum(masks1[i], masks2[j]), dtype=np.float32) | |
| return answer | |
| def iou_mask(masks1, masks2): | |
| """Computes pairwise intersection-over-union between mask collections. | |
| Args: | |
| masks1: a numpy array with shape [N, height, width] holding N masks. Masks | |
| values are of type np.uint8 and values are in {0,1}. | |
| masks2: a numpy array with shape [M, height, width] holding N masks. Masks | |
| values are of type np.uint8 and values are in {0,1}. | |
| Returns: | |
| a numpy array with shape [N, M] representing pairwise iou scores. | |
| Raises: | |
| ValueError: If masks1 and masks2 are not of type np.uint8. | |
| """ | |
| if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: | |
| raise ValueError('masks1 and masks2 should be of type np.uint8') | |
| intersect = intersection(masks1, masks2) | |
| area1 = area(masks1) | |
| area2 = area(masks2) | |
| union = np.expand_dims(area1, axis=1) + np.expand_dims(area2, axis=0) - intersect | |
| return intersect / np.maximum(union, EPSILON) | |
| def ioa_mask(masks1, masks2): | |
| """Computes pairwise intersection-over-area between box collections. | |
| Intersection-over-area (ioa) between two masks, mask1 and mask2 is defined as | |
| their intersection area over mask2's area. Note that ioa is not symmetric, | |
| that is, IOA(mask1, mask2) != IOA(mask2, mask1). | |
| Args: | |
| masks1: a numpy array with shape [N, height, width] holding N masks. Masks | |
| values are of type np.uint8 and values are in {0,1}. | |
| masks2: a numpy array with shape [M, height, width] holding N masks. Masks | |
| values are of type np.uint8 and values are in {0,1}. | |
| Returns: | |
| a numpy array with shape [N, M] representing pairwise ioa scores. | |
| Raises: | |
| ValueError: If masks1 and masks2 are not of type np.uint8. | |
| """ | |
| if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: | |
| raise ValueError('masks1 and masks2 should be of type np.uint8') | |
| intersect = intersection(masks1, masks2) | |
| areas = np.expand_dims(area(masks2), axis=0) | |
| return intersect / (areas + EPSILON) | |
| def area_masklist(masklist): | |
| """Computes area of masks. | |
| Args: | |
| masklist: BoxMaskList holding N boxes and masks | |
| Returns: | |
| a numpy array with shape [N*1] representing mask areas | |
| """ | |
| return area_mask(masklist.get_masks()) | |
| def intersection_masklist(masklist1, masklist2): | |
| """Compute pairwise intersection areas between masks. | |
| Args: | |
| masklist1: BoxMaskList holding N boxes and masks | |
| masklist2: BoxMaskList holding M boxes and masks | |
| Returns: | |
| a numpy array with shape [N*M] representing pairwise intersection area | |
| """ | |
| return intersection_mask(masklist1.get_masks(), masklist2.get_masks()) | |
| def iou_masklist(masklist1, masklist2): | |
| """Computes pairwise intersection-over-union between box and mask collections. | |
| Args: | |
| masklist1: BoxMaskList holding N boxes and masks | |
| masklist2: BoxMaskList holding M boxes and masks | |
| Returns: | |
| a numpy array with shape [N, M] representing pairwise iou scores. | |
| """ | |
| return iou_mask(masklist1.get_masks(), masklist2.get_masks()) | |
| def ioa_masklist(masklist1, masklist2): | |
| """Computes pairwise intersection-over-area between box and mask collections. | |
| Intersection-over-area (ioa) between two masks mask1 and mask2 is defined as | |
| their intersection area over mask2's area. Note that ioa is not symmetric, | |
| that is, IOA(mask1, mask2) != IOA(mask2, mask1). | |
| Args: | |
| masklist1: BoxMaskList holding N boxes and masks | |
| masklist2: BoxMaskList holding M boxes and masks | |
| Returns: | |
| a numpy array with shape [N, M] representing pairwise ioa scores. | |
| """ | |
| return ioa_mask(masklist1.get_masks(), masklist2.get_masks()) | |
| def gather_masklist(masklist, indices, fields=None): | |
| """Gather boxes from BoxMaskList according to indices. | |
| By default, gather returns boxes corresponding to the input index list, as | |
| well as all additional fields stored in the masklist (indexing into the | |
| first dimension). However one can optionally only gather from a | |
| subset of fields. | |
| Args: | |
| masklist: BoxMaskList holding N boxes | |
| indices: a 1-d numpy array of type int_ | |
| fields: (optional) list of fields to also gather from. If None (default), all fields | |
| are gathered from. Pass an empty fields list to only gather the box coordinates. | |
| Returns: | |
| submasklist: a BoxMaskList corresponding to the subset of the input masklist specified by indices | |
| Raises: | |
| ValueError: if specified field is not contained in masklist or if the indices are not of type int_ | |
| """ | |
| if fields is not None: | |
| if 'masks' not in fields: | |
| fields.append('masks') | |
| return boxlist_to_masklist(gather_boxlist(boxlist=masklist, indices=indices, fields=fields)) | |
| def sort_by_field_masklist(masklist, field, order=SortOrder.DESCEND): | |
| """Sort boxes and associated fields according to a scalar field. | |
| A common use case is reordering the boxes according to descending scores. | |
| Args: | |
| masklist: BoxMaskList holding N boxes. | |
| field: A BoxMaskList field for sorting and reordering the BoxMaskList. | |
| order: (Optional) 'descend' or 'ascend'. Default is descend. | |
| Returns: | |
| sorted_masklist: A sorted BoxMaskList with the field in the specified order. | |
| """ | |
| return boxlist_to_masklist(sort_by_field_boxlist(boxlist=masklist, field=field, order=order)) | |
| def non_max_suppression_mask(masklist, max_output_size=10000, iou_threshold=1.0, score_threshold=-10.0): | |
| """Non maximum suppression. | |
| This op greedily selects a subset of detection bounding boxes, pruning | |
| away boxes that have high IOU (intersection over union) overlap (> thresh) | |
| with already selected boxes. In each iteration, the detected bounding box with | |
| highest score in the available pool is selected. | |
| Args: | |
| masklist: BoxMaskList holding N boxes. Must contain a 'scores' field representing | |
| detection scores. All scores belong to the same class. | |
| max_output_size: maximum number of retained boxes | |
| iou_threshold: intersection over union threshold. | |
| score_threshold: minimum score threshold. Remove the boxes with scores | |
| less than this value. Default value is set to -10. A very | |
| low threshold to pass pretty much all the boxes, unless | |
| the user sets a different score threshold. | |
| Returns: | |
| an BoxMaskList holding M boxes where M <= max_output_size | |
| Raises: | |
| ValueError: if 'scores' field does not exist | |
| ValueError: if threshold is not in [0, 1] | |
| ValueError: if max_output_size < 0 | |
| """ | |
| if not masklist.has_field('scores'): | |
| raise ValueError('Field scores does not exist') | |
| if iou_threshold < 0. or iou_threshold > 1.0: | |
| raise ValueError('IOU threshold must be in [0, 1]') | |
| if max_output_size < 0: | |
| raise ValueError('max_output_size must be bigger than 0.') | |
| masklist = filter_scores_greater_than(masklist, score_threshold) | |
| if masklist.num_boxes() == 0: | |
| return masklist | |
| masklist = sort_by_field_boxlist(masklist, 'scores') | |
| # Prevent further computation if NMS is disabled. | |
| if iou_threshold == 1.0: | |
| if masklist.num_boxes() > max_output_size: | |
| selected_indices = np.arange(max_output_size) | |
| return gather_masklist(masklist, selected_indices) | |
| else: | |
| return masklist | |
| masks = masklist.get_masks() | |
| num_masks = masklist.num_boxes() | |
| # is_index_valid is True only for all remaining valid boxes, | |
| is_index_valid = np.full(num_masks, 1, dtype=bool) | |
| selected_indices = [] | |
| num_output = 0 | |
| for i in range(num_masks): | |
| if num_output < max_output_size: | |
| if is_index_valid[i]: | |
| num_output += 1 | |
| selected_indices.append(i) | |
| is_index_valid[i] = False | |
| valid_indices = np.where(is_index_valid)[0] | |
| if valid_indices.size == 0: | |
| break | |
| intersect_over_union = iou_mask(np.expand_dims(masks[i], axis=0), masks[valid_indices]) | |
| intersect_over_union = np.squeeze(intersect_over_union, axis=0) | |
| is_index_valid[valid_indices] = np.logical_and( | |
| is_index_valid[valid_indices], | |
| intersect_over_union <= iou_threshold) | |
| return gather_masklist(masklist, np.array(selected_indices)) | |
| def multi_class_non_max_suppression_mask(masklist, score_thresh, iou_thresh, max_output_size): | |
| """Multi-class version of non maximum suppression. | |
| This op greedily selects a subset of detection bounding boxes, pruning away boxes that have | |
| high IOU (intersection over union) overlap (> thresh) with already selected boxes. It | |
| operates independently for each class for which scores are provided (via the scores field | |
| of the input box_list), pruning boxes with score less than a provided threshold prior to | |
| applying NMS. | |
| Args: | |
| masklist: BoxMaskList holding N boxes. Must contain a 'scores' field representing detection | |
| scores. This scores field is a tensor that can be 1 dimensional (in the case of a | |
| single class) or 2-dimensional, in which case we assume that it takes the shape | |
| [num_boxes, num_classes]. We further assume that this rank is known statically and | |
| that scores.shape[1] is also known (i.e., the number of classes is fixed and known | |
| at graph construction time). | |
| score_thresh: scalar threshold for score (low scoring boxes are removed). | |
| iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap with previously | |
| selected boxes are removed). | |
| max_output_size: maximum number of retained boxes per class. | |
| Returns: | |
| a masklist holding M boxes with a rank-1 scores field representing | |
| corresponding scores for each box with scores sorted in decreasing order | |
| and a rank-1 classes field representing a class label for each box. | |
| Raises: | |
| ValueError: if iou_thresh is not in [0, 1] or if input masklist does not have a valid scores field. | |
| """ | |
| if not 0 <= iou_thresh <= 1.0: | |
| raise ValueError('thresh must be between 0 and 1') | |
| if not isinstance(masklist, MaskList): | |
| raise ValueError('masklist must be a masklist') | |
| if not masklist.has_field('scores'): | |
| raise ValueError('input masklist must have \'scores\' field') | |
| scores = masklist.get_field('scores') | |
| if len(scores.shape) == 1: | |
| scores = np.reshape(scores, [-1, 1]) | |
| elif len(scores.shape) == 2: | |
| if scores.shape[1] is None: | |
| raise ValueError('scores field must have statically defined second dimension') | |
| else: | |
| raise ValueError('scores field must be of rank 1 or 2') | |
| num_boxes = masklist.num_boxes() | |
| num_scores = scores.shape[0] | |
| num_classes = scores.shape[1] | |
| if num_boxes != num_scores: | |
| raise ValueError('Incorrect scores field length: actual vs expected.') | |
| selected_boxes_list = [] | |
| for class_idx in range(num_classes): | |
| masklist_and_class_scores = MaskList(box_data=masklist.get(), mask_data=masklist.get_masks()) | |
| class_scores = np.reshape(scores[0:num_scores, class_idx], [-1]) | |
| masklist_and_class_scores.add_field('scores', class_scores) | |
| masklist_filt = filter_scores_greater_than(masklist_and_class_scores, score_thresh) | |
| nms_result = non_max_suppression( | |
| masklist_filt, | |
| max_output_size=max_output_size, | |
| iou_threshold=iou_thresh, | |
| score_threshold=score_thresh) | |
| nms_result.add_field('classes', np.zeros_like(nms_result.get_field('scores')) + class_idx) | |
| selected_boxes_list.append(nms_result) | |
| selected_boxes = concatenate_boxlist(selected_boxes_list) | |
| sorted_boxes = sort_by_field_boxlist(selected_boxes, 'scores') | |
| return boxlist_to_masklist(boxlist=sorted_boxes) | |
| def prune_non_overlapping_masklist(masklist1, masklist2, minoverlap=0.0): | |
| """Prunes the boxes in list1 that overlap less than thresh with list2. | |
| For each mask in masklist1, we want its IOA to be more than minoverlap | |
| with at least one of the masks in masklist2. If it does not, we remove | |
| it. If the masks are not full size image, we do the pruning based on boxes. | |
| Args: | |
| masklist1: BoxMaskList holding N boxes and masks. | |
| masklist2: BoxMaskList holding M boxes and masks. | |
| minoverlap: Minimum required overlap between boxes, to count them as overlapping. | |
| Returns: | |
| A pruned masklist with size [N', 4]. | |
| """ | |
| intersection_over_area = ioa_masklist(masklist2, masklist1) # [M, N] tensor | |
| intersection_over_area = np.amax(intersection_over_area, axis=0) # [N] tensor | |
| keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap)) | |
| keep_inds = np.nonzero(keep_bool)[0] | |
| new_masklist1 = gather_masklist(masklist1, keep_inds) | |
| return new_masklist1 | |
| def concatenate_masklist(masklists, fields=None): | |
| """Concatenate list of masklists. | |
| This op concatenates a list of input masklists into a larger | |
| masklist. It also | |
| handles concatenation of masklist fields as long as the field tensor | |
| shapes are equal except for the first dimension. | |
| Args: | |
| masklists: list of BoxMaskList objects | |
| fields: optional list of fields to also concatenate. By default, all | |
| fields from the first BoxMaskList in the list are included in the concatenation. | |
| Returns: | |
| a masklist with number of boxes equal to sum([masklist.num_boxes() for masklist in masklist]) | |
| Raises: | |
| ValueError: if masklists is invalid (i.e., is not a list, is empty, or contains non | |
| masklist objects), or if requested fields are not contained in all masklists | |
| """ | |
| if fields is not None: | |
| if 'masks' not in fields: | |
| fields.append('masks') | |
| return boxlist_to_masklist(concatenate_boxlist(boxlists=masklists, fields=fields)) | |
| def filter_scores_greater_than_masklist(masklist, thresh): | |
| """Filter to keep only boxes and masks with score exceeding a given threshold. | |
| This op keeps the collection of boxes and masks whose corresponding scores are | |
| greater than the input threshold. | |
| Args: | |
| masklist: BoxMaskList holding N boxes and masks. Must contain a | |
| 'scores' field representing detection scores. | |
| thresh: scalar threshold | |
| Returns: | |
| a BoxMaskList holding M boxes and masks where M <= N | |
| Raises: | |
| ValueError: if masklist not a BoxMaskList object or if it does not have a scores field | |
| """ | |
| if not isinstance(masklist, MaskList): | |
| raise ValueError('masklist must be a BoxMaskList') | |
| if not masklist.has_field('scores'): | |
| raise ValueError('input masklist must have \'scores\' field') | |
| scores = masklist.get_field('scores') | |
| if len(scores.shape) > 2: | |
| raise ValueError('Scores should have rank 1 or 2') | |
| if len(scores.shape) == 2 and scores.shape[1] != 1: | |
| raise ValueError('Scores should have rank 1 or have shape consistent with [None, 1]') | |
| high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), [-1]).astype(np.int32) | |
| return gather_masklist(masklist, high_score_indices) | |