| import torch |
|
|
|
|
| def pairwise_iou(pred_masks, gt_masks, eps=1e-6): |
| N, H, W = pred_masks.shape |
| M = gt_masks.shape[0] |
| |
| pred_flat = pred_masks.reshape(N, -1).float() |
| gt_flat = gt_masks.reshape(M, -1).float() |
| |
| intersection = torch.matmul(pred_flat, gt_flat.t()) |
| |
| area_pred = pred_flat.sum(dim=1, keepdim=True) |
| area_gt = gt_flat.sum(dim=1, keepdim=True) |
| |
| union = area_pred + area_gt.t() - intersection |
| if eps is None: |
| iou = intersection / union.clamp(min=1) |
| else: |
| iou = intersection / (union + eps) |
| return iou |
|
|
|
|
| def pairwise_iom(pred_masks, gt_masks, eps=1e-8): |
| N, H, W = pred_masks.shape |
| M = gt_masks.shape[0] |
| |
| pred_flat = pred_masks.reshape(N, -1).float() |
| gt_flat = gt_masks.reshape(M, -1).float() |
| |
| intersection = torch.matmul(pred_flat, gt_flat.t()) |
| |
| area_pred = pred_flat.sum(dim=1, keepdim=True) |
| area_gt = gt_flat.sum(dim=1, keepdim=True) |
| |
| min_area = torch.min(area_pred, area_gt) |
| iou = intersection / (min_area + eps) |
| return iou |
|
|