File size: 4,750 Bytes
7b474fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations

import argparse
import json
from pathlib import Path

from detectivesam_inference.dataset import prepare_sample
from detectivesam_inference.metrics import compute_f1, compute_iou
from detectivesam_inference.runtime import DetectiveSAMRunner, get_repo_root
from detectivesam_inference.visualization import save_prediction_outputs


DEFAULT_DEMO_NAME = "banana_28809"


def resolve_demo_defaults(repo_root: Path) -> tuple[Path | None, Path, Path | None, str]:
    user_demo_target = repo_root / "demo" / "user_image" / "demo_input.png"
    fallback_source = repo_root / "demo" / "cocoglide" / "source" / f"{DEFAULT_DEMO_NAME}.png"
    fallback_target = repo_root / "demo" / "cocoglide" / "target" / f"{DEFAULT_DEMO_NAME}.png"
    fallback_mask = repo_root / "demo" / "cocoglide" / "mask" / f"{DEFAULT_DEMO_NAME}.png"

    if user_demo_target.exists():
        return None, user_demo_target, None, "single_image"
    return fallback_source, fallback_target, fallback_mask, "pair"


def parse_args() -> argparse.Namespace:
    repo_root = get_repo_root()
    parser = argparse.ArgumentParser(description="Run DetectiveSAM on a single source/target pair.")
    parser.add_argument(
        "--checkpoint",
        default="detective_sam",
        help="Checkpoint path or alias. Built-in aliases: detective_sam, detective_sam_sota.",
    )
    parser.add_argument("--source", default=None, help="Optional source image. If omitted, target is reused as source.")
    parser.add_argument(
        "--target",
        default=None,
        help="Target image. If omitted, uses demo/user_image/demo_input.png when present, else falls back to the bundled CocoGlide pair.",
    )
    parser.add_argument("--mask", default=None, help="Optional ground-truth mask for metrics.")
    parser.add_argument("--output-dir", default=str(repo_root / "outputs" / "predict_demo"))
    parser.add_argument("--device", default=None)
    parser.add_argument("--threshold", type=float, default=0.5)
    return parser.parse_args()


def resolve_input_paths(args: argparse.Namespace, repo_root: Path) -> tuple[Path, Path, Path | None, str]:
    demo_source, demo_target, demo_mask, demo_mode = resolve_demo_defaults(repo_root)
    if args.target is not None:
        target_path = Path(args.target)
        source_path = Path(args.source) if args.source else target_path
        mask_path = Path(args.mask) if args.mask else None
        return source_path, target_path, mask_path, "custom"

    target_path = demo_target
    source_path = Path(args.source) if args.source else (demo_source or target_path)
    mask_path = Path(args.mask) if args.mask else demo_mask
    return source_path, target_path, mask_path, demo_mode


def main() -> None:
    args = parse_args()
    repo_root = get_repo_root()
    source_path, target_path, mask_path, demo_mode = resolve_input_paths(args, repo_root)
    reference_mode = "pair" if source_path != target_path else "target_as_source"

    runner = DetectiveSAMRunner(checkpoint_path=args.checkpoint, device=args.device)
    sample = prepare_sample(
        source_path=source_path,
        target_path=target_path,
        mask_path=mask_path,
        img_size=runner.config.img_size,
        perturbation_type=runner.config.perturbation_type,
        perturbation_intensity=runner.config.perturbation_intensity,
    )
    prediction = runner.predict_sample(sample, threshold=args.threshold)

    gt_mask = sample.mask.squeeze().numpy().astype("uint8") if sample.mask is not None else None
    summary = {
        "sample": sample.name,
        "checkpoint": str(runner.checkpoint_path.resolve()),
        "demo_mode": demo_mode,
        "reference_mode": reference_mode,
        "source": str(sample.source_path),
        "target": str(sample.target_path),
        "mask": str(sample.mask_path) if sample.mask_path is not None else None,
        "threshold": args.threshold,
        "metrics": {
            "iou": compute_iou(prediction.pred_mask, gt_mask) if gt_mask is not None else None,
            "f1": compute_f1(prediction.pred_mask, gt_mask) if gt_mask is not None else None,
        },
    }

    output_dir = Path(args.output_dir)
    save_prediction_outputs(
        output_dir=output_dir,
        name=sample.name,
        source_image=sample.source_image,
        target_image=sample.target_image,
        probability_map=prediction.probability,
        pred_mask=prediction.pred_mask,
        gt_mask=gt_mask,
    )
    output_dir.mkdir(parents=True, exist_ok=True)
    with (output_dir / f"{sample.name}_summary.json").open("w", encoding="utf-8") as handle:
        json.dump(summary, handle, indent=2)

    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()