| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Utilities for bounding box manipulation and GIoU. |
| | """ |
| | import torch |
| |
|
| | from torchvision.ops.boxes import box_area |
| |
|
| |
|
| | def box_cxcywh_to_xyxy(x): |
| | x_c, y_c, w, h = x.unbind(-1) |
| | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] |
| | return torch.stack(b, dim=-1) |
| |
|
| |
|
| | def box_xyxy_to_cxcywh(x): |
| | x0, y0, x1, y1 = x.unbind(-1) |
| | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] |
| | return torch.stack(b, dim=-1) |
| |
|
| |
|
| | |
| | def box_iou(boxes1, boxes2): |
| | area1 = box_area(boxes1) |
| | area2 = box_area(boxes2) |
| |
|
| | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) |
| | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) |
| |
|
| | wh = (rb - lt).clamp(min=0) |
| | inter = wh[:, :, 0] * wh[:, :, 1] |
| |
|
| | union = area1[:, None] + area2 - inter |
| |
|
| | iou = inter / union |
| | return iou, union |
| |
|
| |
|
| | def generalized_box_iou(boxes1, boxes2): |
| | """ |
| | Generalized IoU from https://giou.stanford.edu/ |
| | |
| | The boxes should be in [x0, y0, x1, y1] format |
| | |
| | Returns a [N, M] pairwise matrix, where N = len(boxes1) |
| | and M = len(boxes2) |
| | """ |
| | |
| | |
| | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() |
| | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() |
| | iou, union = box_iou(boxes1, boxes2) |
| |
|
| | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) |
| | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) |
| |
|
| | wh = (rb - lt).clamp(min=0) |
| | area = wh[:, :, 0] * wh[:, :, 1] |
| |
|
| | return iou - (area - union) / area |
| |
|
| |
|
| | def masks_to_boxes(masks): |
| | """Compute the bounding boxes around the provided masks |
| | |
| | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. |
| | |
| | Returns a [N, 4] tensors, with the boxes in xyxy format |
| | """ |
| | if masks.numel() == 0: |
| | return torch.zeros((0, 4), device=masks.device) |
| |
|
| | h, w = masks.shape[-2:] |
| |
|
| | y = torch.arange(0, h, dtype=torch.float) |
| | x = torch.arange(0, w, dtype=torch.float) |
| | y, x = torch.meshgrid(y, x) |
| |
|
| | x_mask = masks * x.unsqueeze(0) |
| | x_max = x_mask.flatten(1).max(-1)[0] |
| | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] |
| |
|
| | y_mask = masks * y.unsqueeze(0) |
| | y_max = y_mask.flatten(1).max(-1)[0] |
| | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] |
| |
|
| | return torch.stack([x_min, y_min, x_max, y_max], 1) |
| |
|