Gertlek's picture
Publish DetectiveSAM inference bundle
7b474fb verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
from tqdm import tqdm
from detectivesam_inference.dataset import PairDataset
from detectivesam_inference.metrics import compute_f1, compute_iou, summarize_results
from detectivesam_inference.runtime import DetectiveSAMRunner, get_repo_root
from detectivesam_inference.visualization import save_prediction_outputs
def parse_args() -> argparse.Namespace:
repo_root = get_repo_root()
parser = argparse.ArgumentParser(description="Evaluate DetectiveSAM on a dataset root with source/target/mask folders.")
parser.add_argument(
"--checkpoint",
default="detective_sam",
help="Checkpoint path or alias. Built-in aliases: detective_sam, detective_sam_sota.",
)
parser.add_argument("--dataset-root", default=str(repo_root / "demo" / "cocoglide"))
parser.add_argument("--output-dir", default=str(repo_root / "outputs" / "eval_demo"))
parser.add_argument("--device", default=None)
parser.add_argument("--threshold", type=float, default=0.5)
parser.add_argument("--max-samples", type=int, default=None)
parser.add_argument("--num-visualizations", type=int, default=4)
return parser.parse_args()
def main() -> None:
args = parse_args()
runner = DetectiveSAMRunner(checkpoint_path=args.checkpoint, device=args.device)
dataset = PairDataset(
root_dir=args.dataset_root,
img_size=runner.config.img_size,
perturbation_type=runner.config.perturbation_type,
perturbation_intensity=runner.config.perturbation_intensity,
max_samples=args.max_samples,
)
output_dir = Path(args.output_dir)
vis_dir = output_dir / "visualizations"
output_dir.mkdir(parents=True, exist_ok=True)
vis_dir.mkdir(parents=True, exist_ok=True)
per_sample_results: list[dict[str, float | str | None]] = []
for index, sample in enumerate(tqdm(dataset, desc="Evaluating")):
prediction = runner.predict_sample(sample, threshold=args.threshold)
gt_mask = sample.mask.squeeze().numpy().astype("uint8") if sample.mask is not None else None
per_sample_results.append(
{
"name": sample.name,
"iou": compute_iou(prediction.pred_mask, gt_mask) if gt_mask is not None else None,
"f1": compute_f1(prediction.pred_mask, gt_mask) if gt_mask is not None else None,
}
)
if index < args.num_visualizations:
save_prediction_outputs(
output_dir=vis_dir,
name=sample.name,
source_image=sample.source_image,
target_image=sample.target_image,
probability_map=prediction.probability,
pred_mask=prediction.pred_mask,
gt_mask=gt_mask,
)
payload = {
"checkpoint": str(runner.checkpoint_path.resolve()),
"dataset_root": str(Path(args.dataset_root).resolve()),
"threshold": args.threshold,
"summary": summarize_results(per_sample_results),
"samples": per_sample_results,
}
with (output_dir / "summary.json").open("w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
print(json.dumps(payload["summary"], indent=2))
print(f"Detailed results written to {output_dir / 'summary.json'}")
if __name__ == "__main__":
main()