Gertlek's picture
Publish DetectiveSAM inference bundle
7b474fb verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
from detectivesam_inference.dataset import prepare_sample
from detectivesam_inference.metrics import compute_f1, compute_iou
from detectivesam_inference.runtime import DetectiveSAMRunner, get_repo_root
from detectivesam_inference.visualization import save_prediction_outputs
DEFAULT_DEMO_NAME = "banana_28809"
def resolve_demo_defaults(repo_root: Path) -> tuple[Path | None, Path, Path | None, str]:
user_demo_target = repo_root / "demo" / "user_image" / "demo_input.png"
fallback_source = repo_root / "demo" / "cocoglide" / "source" / f"{DEFAULT_DEMO_NAME}.png"
fallback_target = repo_root / "demo" / "cocoglide" / "target" / f"{DEFAULT_DEMO_NAME}.png"
fallback_mask = repo_root / "demo" / "cocoglide" / "mask" / f"{DEFAULT_DEMO_NAME}.png"
if user_demo_target.exists():
return None, user_demo_target, None, "single_image"
return fallback_source, fallback_target, fallback_mask, "pair"
def parse_args() -> argparse.Namespace:
repo_root = get_repo_root()
parser = argparse.ArgumentParser(description="Run DetectiveSAM on a single source/target pair.")
parser.add_argument(
"--checkpoint",
default="detective_sam",
help="Checkpoint path or alias. Built-in aliases: detective_sam, detective_sam_sota.",
)
parser.add_argument("--source", default=None, help="Optional source image. If omitted, target is reused as source.")
parser.add_argument(
"--target",
default=None,
help="Target image. If omitted, uses demo/user_image/demo_input.png when present, else falls back to the bundled CocoGlide pair.",
)
parser.add_argument("--mask", default=None, help="Optional ground-truth mask for metrics.")
parser.add_argument("--output-dir", default=str(repo_root / "outputs" / "predict_demo"))
parser.add_argument("--device", default=None)
parser.add_argument("--threshold", type=float, default=0.5)
return parser.parse_args()
def resolve_input_paths(args: argparse.Namespace, repo_root: Path) -> tuple[Path, Path, Path | None, str]:
demo_source, demo_target, demo_mask, demo_mode = resolve_demo_defaults(repo_root)
if args.target is not None:
target_path = Path(args.target)
source_path = Path(args.source) if args.source else target_path
mask_path = Path(args.mask) if args.mask else None
return source_path, target_path, mask_path, "custom"
target_path = demo_target
source_path = Path(args.source) if args.source else (demo_source or target_path)
mask_path = Path(args.mask) if args.mask else demo_mask
return source_path, target_path, mask_path, demo_mode
def main() -> None:
args = parse_args()
repo_root = get_repo_root()
source_path, target_path, mask_path, demo_mode = resolve_input_paths(args, repo_root)
reference_mode = "pair" if source_path != target_path else "target_as_source"
runner = DetectiveSAMRunner(checkpoint_path=args.checkpoint, device=args.device)
sample = prepare_sample(
source_path=source_path,
target_path=target_path,
mask_path=mask_path,
img_size=runner.config.img_size,
perturbation_type=runner.config.perturbation_type,
perturbation_intensity=runner.config.perturbation_intensity,
)
prediction = runner.predict_sample(sample, threshold=args.threshold)
gt_mask = sample.mask.squeeze().numpy().astype("uint8") if sample.mask is not None else None
summary = {
"sample": sample.name,
"checkpoint": str(runner.checkpoint_path.resolve()),
"demo_mode": demo_mode,
"reference_mode": reference_mode,
"source": str(sample.source_path),
"target": str(sample.target_path),
"mask": str(sample.mask_path) if sample.mask_path is not None else None,
"threshold": args.threshold,
"metrics": {
"iou": compute_iou(prediction.pred_mask, gt_mask) if gt_mask is not None else None,
"f1": compute_f1(prediction.pred_mask, gt_mask) if gt_mask is not None else None,
},
}
output_dir = Path(args.output_dir)
save_prediction_outputs(
output_dir=output_dir,
name=sample.name,
source_image=sample.source_image,
target_image=sample.target_image,
probability_map=prediction.probability,
pred_mask=prediction.pred_mask,
gt_mask=gt_mask,
)
output_dir.mkdir(parents=True, exist_ok=True)
with (output_dir / f"{sample.name}_summary.json").open("w", encoding="utf-8") as handle:
json.dump(summary, handle, indent=2)
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()