Instructions to use Gertlek/DetectiveSAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sam2
How to use Gertlek/DetectiveSAM with sam2:
# Use SAM2 with images import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>) # add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): # propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ... - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from detectivesam_inference.dataset import PairDataset | |
| from detectivesam_inference.metrics import compute_f1, compute_iou, summarize_results | |
| from detectivesam_inference.runtime import DetectiveSAMRunner, get_repo_root | |
| from detectivesam_inference.visualization import save_prediction_outputs | |
| def parse_args() -> argparse.Namespace: | |
| repo_root = get_repo_root() | |
| parser = argparse.ArgumentParser(description="Evaluate DetectiveSAM on a dataset root with source/target/mask folders.") | |
| parser.add_argument( | |
| "--checkpoint", | |
| default="detective_sam", | |
| help="Checkpoint path or alias. Built-in aliases: detective_sam, detective_sam_sota.", | |
| ) | |
| parser.add_argument("--dataset-root", default=str(repo_root / "demo" / "cocoglide")) | |
| parser.add_argument("--output-dir", default=str(repo_root / "outputs" / "eval_demo")) | |
| parser.add_argument("--device", default=None) | |
| parser.add_argument("--threshold", type=float, default=0.5) | |
| parser.add_argument("--max-samples", type=int, default=None) | |
| parser.add_argument("--num-visualizations", type=int, default=4) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| runner = DetectiveSAMRunner(checkpoint_path=args.checkpoint, device=args.device) | |
| dataset = PairDataset( | |
| root_dir=args.dataset_root, | |
| img_size=runner.config.img_size, | |
| perturbation_type=runner.config.perturbation_type, | |
| perturbation_intensity=runner.config.perturbation_intensity, | |
| max_samples=args.max_samples, | |
| ) | |
| output_dir = Path(args.output_dir) | |
| vis_dir = output_dir / "visualizations" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| vis_dir.mkdir(parents=True, exist_ok=True) | |
| per_sample_results: list[dict[str, float | str | None]] = [] | |
| for index, sample in enumerate(tqdm(dataset, desc="Evaluating")): | |
| prediction = runner.predict_sample(sample, threshold=args.threshold) | |
| gt_mask = sample.mask.squeeze().numpy().astype("uint8") if sample.mask is not None else None | |
| per_sample_results.append( | |
| { | |
| "name": sample.name, | |
| "iou": compute_iou(prediction.pred_mask, gt_mask) if gt_mask is not None else None, | |
| "f1": compute_f1(prediction.pred_mask, gt_mask) if gt_mask is not None else None, | |
| } | |
| ) | |
| if index < args.num_visualizations: | |
| save_prediction_outputs( | |
| output_dir=vis_dir, | |
| name=sample.name, | |
| source_image=sample.source_image, | |
| target_image=sample.target_image, | |
| probability_map=prediction.probability, | |
| pred_mask=prediction.pred_mask, | |
| gt_mask=gt_mask, | |
| ) | |
| payload = { | |
| "checkpoint": str(runner.checkpoint_path.resolve()), | |
| "dataset_root": str(Path(args.dataset_root).resolve()), | |
| "threshold": args.threshold, | |
| "summary": summarize_results(per_sample_results), | |
| "samples": per_sample_results, | |
| } | |
| with (output_dir / "summary.json").open("w", encoding="utf-8") as handle: | |
| json.dump(payload, handle, indent=2) | |
| print(json.dumps(payload["summary"], indent=2)) | |
| print(f"Detailed results written to {output_dir / 'summary.json'}") | |
| if __name__ == "__main__": | |
| main() | |