File size: 5,081 Bytes
7b474fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations

import math
from pathlib import Path

import pytest

from detectivesam_inference.checkpoint import resolve_checkpoint_path
from detectivesam_inference.dataset import PairDataset, prepare_sample
from detectivesam_inference.metrics import compute_f1, compute_iou, summarize_results
from detectivesam_inference.runtime import DetectiveSAMRunner, get_repo_root


def assert_exact(value: float | None, expected: float | None, *, abs_tol: float = 1e-12) -> None:
    assert value is not None
    assert expected is not None
    assert math.isclose(value, expected, rel_tol=0.0, abs_tol=abs_tol)


@pytest.fixture(scope="module")
def repo_root() -> Path:
    return get_repo_root()


@pytest.fixture(scope="module")
def baseline_runner() -> DetectiveSAMRunner:
    return DetectiveSAMRunner(checkpoint_path="detective_sam", device="cpu")


@pytest.fixture(scope="module")
def sota_runner() -> DetectiveSAMRunner:
    return DetectiveSAMRunner(checkpoint_path="detective_sam_sota", device="cpu")


def predict_metrics(
    runner: DetectiveSAMRunner,
    *,
    source_path: Path,
    target_path: Path,
    mask_path: Path,
) -> tuple[float, float]:
    sample = prepare_sample(
        source_path=source_path,
        target_path=target_path,
        mask_path=mask_path,
        img_size=runner.config.img_size,
        perturbation_type=runner.config.perturbation_type,
        perturbation_intensity=runner.config.perturbation_intensity,
    )
    prediction = runner.predict_sample(sample, threshold=0.5)
    true_mask = sample.mask.squeeze().numpy().astype("uint8")
    return compute_iou(prediction.pred_mask, true_mask), compute_f1(prediction.pred_mask, true_mask)


def test_checkpoint_alias_resolution(repo_root: Path) -> None:
    assert resolve_checkpoint_path("detective_sam", repo_root) == repo_root / "checkpoints" / "model_epoch22_batch999_score1.1114.pth"
    assert resolve_checkpoint_path("detective_sam_sota", repo_root) == repo_root / "checkpoints" / "detective_sam_sota.pth"


def test_baseline_banana_demo_metrics(repo_root: Path, baseline_runner: DetectiveSAMRunner) -> None:
    demo_root = repo_root / "demo" / "cocoglide"
    iou, f1 = predict_metrics(
        baseline_runner,
        source_path=demo_root / "source" / "banana_28809.png",
        target_path=demo_root / "target" / "banana_28809.png",
        mask_path=demo_root / "mask" / "banana_28809.png",
    )
    assert_exact(iou, 0.8566427949370513)
    assert_exact(f1, 0.9227868680750683)


def test_sota_flux_demo_metrics(repo_root: Path, sota_runner: DetectiveSAMRunner) -> None:
    demo_root = repo_root / "demo" / "flux_test"
    iou, f1 = predict_metrics(
        sota_runner,
        source_path=demo_root / "source" / "548.png",
        target_path=demo_root / "target" / "548.png",
        mask_path=demo_root / "mask" / "548.png",
    )
    assert_exact(iou, 0.8703024868799283)
    assert_exact(f1, 0.9306542583192329)


def test_sota_qwen_demo_metrics(repo_root: Path, sota_runner: DetectiveSAMRunner) -> None:
    demo_root = repo_root / "demo" / "qwen_test"
    iou, f1 = predict_metrics(
        sota_runner,
        source_path=demo_root / "source" / "166.png",
        target_path=demo_root / "target" / "166.png",
        mask_path=demo_root / "mask" / "166.png",
    )
    assert_exact(iou, 0.8297306693388413)
    assert_exact(f1, 0.9069429542203147)


def test_baseline_cocoglide_eval_summary(repo_root: Path, baseline_runner: DetectiveSAMRunner) -> None:
    dataset = PairDataset(
        root_dir=repo_root / "demo" / "cocoglide",
        img_size=baseline_runner.config.img_size,
        perturbation_type=baseline_runner.config.perturbation_type,
        perturbation_intensity=baseline_runner.config.perturbation_intensity,
    )

    per_sample_results: list[dict[str, float | str | None]] = []
    for sample in dataset:
        prediction = baseline_runner.predict_sample(sample, threshold=0.5)
        true_mask = sample.mask.squeeze().numpy().astype("uint8")
        per_sample_results.append(
            {
                "name": sample.name,
                "iou": compute_iou(prediction.pred_mask, true_mask),
                "f1": compute_f1(prediction.pred_mask, true_mask),
            }
        )

    summary = summarize_results(per_sample_results)
    assert summary["num_samples"] == 5
    assert summary["num_samples_with_gt"] == 5
    assert_exact(summary["mean_iou"], 0.5092573070481035)
    assert_exact(summary["mean_f1"], 0.6509390858765342)

    expected_by_name = {
        "airplane_139871": (0.41829717560376584, 0.5898582931686339),
        "banana_28809": (0.8566427949370513, 0.9227868680750683),
        "giraffe_296969": (0.22833093957714018, 0.3717743031951054),
        "train_221213": (0.547253866814856, 0.7073872989458688),
        "tv_453722": (0.49576175830770386, 0.662888665997994),
    }
    for result in per_sample_results:
        expected_iou, expected_f1 = expected_by_name[result["name"]]
        assert_exact(result["iou"], expected_iou)
        assert_exact(result["f1"], expected_f1)