Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torchvision import ops | |
| from torchvision.ops.boxes import box_area | |
| import torch.nn.functional as F | |
| def boxes_with_scores(density_map, tlrb, sort=False, validate=False): | |
| B, C, _, _ = density_map.shape # B, 1, H, W | |
| # maxpool instead of scikit local peak | |
| pooled = F.max_pool2d(density_map, 3, 1, 1) | |
| # medians over batch | |
| if validate: | |
| batch_thresh = torch.max(density_map.reshape(B, -1), dim=-1).values.view(B, C, 1, 1) / 8 | |
| else: | |
| batch_thresh = torch.median(density_map.reshape(B, -1), dim=-1).values.view(B, C, 1, 1) | |
| # binary mask of selected boxes | |
| mask = (pooled == density_map) & (density_map > batch_thresh) | |
| # need this for loop to have the same output structure | |
| # can be vectorized otherwise | |
| out_batch = [] | |
| ref_points_batch = [] | |
| for i in range(B): | |
| # select the masked density maps and box offsets | |
| bbox_scores = density_map[i, mask[i]] | |
| ref_points = mask[i].nonzero()[:, -2:] | |
| # normalize center locations | |
| bbox_centers = ref_points / torch.tensor(mask.shape[2:], device=mask.device) | |
| # select masked box offsets, permute to keep channels last | |
| tlrb_ = tlrb[i].permute(1, 2, 0) | |
| bbox_offsets = tlrb_[mask[i].permute(1, 2, 0).expand_as(tlrb_)].reshape(-1, 4) | |
| # vectorised calculation of the boxes = [ref_points_transposed[1] / ...] in original | |
| sign = torch.tensor([-1, -1, 1, 1], device=mask.device) | |
| bbox_xyxy = bbox_centers.flip(-1).repeat(1, 2) + sign * bbox_offsets | |
| # sort by bbox score if needed -- this matches the original | |
| if sort: | |
| perm = torch.argsort(bbox_scores, descending=True) | |
| bbox_scores = bbox_scores[perm] | |
| bbox_xyxy = bbox_xyxy[perm] | |
| ref_points = ref_points[perm] | |
| out_batch.append({ | |
| "pred_boxes": bbox_xyxy.unsqueeze(0), | |
| "box_v": bbox_scores.unsqueeze(0) | |
| }) | |
| ref_points_batch.append(ref_points.T) | |
| return out_batch, ref_points_batch | |
| def box_cxcywh_to_xyxy(x): | |
| x_c, y_c, w, h = x.unbind(-1) | |
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h), | |
| (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
| return torch.stack(b, dim=-1) | |
| def box_xyxy_to_cxcywh(x): | |
| x0, y0, x1, y1 = x.unbind(-1) | |
| b = [(x0 + x1) / 2, (y0 + y1) / 2, | |
| (x1 - x0), (y1 - y0)] | |
| return torch.stack(b, dim=-1) | |
| # modified from torchvision to also return the union | |
| def box_iou(boxes1, boxes2): | |
| area1 = box_area(boxes1) | |
| area2 = box_area(boxes2) | |
| lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] | |
| rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] | |
| wh = (rb - lt).clamp(min=0) # [N,M,2] | |
| inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] | |
| union = area1[:, None] + area2 - inter + 1e-16 # [N,M] | |
| iou = inter / union | |
| return iou, union | |
| def generalized_box_iou(boxes1, boxes2): | |
| """ | |
| Generalized IoU from https://giou.stanford.edu/ | |
| The boxes should be in [x0, y0, x1, y1] format | |
| Returns a [N, M] pairwise matrix, where N = len(boxes1) | |
| and M = len(boxes2) | |
| """ | |
| # degenerate boxes gives inf / nan results | |
| # so do an early check | |
| assert (boxes1[:, 2:] >= boxes1[:, :2]).all() | |
| assert (boxes2[:, 2:] >= boxes2[:, :2]).all() | |
| iou, union = box_iou(boxes1, boxes2) | |
| lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) | |
| rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) | |
| wh = (rb - lt).clamp(min=0) # [N,M,2] | |
| area = wh[:, :, 0] * wh[:, :, 1] + 1e-16 # [N,M] | |
| return iou - (area - union) / area | |
| def masks_to_boxes(masks): | |
| """Compute the bounding boxes around the provided masks | |
| The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. | |
| Returns a [N, 4] tensors, with the boxes in xyxy format | |
| """ | |
| if masks.numel() == 0: | |
| return torch.zeros((0, 4), device=masks.device) | |
| h, w = masks.shape[-2:] | |
| y = torch.arange(0, h, dtype=torch.float) | |
| x = torch.arange(0, w, dtype=torch.float) | |
| y, x = torch.meshgrid(y, x) | |
| x_mask = (masks * x.unsqueeze(0)) | |
| x_max = x_mask.flatten(1).max(-1)[0] | |
| x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] | |
| y_mask = (masks * y.unsqueeze(0)) | |
| y_max = y_mask.flatten(1).max(-1)[0] | |
| y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] | |
| return torch.stack([x_min, y_min, x_max, y_max], 1) | |
| import numpy as np | |
| class BoxList: | |
| def __init__(self, box, image_size, mode='xyxy'): | |
| device = box.device if hasattr(box, 'device') else 'cpu' | |
| if torch.is_tensor(box): | |
| box = torch.as_tensor(box, dtype=torch.float32, device=device) | |
| else: | |
| box = torch.as_tensor(np.array(box), dtype=torch.float32, device=device) | |
| self.box = box | |
| self.size = image_size | |
| self.mode = mode | |
| self.fields = {} | |
| def convert(self, mode): | |
| if mode == self.mode: | |
| return self | |
| x_min, y_min, x_max, y_max = self.split_to_xyxy() | |
| if mode == 'xyxy': | |
| box = torch.cat([x_min, y_min, x_max, y_max], -1) | |
| box = BoxList(box, self.size, mode=mode) | |
| elif mode == 'xywh': | |
| remove = 1 | |
| box = torch.cat( | |
| [x_min, y_min, x_max - x_min + remove, y_max - y_min + remove], -1 | |
| ) | |
| box = BoxList(box, self.size, mode=mode) | |
| box.copy_field(self) | |
| return box | |
| def copy_field(self, box): | |
| for k, v in box.fields.items(): | |
| self.fields[k] = v | |
| def area(self): | |
| box = self.box | |
| if self.mode == 'xyxy': | |
| remove = 1 | |
| area = (box[:, 2] - box[:, 0] + remove) * (box[:, 3] - box[:, 1] + remove) | |
| elif self.mode == 'xywh': | |
| area = box[:, 2] * box[:, 3] | |
| return area | |
| def split_to_xyxy(self): | |
| if self.mode == 'xyxy': | |
| x_min, y_min, x_max, y_max = self.box.split(1, dim=-1) | |
| return x_min, y_min, x_max, y_max | |
| elif self.mode == 'xywh': | |
| remove = 1 | |
| x_min, y_min, w, h = self.box.split(1, dim=-1) | |
| return ( | |
| x_min, | |
| y_min, | |
| x_min + (w - remove).clamp(min=0), | |
| y_min + (h - remove).clamp(min=0), | |
| ) | |
| def __len__(self): | |
| return self.box.shape[0] | |
| def __getitem__(self, index): | |
| box = BoxList(self.box[index], self.size, self.mode) | |
| return box | |
| def resize(self, size, *args, **kwargs): | |
| ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) | |
| if ratios[0] == ratios[1]: | |
| ratio = ratios[0] | |
| scaled = self.box * ratio | |
| box = BoxList(scaled, size, mode=self.mode) | |
| for k, v in self.fields.items(): | |
| if not isinstance(v, torch.Tensor): | |
| v = v.resize(size, *args, **kwargs) | |
| box.fields[k] = v | |
| return box | |
| ratio_w, ratio_h = ratios | |
| x_min, y_min, x_max, y_max = self.split_to_xyxy() | |
| scaled_x_min = x_min * ratio_w | |
| scaled_x_max = x_max * ratio_w | |
| scaled_y_min = y_min * ratio_h | |
| scaled_y_max = y_max * ratio_h | |
| scaled = torch.cat([scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max], -1) | |
| box = BoxList(scaled, size, mode='xyxy') | |
| for k, v in self.fields.items(): | |
| if not isinstance(v, torch.Tensor): | |
| v = v.resize(size, *args, **kwargs) | |
| box.fields[k] = v | |
| return box.convert(self.mode) | |
| def clip(self, remove_empty=True): | |
| remove = 1 | |
| max_width = self.size[0] - remove | |
| max_height = self.size[1] - remove | |
| self.box[:, 0].clamp_(min=0, max=max_width) | |
| self.box[:, 1].clamp_(min=0, max=max_height) | |
| self.box[:, 2].clamp_(min=0, max=max_width) | |
| self.box[:, 3].clamp_(min=0, max=max_height) | |
| if remove_empty: | |
| box = self.box | |
| keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) | |
| return self[keep] | |
| else: | |
| return self | |
| def to(self, device): | |
| box = BoxList(self.box.to(device), self.size, self.mode) | |
| for k, v in self.fields.items(): | |
| if hasattr(v, 'to'): | |
| v = v.to(device) | |
| box.fields[k] = v | |
| return box | |
| def remove_small_box(boxlist, min_size): | |
| box = boxlist.convert('xywh').box | |
| _, _, w, h = box.unbind(dim=1) | |
| keep = (w >= min_size) & (h >= min_size) | |
| keep = keep.nonzero().squeeze(1) | |
| return boxlist[keep] | |
| def boxlist_nms(boxlist, scores, threshold, max_proposal=-1): | |
| if threshold <= 0: | |
| return boxlist | |
| mode = boxlist.mode | |
| boxlist = boxlist.convert('xyxy') | |
| box = boxlist.box | |
| keep = ops.nms(box, scores, threshold) | |
| if max_proposal > 0: | |
| keep = keep[:max_proposal] | |
| boxlist = boxlist[keep] | |
| return boxlist.convert(mode) | |
| def compute_location(features): | |
| locations = [] | |
| _, _, height, width = features.shape | |
| location_per_level = compute_location_per_level( | |
| height, width, 1, features.device | |
| ) | |
| locations.append(location_per_level) | |
| return locations | |
| def compute_location_per_level(height, width, stride, device): | |
| shift_x = torch.arange( | |
| 0, width * stride, step=stride, dtype=torch.float32, device=device | |
| ) | |
| shift_y = torch.arange( | |
| 0, height * stride, step=stride, dtype=torch.float32, device=device | |
| ) | |
| shift_y, shift_x = torch.meshgrid(shift_y, shift_x) | |
| shift_x = shift_x.reshape(-1) | |
| shift_y = shift_y.reshape(-1) | |
| location = torch.stack((shift_x, shift_y), 1) + stride // 2 | |
| return location |