| |
|
|
| import torch, math |
|
|
|
|
|
|
| def ciou(bboxes1, bboxes2): |
| bboxes1 = torch.sigmoid(bboxes1) |
| bboxes2 = torch.sigmoid(bboxes2) |
| rows = bboxes1.shape[0] |
| cols = bboxes2.shape[0] |
| cious = torch.zeros((rows, cols)) |
| if rows * cols == 0: |
| return cious |
| exchange = False |
| if bboxes1.shape[0] > bboxes2.shape[0]: |
| bboxes1, bboxes2 = bboxes2, bboxes1 |
| cious = torch.zeros((cols, rows)) |
| exchange = True |
| w1 = torch.exp(bboxes1[:, 2]) |
| h1 = torch.exp(bboxes1[:, 3]) |
| w2 = torch.exp(bboxes2[:, 2]) |
| h2 = torch.exp(bboxes2[:, 3]) |
| area1 = w1 * h1 |
| area2 = w2 * h2 |
| center_x1 = bboxes1[:, 0] |
| center_y1 = bboxes1[:, 1] |
| center_x2 = bboxes2[:, 0] |
| center_y2 = bboxes2[:, 1] |
|
|
| inter_l = torch.max(center_x1 - w1 / 2,center_x2 - w2 / 2) |
| inter_r = torch.min(center_x1 + w1 / 2,center_x2 + w2 / 2) |
| inter_t = torch.max(center_y1 - h1 / 2,center_y2 - h2 / 2) |
| inter_b = torch.min(center_y1 + h1 / 2,center_y2 + h2 / 2) |
| inter_area = torch.clamp((inter_r - inter_l),min=0) * torch.clamp((inter_b - inter_t),min=0) |
|
|
| c_l = torch.min(center_x1 - w1 / 2,center_x2 - w2 / 2) |
| c_r = torch.max(center_x1 + w1 / 2,center_x2 + w2 / 2) |
| c_t = torch.min(center_y1 - h1 / 2,center_y2 - h2 / 2) |
| c_b = torch.max(center_y1 + h1 / 2,center_y2 + h2 / 2) |
|
|
| inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2 |
| c_diag = torch.clamp((c_r - c_l),min=0)**2 + torch.clamp((c_b - c_t),min=0)**2 |
|
|
| union = area1+area2-inter_area |
| u = (inter_diag) / c_diag |
| iou = inter_area / union |
| v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2) |
| with torch.no_grad(): |
| S = (iou>0.5).float() |
| alpha= S*v/(1-iou+v) |
| cious = iou - u - alpha * v |
| cious = torch.clamp(cious,min=-1.0,max = 1.0) |
| if exchange: |
| cious = cious.T |
| return 1-cious |
|
|
| def diou(bboxes1, bboxes2): |
| bboxes1 = torch.sigmoid(bboxes1) |
| bboxes2 = torch.sigmoid(bboxes2) |
| rows = bboxes1.shape[0] |
| cols = bboxes2.shape[0] |
| cious = torch.zeros((rows, cols)) |
| if rows * cols == 0: |
| return cious |
| exchange = False |
| if bboxes1.shape[0] > bboxes2.shape[0]: |
| bboxes1, bboxes2 = bboxes2, bboxes1 |
| cious = torch.zeros((cols, rows)) |
| exchange = True |
| w1 = torch.exp(bboxes1[:, 2]) |
| h1 = torch.exp(bboxes1[:, 3]) |
| w2 = torch.exp(bboxes2[:, 2]) |
| h2 = torch.exp(bboxes2[:, 3]) |
| area1 = w1 * h1 |
| area2 = w2 * h2 |
| center_x1 = bboxes1[:, 0] |
| center_y1 = bboxes1[:, 1] |
| center_x2 = bboxes2[:, 0] |
| center_y2 = bboxes2[:, 1] |
|
|
| inter_l = torch.max(center_x1 - w1 / 2,center_x2 - w2 / 2) |
| inter_r = torch.min(center_x1 + w1 / 2,center_x2 + w2 / 2) |
| inter_t = torch.max(center_y1 - h1 / 2,center_y2 - h2 / 2) |
| inter_b = torch.min(center_y1 + h1 / 2,center_y2 + h2 / 2) |
| inter_area = torch.clamp((inter_r - inter_l),min=0) * torch.clamp((inter_b - inter_t),min=0) |
|
|
| c_l = torch.min(center_x1 - w1 / 2,center_x2 - w2 / 2) |
| c_r = torch.max(center_x1 + w1 / 2,center_x2 + w2 / 2) |
| c_t = torch.min(center_y1 - h1 / 2,center_y2 - h2 / 2) |
| c_b = torch.max(center_y1 + h1 / 2,center_y2 + h2 / 2) |
|
|
| inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2 |
| c_diag = torch.clamp((c_r - c_l),min=0)**2 + torch.clamp((c_b - c_t),min=0)**2 |
|
|
| union = area1+area2-inter_area |
| u = (inter_diag) / c_diag |
| iou = inter_area / union |
| dious = iou - u |
| dious = torch.clamp(dious,min=-1.0,max = 1.0) |
| if exchange: |
| dious = dious.T |
| return 1-dious |
|
|
|
|
| if __name__ == "__main__": |
| x = torch.rand(10, 4) |
| y = torch.rand(10,4) |
| import ipdb;ipdb.set_trace() |
| cxy = ciou(x, y) |
| dxy = diou(x, y) |
| print(cxy.shape, dxy.shape) |
|
|