File size: 3,981 Bytes
1fc7794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Unit tests for evaluation matching and metrics."""

import numpy as np
import pytest

from src.evaluate import match_detections_to_gt, compute_f1, compute_average_precision


class TestComputeF1:
    def test_perfect_score(self):
        f1, p, r = compute_f1(10, 0, 0)
        assert f1 == pytest.approx(1.0, abs=0.001)
        assert p == pytest.approx(1.0, abs=0.001)
        assert r == pytest.approx(1.0, abs=0.001)

    def test_zero_detections(self):
        f1, p, r = compute_f1(0, 0, 10)
        assert f1 == pytest.approx(0.0, abs=0.01)
        assert r == pytest.approx(0.0, abs=0.01)

    def test_all_false_positives(self):
        f1, p, r = compute_f1(0, 10, 0)
        assert p == pytest.approx(0.0, abs=0.01)


class TestMatchDetections:
    def test_perfect_matching(self):
        """Detections at exact GT locations should all match."""
        gt_6nm = np.array([[100.0, 100.0], [200.0, 200.0]])
        gt_12nm = np.array([[300.0, 300.0]])

        dets = [
            {"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9},
            {"x": 200.0, "y": 200.0, "class": "6nm", "conf": 0.8},
            {"x": 300.0, "y": 300.0, "class": "12nm", "conf": 0.7},
        ]

        results = match_detections_to_gt(dets, gt_6nm, gt_12nm)
        assert results["6nm"]["tp"] == 2
        assert results["6nm"]["fp"] == 0
        assert results["6nm"]["fn"] == 0
        assert results["12nm"]["tp"] == 1

    def test_wrong_class_no_match(self):
        """Detection near GT but wrong class should not match."""
        gt_6nm = np.array([[100.0, 100.0]])
        gt_12nm = np.empty((0, 2))

        dets = [
            {"x": 100.0, "y": 100.0, "class": "12nm", "conf": 0.9},
        ]

        results = match_detections_to_gt(dets, gt_6nm, gt_12nm)
        assert results["6nm"]["fn"] == 1  # missed
        assert results["12nm"]["fp"] == 1  # false positive

    def test_beyond_radius_no_match(self):
        """Detection beyond match radius should not match."""
        gt_6nm = np.array([[100.0, 100.0]])
        gt_12nm = np.empty((0, 2))

        dets = [
            {"x": 120.0, "y": 100.0, "class": "6nm", "conf": 0.9},  # 20px away > 9px radius
        ]

        results = match_detections_to_gt(
            dets, gt_6nm, gt_12nm, match_radii={"6nm": 9.0, "12nm": 15.0}
        )
        assert results["6nm"]["tp"] == 0
        assert results["6nm"]["fp"] == 1
        assert results["6nm"]["fn"] == 1

    def test_within_radius_matches(self):
        """Detection within match radius should match."""
        gt_6nm = np.array([[100.0, 100.0]])
        gt_12nm = np.empty((0, 2))

        dets = [
            {"x": 105.0, "y": 100.0, "class": "6nm", "conf": 0.9},  # 5px away < 9px
        ]

        results = match_detections_to_gt(
            dets, gt_6nm, gt_12nm, match_radii={"6nm": 9.0, "12nm": 15.0}
        )
        assert results["6nm"]["tp"] == 1

    def test_no_detections(self):
        """No detections: all GT are false negatives."""
        gt_6nm = np.array([[100.0, 100.0], [200.0, 200.0]])
        results = match_detections_to_gt([], gt_6nm, np.empty((0, 2)))
        assert results["6nm"]["fn"] == 2
        assert results["6nm"]["f1"] == pytest.approx(0.0, abs=0.01)

    def test_no_ground_truth(self):
        """No GT: all detections are false positives."""
        dets = [{"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9}]
        results = match_detections_to_gt(dets, np.empty((0, 2)), np.empty((0, 2)))
        assert results["6nm"]["fp"] == 1


class TestAveragePrecision:
    def test_perfect_ap(self):
        """All detections match in rank order → AP = 1.0."""
        gt = np.array([[100.0, 100.0], [200.0, 200.0]])
        dets = [
            {"x": 100.0, "y": 100.0, "class": "6nm", "conf": 0.9},
            {"x": 200.0, "y": 200.0, "class": "6nm", "conf": 0.8},
        ]
        ap = compute_average_precision(dets, gt, match_radius=9.0)
        assert ap == pytest.approx(1.0, abs=0.01)