| """Precision/recall evaluation at IoU >= 0.5 with class matching.""" | |
| import torch | |
| def compute_precision_recall(det_boxes, det_labels, gt_boxes, gt_labels): | |
| """Compute TP, FP counts for one image. Returns (tp, fp, n_gt).""" | |
| n_gt = len(gt_labels) | |
| if len(det_boxes) == 0: | |
| return 0, 0, n_gt | |
| if n_gt == 0: | |
| return 0, len(det_boxes), 0 | |
| x1 = torch.maximum(det_boxes[:, None, 0], gt_boxes[None, :, 0]) | |
| y1 = torch.maximum(det_boxes[:, None, 1], gt_boxes[None, :, 1]) | |
| x2 = torch.minimum(det_boxes[:, None, 2], gt_boxes[None, :, 2]) | |
| y2 = torch.minimum(det_boxes[:, None, 3], gt_boxes[None, :, 3]) | |
| inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0) | |
| det_area = (det_boxes[:, 2] - det_boxes[:, 0]) * (det_boxes[:, 3] - det_boxes[:, 1]) | |
| gt_area = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) | |
| union = det_area[:, None] + gt_area[None, :] - inter_area | |
| iou = inter_area / union.clamp(min=1e-6) | |
| tp, fp = 0, 0 | |
| matched_gt = set() | |
| for di in range(len(det_boxes)): | |
| best_iou, best_gi = iou[di].max(0) | |
| gi = best_gi.item() | |
| if best_iou.item() >= 0.5 and gi not in matched_gt and det_labels[di] == gt_labels[gi]: | |
| tp += 1 | |
| matched_gt.add(gi) | |
| else: | |
| fp += 1 | |
| return tp, fp, n_gt | |