"""Precision/recall evaluation at IoU >= 0.5 with class matching.""" import torch def compute_precision_recall(det_boxes, det_labels, gt_boxes, gt_labels): """Compute TP, FP counts for one image. Returns (tp, fp, n_gt).""" n_gt = len(gt_labels) if len(det_boxes) == 0: return 0, 0, n_gt if n_gt == 0: return 0, len(det_boxes), 0 x1 = torch.maximum(det_boxes[:, None, 0], gt_boxes[None, :, 0]) y1 = torch.maximum(det_boxes[:, None, 1], gt_boxes[None, :, 1]) x2 = torch.minimum(det_boxes[:, None, 2], gt_boxes[None, :, 2]) y2 = torch.minimum(det_boxes[:, None, 3], gt_boxes[None, :, 3]) inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0) det_area = (det_boxes[:, 2] - det_boxes[:, 0]) * (det_boxes[:, 3] - det_boxes[:, 1]) gt_area = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) union = det_area[:, None] + gt_area[None, :] - inter_area iou = inter_area / union.clamp(min=1e-6) tp, fp = 0, 0 matched_gt = set() for di in range(len(det_boxes)): best_iou, best_gi = iou[di].max(0) gi = best_gi.item() if best_iou.item() >= 0.5 and gi not in matched_gt and det_labels[di] == gt_labels[gi]: tp += 1 matched_gt.add(gi) else: fp += 1 return tp, fp, n_gt