DetectiveSAM / tests /test_regression.py
Gertlek's picture
Publish DetectiveSAM inference bundle
7b474fb verified
from __future__ import annotations
import math
from pathlib import Path
import pytest
from detectivesam_inference.checkpoint import resolve_checkpoint_path
from detectivesam_inference.dataset import PairDataset, prepare_sample
from detectivesam_inference.metrics import compute_f1, compute_iou, summarize_results
from detectivesam_inference.runtime import DetectiveSAMRunner, get_repo_root
def assert_exact(value: float | None, expected: float | None, *, abs_tol: float = 1e-12) -> None:
assert value is not None
assert expected is not None
assert math.isclose(value, expected, rel_tol=0.0, abs_tol=abs_tol)
@pytest.fixture(scope="module")
def repo_root() -> Path:
return get_repo_root()
@pytest.fixture(scope="module")
def baseline_runner() -> DetectiveSAMRunner:
return DetectiveSAMRunner(checkpoint_path="detective_sam", device="cpu")
@pytest.fixture(scope="module")
def sota_runner() -> DetectiveSAMRunner:
return DetectiveSAMRunner(checkpoint_path="detective_sam_sota", device="cpu")
def predict_metrics(
runner: DetectiveSAMRunner,
*,
source_path: Path,
target_path: Path,
mask_path: Path,
) -> tuple[float, float]:
sample = prepare_sample(
source_path=source_path,
target_path=target_path,
mask_path=mask_path,
img_size=runner.config.img_size,
perturbation_type=runner.config.perturbation_type,
perturbation_intensity=runner.config.perturbation_intensity,
)
prediction = runner.predict_sample(sample, threshold=0.5)
true_mask = sample.mask.squeeze().numpy().astype("uint8")
return compute_iou(prediction.pred_mask, true_mask), compute_f1(prediction.pred_mask, true_mask)
def test_checkpoint_alias_resolution(repo_root: Path) -> None:
assert resolve_checkpoint_path("detective_sam", repo_root) == repo_root / "checkpoints" / "model_epoch22_batch999_score1.1114.pth"
assert resolve_checkpoint_path("detective_sam_sota", repo_root) == repo_root / "checkpoints" / "detective_sam_sota.pth"
def test_baseline_banana_demo_metrics(repo_root: Path, baseline_runner: DetectiveSAMRunner) -> None:
demo_root = repo_root / "demo" / "cocoglide"
iou, f1 = predict_metrics(
baseline_runner,
source_path=demo_root / "source" / "banana_28809.png",
target_path=demo_root / "target" / "banana_28809.png",
mask_path=demo_root / "mask" / "banana_28809.png",
)
assert_exact(iou, 0.8566427949370513)
assert_exact(f1, 0.9227868680750683)
def test_sota_flux_demo_metrics(repo_root: Path, sota_runner: DetectiveSAMRunner) -> None:
demo_root = repo_root / "demo" / "flux_test"
iou, f1 = predict_metrics(
sota_runner,
source_path=demo_root / "source" / "548.png",
target_path=demo_root / "target" / "548.png",
mask_path=demo_root / "mask" / "548.png",
)
assert_exact(iou, 0.8703024868799283)
assert_exact(f1, 0.9306542583192329)
def test_sota_qwen_demo_metrics(repo_root: Path, sota_runner: DetectiveSAMRunner) -> None:
demo_root = repo_root / "demo" / "qwen_test"
iou, f1 = predict_metrics(
sota_runner,
source_path=demo_root / "source" / "166.png",
target_path=demo_root / "target" / "166.png",
mask_path=demo_root / "mask" / "166.png",
)
assert_exact(iou, 0.8297306693388413)
assert_exact(f1, 0.9069429542203147)
def test_baseline_cocoglide_eval_summary(repo_root: Path, baseline_runner: DetectiveSAMRunner) -> None:
dataset = PairDataset(
root_dir=repo_root / "demo" / "cocoglide",
img_size=baseline_runner.config.img_size,
perturbation_type=baseline_runner.config.perturbation_type,
perturbation_intensity=baseline_runner.config.perturbation_intensity,
)
per_sample_results: list[dict[str, float | str | None]] = []
for sample in dataset:
prediction = baseline_runner.predict_sample(sample, threshold=0.5)
true_mask = sample.mask.squeeze().numpy().astype("uint8")
per_sample_results.append(
{
"name": sample.name,
"iou": compute_iou(prediction.pred_mask, true_mask),
"f1": compute_f1(prediction.pred_mask, true_mask),
}
)
summary = summarize_results(per_sample_results)
assert summary["num_samples"] == 5
assert summary["num_samples_with_gt"] == 5
assert_exact(summary["mean_iou"], 0.5092573070481035)
assert_exact(summary["mean_f1"], 0.6509390858765342)
expected_by_name = {
"airplane_139871": (0.41829717560376584, 0.5898582931686339),
"banana_28809": (0.8566427949370513, 0.9227868680750683),
"giraffe_296969": (0.22833093957714018, 0.3717743031951054),
"train_221213": (0.547253866814856, 0.7073872989458688),
"tv_453722": (0.49576175830770386, 0.662888665997994),
}
for result in per_sample_results:
expected_iou, expected_f1 = expected_by_name[result["name"]]
assert_exact(result["iou"], expected_iou)
assert_exact(result["f1"], expected_f1)