File size: 3,427 Bytes
7b474fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations

import argparse
import json
from pathlib import Path

from tqdm import tqdm

from detectivesam_inference.dataset import PairDataset
from detectivesam_inference.metrics import compute_f1, compute_iou, summarize_results
from detectivesam_inference.runtime import DetectiveSAMRunner, get_repo_root
from detectivesam_inference.visualization import save_prediction_outputs


def parse_args() -> argparse.Namespace:
    repo_root = get_repo_root()
    parser = argparse.ArgumentParser(description="Evaluate DetectiveSAM on a dataset root with source/target/mask folders.")
    parser.add_argument(
        "--checkpoint",
        default="detective_sam",
        help="Checkpoint path or alias. Built-in aliases: detective_sam, detective_sam_sota.",
    )
    parser.add_argument("--dataset-root", default=str(repo_root / "demo" / "cocoglide"))
    parser.add_argument("--output-dir", default=str(repo_root / "outputs" / "eval_demo"))
    parser.add_argument("--device", default=None)
    parser.add_argument("--threshold", type=float, default=0.5)
    parser.add_argument("--max-samples", type=int, default=None)
    parser.add_argument("--num-visualizations", type=int, default=4)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    runner = DetectiveSAMRunner(checkpoint_path=args.checkpoint, device=args.device)
    dataset = PairDataset(
        root_dir=args.dataset_root,
        img_size=runner.config.img_size,
        perturbation_type=runner.config.perturbation_type,
        perturbation_intensity=runner.config.perturbation_intensity,
        max_samples=args.max_samples,
    )

    output_dir = Path(args.output_dir)
    vis_dir = output_dir / "visualizations"
    output_dir.mkdir(parents=True, exist_ok=True)
    vis_dir.mkdir(parents=True, exist_ok=True)

    per_sample_results: list[dict[str, float | str | None]] = []
    for index, sample in enumerate(tqdm(dataset, desc="Evaluating")):
        prediction = runner.predict_sample(sample, threshold=args.threshold)
        gt_mask = sample.mask.squeeze().numpy().astype("uint8") if sample.mask is not None else None
        per_sample_results.append(
            {
                "name": sample.name,
                "iou": compute_iou(prediction.pred_mask, gt_mask) if gt_mask is not None else None,
                "f1": compute_f1(prediction.pred_mask, gt_mask) if gt_mask is not None else None,
            }
        )

        if index < args.num_visualizations:
            save_prediction_outputs(
                output_dir=vis_dir,
                name=sample.name,
                source_image=sample.source_image,
                target_image=sample.target_image,
                probability_map=prediction.probability,
                pred_mask=prediction.pred_mask,
                gt_mask=gt_mask,
            )

    payload = {
        "checkpoint": str(runner.checkpoint_path.resolve()),
        "dataset_root": str(Path(args.dataset_root).resolve()),
        "threshold": args.threshold,
        "summary": summarize_results(per_sample_results),
        "samples": per_sample_results,
    }
    with (output_dir / "summary.json").open("w", encoding="utf-8") as handle:
        json.dump(payload, handle, indent=2)

    print(json.dumps(payload["summary"], indent=2))
    print(f"Detailed results written to {output_dir / 'summary.json'}")


if __name__ == "__main__":
    main()