File size: 1,337 Bytes
ca63835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""Precision/recall evaluation at IoU >= 0.5 with class matching."""

import torch


def compute_precision_recall(det_boxes, det_labels, gt_boxes, gt_labels):
    """Compute TP, FP counts for one image. Returns (tp, fp, n_gt)."""
    n_gt = len(gt_labels)
    if len(det_boxes) == 0:
        return 0, 0, n_gt
    if n_gt == 0:
        return 0, len(det_boxes), 0

    x1 = torch.maximum(det_boxes[:, None, 0], gt_boxes[None, :, 0])
    y1 = torch.maximum(det_boxes[:, None, 1], gt_boxes[None, :, 1])
    x2 = torch.minimum(det_boxes[:, None, 2], gt_boxes[None, :, 2])
    y2 = torch.minimum(det_boxes[:, None, 3], gt_boxes[None, :, 3])
    inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
    det_area = (det_boxes[:, 2] - det_boxes[:, 0]) * (det_boxes[:, 3] - det_boxes[:, 1])
    gt_area = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1])
    union = det_area[:, None] + gt_area[None, :] - inter_area
    iou = inter_area / union.clamp(min=1e-6)

    tp, fp = 0, 0
    matched_gt = set()
    for di in range(len(det_boxes)):
        best_iou, best_gi = iou[di].max(0)
        gi = best_gi.item()
        if best_iou.item() >= 0.5 and gi not in matched_gt and det_labels[di] == gt_labels[gi]:
            tp += 1
            matched_gt.add(gi)
        else:
            fp += 1
    return tp, fp, n_gt