| |
| |
| |
| from typing import Tuple |
|
|
| import torch |
| from torch import Tensor |
| from torch.autograd import Function |
|
|
| from ..utils import ext_loader |
|
|
| EPSILON = 1e-8 |
| ext_module = ext_loader.load_ext('_ext', |
| ['diff_iou_rotated_sort_vertices_forward']) |
|
|
|
|
| class SortVertices(Function): |
|
|
| @staticmethod |
| def forward(ctx, vertices, mask, num_valid): |
| idx = ext_module.diff_iou_rotated_sort_vertices_forward( |
| vertices, mask, num_valid) |
| if torch.__version__ != 'parrots': |
| ctx.mark_non_differentiable(idx) |
| return idx |
|
|
| @staticmethod |
| def backward(ctx, gradout): |
| return () |
|
|
|
|
| def box_intersection(corners1: Tensor, |
| corners2: Tensor) -> Tuple[Tensor, Tensor]: |
| """Find intersection points of rectangles. |
| Convention: if two edges are collinear, there is no intersection point. |
| |
| Args: |
| corners1 (Tensor): (B, N, 4, 2) First batch of boxes. |
| corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. |
| |
| Returns: |
| Tuple: |
| - Tensor: (B, N, 4, 4, 2) Intersections. |
| - Tensor: (B, N, 4, 4) Valid intersections mask. |
| """ |
| |
| |
| line1 = torch.cat([corners1, corners1[:, :, [1, 2, 3, 0], :]], dim=3) |
| line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3) |
| |
| |
| line1_ext = line1.unsqueeze(3) |
| line2_ext = line2.unsqueeze(2) |
| x1, y1, x2, y2 = line1_ext.split([1, 1, 1, 1], dim=-1) |
| x3, y3, x4, y4 = line2_ext.split([1, 1, 1, 1], dim=-1) |
| |
| numerator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) |
| denumerator_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4) |
| t = denumerator_t / numerator |
| t[numerator == .0] = -1. |
| mask_t = (t > 0) & (t < 1) |
| denumerator_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3) |
| u = -denumerator_u / numerator |
| u[numerator == .0] = -1. |
| mask_u = (u > 0) & (u < 1) |
| mask = mask_t * mask_u |
| |
| t = denumerator_t / (numerator + EPSILON) |
| intersections = torch.stack([x1 + t * (x2 - x1), y1 + t * (y2 - y1)], |
| dim=-1) |
| intersections = intersections * mask.float().unsqueeze(-1) |
| return intersections, mask |
|
|
|
|
| def box1_in_box2(corners1: Tensor, corners2: Tensor) -> Tensor: |
| """Check if corners of box1 lie in box2. |
| Convention: if a corner is exactly on the edge of the other box, |
| it's also a valid point. |
| |
| Args: |
| corners1 (Tensor): (B, N, 4, 2) First batch of boxes. |
| corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. |
| |
| Returns: |
| Tensor: (B, N, 4) Intersection. |
| """ |
| |
| a = corners2[:, :, 0:1, :] |
| b = corners2[:, :, 1:2, :] |
| d = corners2[:, :, 3:4, :] |
| |
| ab = b - a |
| am = corners1 - a |
| ad = d - a |
| prod_ab = torch.sum(ab * am, dim=-1) |
| norm_ab = torch.sum(ab * ab, dim=-1) |
| prod_ad = torch.sum(ad * am, dim=-1) |
| norm_ad = torch.sum(ad * ad, dim=-1) |
| |
| |
| cond1 = (prod_ab / norm_ab > -1e-6) * (prod_ab / norm_ab < 1 + 1e-6 |
| ) |
| cond2 = (prod_ad / norm_ad > -1e-6) * (prod_ad / norm_ad < 1 + 1e-6 |
| ) |
| return cond1 * cond2 |
|
|
|
|
| def box_in_box(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor, Tensor]: |
| """Check if corners of two boxes lie in each other. |
| |
| Args: |
| corners1 (Tensor): (B, N, 4, 2) First batch of boxes. |
| corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. |
| |
| Returns: |
| Tuple: |
| - Tensor: (B, N, 4) True if i-th corner of box1 is in box2. |
| - Tensor: (B, N, 4) True if i-th corner of box2 is in box1. |
| """ |
| c1_in_2 = box1_in_box2(corners1, corners2) |
| c2_in_1 = box1_in_box2(corners2, corners1) |
| return c1_in_2, c2_in_1 |
|
|
|
|
| def build_vertices(corners1: Tensor, corners2: Tensor, c1_in_2: Tensor, |
| c2_in_1: Tensor, intersections: Tensor, |
| valid_mask: Tensor) -> Tuple[Tensor, Tensor]: |
| """Find vertices of intersection area. |
| |
| Args: |
| corners1 (Tensor): (B, N, 4, 2) First batch of boxes. |
| corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. |
| c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2. |
| c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1. |
| intersections (Tensor): (B, N, 4, 4, 2) Intersections. |
| valid_mask (Tensor): (B, N, 4, 4) Valid intersections mask. |
| |
| Returns: |
| Tuple: |
| - Tensor: (B, N, 24, 2) Vertices of intersection area; |
| only some elements are valid. |
| - Tensor: (B, N, 24) Mask of valid elements in vertices. |
| """ |
| |
| |
| B = corners1.size()[0] |
| N = corners1.size()[1] |
| |
| vertices = torch.cat( |
| [corners1, corners2, |
| intersections.view([B, N, -1, 2])], dim=2) |
| |
| mask = torch.cat([c1_in_2, c2_in_1, valid_mask.view([B, N, -1])], dim=2) |
| return vertices, mask |
|
|
|
|
| def sort_indices(vertices: Tensor, mask: Tensor) -> Tensor: |
| """Sort indices. |
| Note: |
| why 9? the polygon has maximal 8 vertices. |
| +1 to duplicate the first element. |
| the index should have following structure: |
| (A, B, C, ... , A, X, X, X) |
| and X indicates the index of arbitrary elements in the last |
| 16 (intersections not corners) with value 0 and mask False. |
| (cause they have zero value and zero gradient) |
| |
| Args: |
| vertices (Tensor): (B, N, 24, 2) Box vertices. |
| mask (Tensor): (B, N, 24) Mask. |
| |
| Returns: |
| Tensor: (B, N, 9) Sorted indices. |
| |
| """ |
| num_valid = torch.sum(mask.int(), dim=2).int() |
| mean = torch.sum( |
| vertices * mask.float().unsqueeze(-1), dim=2, |
| keepdim=True) / num_valid.unsqueeze(-1).unsqueeze(-1) |
| vertices_normalized = vertices - mean |
| return SortVertices.apply(vertices_normalized, mask, num_valid).long() |
|
|
|
|
| def calculate_area(idx_sorted: Tensor, |
| vertices: Tensor) -> Tuple[Tensor, Tensor]: |
| """Calculate area of intersection. |
| |
| Args: |
| idx_sorted (Tensor): (B, N, 9) Sorted vertex ids. |
| vertices (Tensor): (B, N, 24, 2) Vertices. |
| |
| Returns: |
| Tuple: |
| - Tensor (B, N): Area of intersection. |
| - Tensor: (B, N, 9, 2) Vertices of polygon with zero padding. |
| """ |
| idx_ext = idx_sorted.unsqueeze(-1).repeat([1, 1, 1, 2]) |
| selected = torch.gather(vertices, 2, idx_ext) |
| total = selected[:, :, 0:-1, 0] * selected[:, :, 1:, 1] \ |
| - selected[:, :, 0:-1, 1] * selected[:, :, 1:, 0] |
| total = torch.sum(total, dim=2) |
| area = torch.abs(total) / 2 |
| return area, selected |
|
|
|
|
| def oriented_box_intersection_2d(corners1: Tensor, |
| corners2: Tensor) -> Tuple[Tensor, Tensor]: |
| """Calculate intersection area of 2d rotated boxes. |
| |
| Args: |
| corners1 (Tensor): (B, N, 4, 2) First batch of boxes. |
| corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. |
| |
| Returns: |
| Tuple: |
| - Tensor (B, N): Area of intersection. |
| - Tensor (B, N, 9, 2): Vertices of polygon with zero padding. |
| """ |
| intersections, valid_mask = box_intersection(corners1, corners2) |
| c12, c21 = box_in_box(corners1, corners2) |
| vertices, mask = build_vertices(corners1, corners2, c12, c21, |
| intersections, valid_mask) |
| sorted_indices = sort_indices(vertices, mask) |
| return calculate_area(sorted_indices, vertices) |
|
|
|
|
| def box2corners(box: Tensor) -> Tensor: |
| """Convert rotated 2d box coordinate to corners. |
| |
| Args: |
| box (Tensor): (B, N, 5) with x, y, w, h, alpha. |
| |
| Returns: |
| Tensor: (B, N, 4, 2) Corners. |
| """ |
| B = box.size()[0] |
| x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1) |
| x4 = box.new_tensor([0.5, -0.5, -0.5, 0.5]).to(box.device) |
| x4 = x4 * w |
| y4 = box.new_tensor([0.5, 0.5, -0.5, -0.5]).to(box.device) |
| y4 = y4 * h |
| corners = torch.stack([x4, y4], dim=-1) |
| sin = torch.sin(alpha) |
| cos = torch.cos(alpha) |
| row1 = torch.cat([cos, sin], dim=-1) |
| row2 = torch.cat([-sin, cos], dim=-1) |
| rot_T = torch.stack([row1, row2], dim=-2) |
| rotated = torch.bmm(corners.view([-1, 4, 2]), rot_T.view([-1, 2, 2])) |
| rotated = rotated.view([B, -1, 4, 2]) |
| rotated[..., 0] += x |
| rotated[..., 1] += y |
| return rotated |
|
|
|
|
| def diff_iou_rotated_2d(box1: Tensor, box2: Tensor) -> Tensor: |
| """Calculate differentiable iou of rotated 2d boxes. |
| |
| Args: |
| box1 (Tensor): (B, N, 5) First box. |
| box2 (Tensor): (B, N, 5) Second box. |
| |
| Returns: |
| Tensor: (B, N) IoU. |
| """ |
| corners1 = box2corners(box1) |
| corners2 = box2corners(box2) |
| intersection, _ = oriented_box_intersection_2d(corners1, |
| corners2) |
| area1 = box1[:, :, 2] * box1[:, :, 3] |
| area2 = box2[:, :, 2] * box2[:, :, 3] |
| union = area1 + area2 - intersection |
| iou = intersection / union |
| return iou |
|
|
|
|
| def diff_iou_rotated_3d(box3d1: Tensor, box3d2: Tensor) -> Tensor: |
| """Calculate differentiable iou of rotated 3d boxes. |
| |
| Args: |
| box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha). |
| box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha). |
| |
| Returns: |
| Tensor: (B, N) IoU. |
| """ |
| box1 = box3d1[..., [0, 1, 3, 4, 6]] |
| box2 = box3d2[..., [0, 1, 3, 4, 6]] |
| corners1 = box2corners(box1) |
| corners2 = box2corners(box2) |
| intersection, _ = oriented_box_intersection_2d(corners1, corners2) |
| zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5 |
| zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5 |
| zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5 |
| zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5 |
| z_overlap = (torch.min(zmax1, zmax2) - |
| torch.max(zmin1, zmin2)).clamp_(min=0.) |
| intersection_3d = intersection * z_overlap |
| volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5] |
| volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5] |
| union_3d = volume1 + volume2 - intersection_3d |
| return intersection_3d / union_3d |
|
|