phanerozoic's picture
Restructure: one folder per head, shared losses/utils, registry runner
ca63835 verified
"""Precision/recall evaluation at IoU >= 0.5 with class matching."""
import torch
def compute_precision_recall(det_boxes, det_labels, gt_boxes, gt_labels):
"""Compute TP, FP counts for one image. Returns (tp, fp, n_gt)."""
n_gt = len(gt_labels)
if len(det_boxes) == 0:
return 0, 0, n_gt
if n_gt == 0:
return 0, len(det_boxes), 0
x1 = torch.maximum(det_boxes[:, None, 0], gt_boxes[None, :, 0])
y1 = torch.maximum(det_boxes[:, None, 1], gt_boxes[None, :, 1])
x2 = torch.minimum(det_boxes[:, None, 2], gt_boxes[None, :, 2])
y2 = torch.minimum(det_boxes[:, None, 3], gt_boxes[None, :, 3])
inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
det_area = (det_boxes[:, 2] - det_boxes[:, 0]) * (det_boxes[:, 3] - det_boxes[:, 1])
gt_area = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1])
union = det_area[:, None] + gt_area[None, :] - inter_area
iou = inter_area / union.clamp(min=1e-6)
tp, fp = 0, 0
matched_gt = set()
for di in range(len(det_boxes)):
best_iou, best_gi = iou[di].max(0)
gi = best_gi.item()
if best_iou.item() >= 0.5 and gi not in matched_gt and det_labels[di] == gt_labels[gi]:
tp += 1
matched_gt.add(gi)
else:
fp += 1
return tp, fp, n_gt