Spaces:
Sleeping
Sleeping
| """Adapted from Fast R-CNN | |
| Written by Sergey Karayev | |
| Licensed under The MIT License | |
| Copyright (c) 2015 Microsoft. | |
| """ | |
| import numpy as np | |
| from skimage.measure import regionprops | |
| def _union_slice(a: tuple[slice], b: tuple[slice]): | |
| """Returns the union of slice tuples a and b.""" | |
| starts = tuple(min(_a.start, _b.start) for _a, _b in zip(a, b)) | |
| stops = tuple(max(_a.stop, _b.stop) for _a, _b in zip(a, b)) | |
| return tuple(slice(start, stop) for start, stop in zip(starts, stops)) | |
| def get_labels_with_overlap(gt_frame, res_frame): | |
| """Get all labels IDs in gt_frame and res_frame whose bounding boxes | |
| overlap. | |
| Args: | |
| gt_frame (np.ndarray): ground truth segmentation for a single frame | |
| res_frame (np.ndarray): result segmentation for a given frame | |
| Returns: | |
| overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes | |
| overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes | |
| intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area) | |
| """ | |
| gt_frame = gt_frame.astype(np.uint16, copy=False) | |
| res_frame = res_frame.astype(np.uint16, copy=False) | |
| gt_props = regionprops(gt_frame) | |
| gt_boxes = [np.array(gt_prop.bbox) for gt_prop in gt_props] | |
| gt_boxes = np.array(gt_boxes).astype(np.float64) | |
| gt_box_labels = np.asarray( | |
| [int(gt_prop.label) for gt_prop in gt_props], dtype=np.uint16 | |
| ) | |
| res_props = regionprops(res_frame) | |
| res_boxes = [np.array(res_prop.bbox) for res_prop in res_props] | |
| res_boxes = np.array(res_boxes).astype(np.float64) | |
| res_box_labels = np.asarray( | |
| [int(res_prop.label) for res_prop in res_props], dtype=np.uint16 | |
| ) | |
| if len(gt_props) == 0 or len(res_props) == 0: | |
| return [], [], [] | |
| if gt_frame.ndim == 3: | |
| overlaps = compute_overlap_3D(gt_boxes, res_boxes) | |
| else: | |
| overlaps = compute_overlap( | |
| gt_boxes, res_boxes | |
| ) # has the form [gt_bbox, res_bbox] | |
| # Find the bboxes that have overlap at all (ind_ corresponds to box number - starting at 0) | |
| ind_gt, ind_res = np.nonzero(overlaps) | |
| ind_gt = np.asarray(ind_gt, dtype=np.uint16) | |
| ind_res = np.asarray(ind_res, dtype=np.uint16) | |
| overlapping_gt_labels = gt_box_labels[ind_gt] | |
| overlapping_res_labels = res_box_labels[ind_res] | |
| intersections_over_gt = [] | |
| for i, j in zip(ind_gt, ind_res): | |
| sslice = _union_slice(gt_props[i].slice, res_props[j].slice) | |
| gt_mask = gt_frame[sslice] == gt_box_labels[i] | |
| res_mask = res_frame[sslice] == res_box_labels[j] | |
| area_inter = np.count_nonzero(np.logical_and(gt_mask, res_mask)) | |
| area_gt = np.count_nonzero(gt_mask) | |
| intersections_over_gt.append(area_inter / area_gt) | |
| return overlapping_gt_labels, overlapping_res_labels, intersections_over_gt | |
| def compute_overlap(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray: | |
| """Args: | |
| a: (N, 4) ndarray of float | |
| b: (K, 4) ndarray of float. | |
| Returns: | |
| overlaps: (N, K) ndarray of overlap between boxes and query_boxes | |
| """ | |
| N = boxes.shape[0] | |
| K = query_boxes.shape[0] | |
| overlaps = np.zeros((N, K), dtype=np.float64) | |
| for k in range(K): | |
| box_area = (query_boxes[k, 2] - query_boxes[k, 0] + 1) * ( | |
| query_boxes[k, 3] - query_boxes[k, 1] + 1 | |
| ) | |
| for n in range(N): | |
| iw = ( | |
| min(boxes[n, 2], query_boxes[k, 2]) | |
| - max(boxes[n, 0], query_boxes[k, 0]) | |
| + 1 | |
| ) | |
| if iw > 0: | |
| ih = ( | |
| min(boxes[n, 3], query_boxes[k, 3]) | |
| - max(boxes[n, 1], query_boxes[k, 1]) | |
| + 1 | |
| ) | |
| if ih > 0: | |
| ua = np.float64( | |
| (boxes[n, 2] - boxes[n, 0] + 1) | |
| * (boxes[n, 3] - boxes[n, 1] + 1) | |
| + box_area | |
| - iw * ih | |
| ) | |
| overlaps[n, k] = iw * ih / ua | |
| return overlaps | |
| def compute_overlap_3D(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray: | |
| """Args: | |
| a: (N, 6) ndarray of float | |
| b: (K, 6) ndarray of float. | |
| Returns: | |
| overlaps: (N, K) ndarray of overlap between boxes and query_boxes | |
| """ | |
| N = boxes.shape[0] | |
| K = query_boxes.shape[0] | |
| overlaps = np.zeros((N, K), dtype=np.float64) | |
| for k in range(K): | |
| box_volume = ( | |
| (query_boxes[k, 3] - query_boxes[k, 0] + 1) | |
| * (query_boxes[k, 4] - query_boxes[k, 1] + 1) | |
| * (query_boxes[k, 5] - query_boxes[k, 2] + 1) | |
| ) | |
| for n in range(N): | |
| id_ = ( | |
| min(boxes[n, 3], query_boxes[k, 3]) | |
| - max(boxes[n, 0], query_boxes[k, 0]) | |
| + 1 | |
| ) | |
| if id_ > 0: | |
| iw = ( | |
| min(boxes[n, 4], query_boxes[k, 4]) | |
| - max(boxes[n, 1], query_boxes[k, 1]) | |
| + 1 | |
| ) | |
| if iw > 0: | |
| ih = ( | |
| min(boxes[n, 5], query_boxes[k, 5]) | |
| - max(boxes[n, 2], query_boxes[k, 2]) | |
| + 1 | |
| ) | |
| if ih > 0: | |
| ua = np.float64( | |
| (boxes[n, 3] - boxes[n, 0] + 1) | |
| * (boxes[n, 4] - boxes[n, 1] + 1) | |
| * (boxes[n, 5] - boxes[n, 2] + 1) | |
| + box_volume | |
| - iw * ih * id_ | |
| ) | |
| overlaps[n, k] = iw * ih * id_ / ua | |
| return overlaps | |
| try: | |
| import numba | |
| except ImportError: | |
| import os | |
| import warnings | |
| if not os.getenv("NO_JIT_WARNING", False): | |
| warnings.warn( | |
| "Numba not installed, falling back to slower numpy implementation. " | |
| "Install numba for a significant speedup. Set the environment " | |
| "variable NO_JIT_WARNING=1 to disable this warning.", | |
| stacklevel=2, | |
| ) | |
| else: | |
| # compute_overlap 2d and 3d have the same signature | |
| signature = [ | |
| "f8[:,::1](f8[:,::1], f8[:,::1])", | |
| numba.types.Array(numba.float64, 2, "C", readonly=True)( | |
| numba.types.Array(numba.float64, 2, "C", readonly=True), | |
| numba.types.Array(numba.float64, 2, "C", readonly=True), | |
| ), | |
| ] | |
| # variables that appear in the body of each function | |
| common_locals = { | |
| "N": numba.uint64, | |
| "K": numba.uint64, | |
| "overlaps": numba.types.Array(numba.float64, 2, "C"), | |
| "iw": numba.float64, | |
| "ih": numba.float64, | |
| "ua": numba.float64, | |
| "n": numba.uint64, | |
| "k": numba.uint64, | |
| } | |
| compute_overlap = numba.njit( | |
| signature, | |
| locals={**common_locals, "box_area": numba.float64}, | |
| fastmath=True, | |
| nogil=True, | |
| boundscheck=False, | |
| )(compute_overlap) | |
| compute_overlap_3D = numba.njit( | |
| signature, | |
| locals={**common_locals, "id_": numba.float64, "box_volume": numba.float64}, | |
| fastmath=True, | |
| nogil=True, | |
| boundscheck=False, | |
| )(compute_overlap_3D) | |