Instructions to use Gertlek/DetectiveSAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sam2
How to use Gertlek/DetectiveSAM with sam2:
# Use SAM2 with images import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>) # add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): # propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ... - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| DEFAULT_CHECKPOINT = Path("checkpoints/model_epoch22_batch999_score1.1114.pth") | |
| CHECKPOINT_ALIASES = { | |
| "detective_sam": DEFAULT_CHECKPOINT, | |
| "detective_sam_sota": Path("checkpoints/detective_sam_sota.pth"), | |
| } | |
| class InferenceConfig: | |
| img_size: int | |
| prompt_dim: int | |
| downscale: int | |
| dropout_rate: float | |
| perturbation_type: str | |
| perturbation_intensity: float | |
| sam_config_file: str | |
| sam_checkpoint: str | |
| def max_streams(self) -> int: | |
| return count_perturbation_streams(self.perturbation_type) | |
| def resolve_checkpoint_path(checkpoint_value: str | Path | None, repo_root: str | Path) -> Path: | |
| repo_root = Path(repo_root) | |
| if checkpoint_value is None: | |
| return repo_root / DEFAULT_CHECKPOINT | |
| checkpoint_str = str(checkpoint_value) | |
| if checkpoint_str in CHECKPOINT_ALIASES: | |
| return repo_root / CHECKPOINT_ALIASES[checkpoint_str] | |
| checkpoint_path = Path(checkpoint_value) | |
| if checkpoint_path.is_absolute(): | |
| return checkpoint_path | |
| if checkpoint_path.exists(): | |
| return checkpoint_path.resolve() | |
| repo_candidate = repo_root / checkpoint_path | |
| if repo_candidate.exists(): | |
| return repo_candidate | |
| aliased_checkpoint = repo_root / "checkpoints" / f"{checkpoint_str}.pth" | |
| if aliased_checkpoint.exists(): | |
| return aliased_checkpoint | |
| return repo_candidate | |
| def resolve_repo_path(path_value: str | Path, repo_root: str | Path) -> Path: | |
| path = Path(path_value) | |
| if path.is_absolute(): | |
| return path | |
| repo_root = Path(repo_root) | |
| direct = repo_root / path | |
| if direct.exists(): | |
| return direct | |
| sam_config = repo_root / "sam2configs" / path.name | |
| if sam_config.exists(): | |
| return sam_config | |
| return direct | |
| def load_inference_config(checkpoint_path: str | Path) -> InferenceConfig: | |
| params = _load_params_file(checkpoint_path) | |
| return InferenceConfig( | |
| img_size=int(_resolve_param(params, "img_size", section="training_config", default=512)), | |
| prompt_dim=int( | |
| _resolve_param( | |
| params, | |
| "prompt_dim", | |
| section="model_config", | |
| default=_resolve_param(params, "prompt", section="model_config", default=128), | |
| ) | |
| ), | |
| downscale=int(_resolve_param(params, "downscale", section="model_config", default=16)), | |
| dropout_rate=float( | |
| _resolve_param( | |
| params, | |
| "dropout_rate", | |
| section="model_config", | |
| default=_resolve_param(params, "dropout", section="model_config", default=0.1), | |
| ) | |
| ), | |
| perturbation_type=str(_resolve_param(params, "perturbation_type", section="data_config", default="none")), | |
| perturbation_intensity=float( | |
| _resolve_param(params, "perturbation_intensity", section="data_config", default=0.0) | |
| ), | |
| sam_config_file=str( | |
| _resolve_param(params, "sam_config_file", section="sam_config", default="sam2.1_hiera_b+.yaml") | |
| ), | |
| sam_checkpoint=str( | |
| _resolve_param( | |
| params, | |
| "sam_checkpoint", | |
| section="sam_config", | |
| default="sam2configs/sam2.1_hiera_base_plus.pt", | |
| ) | |
| ), | |
| ) | |
| def count_perturbation_streams(perturbation_type: str) -> int: | |
| if perturbation_type == "none": | |
| return 0 | |
| if "+" in perturbation_type: | |
| return len(perturbation_type.split("+")) | |
| if "/" in perturbation_type: | |
| return len(perturbation_type.split("/")) | |
| return 1 | |
| def _load_params_file(checkpoint_path: str | Path) -> dict[str, Any]: | |
| checkpoint_path = Path(checkpoint_path) | |
| candidate_paths = [ | |
| checkpoint_path.with_name(f"{checkpoint_path.stem}_params.yaml"), | |
| checkpoint_path.parent / "model_params.yaml", | |
| ] | |
| for candidate in candidate_paths: | |
| if candidate.exists(): | |
| with candidate.open("r", encoding="utf-8") as handle: | |
| loaded = yaml.safe_load(handle) | |
| if not isinstance(loaded, dict): | |
| raise ValueError(f"Checkpoint params file must deserialize to a mapping: {candidate}") | |
| return loaded | |
| raise FileNotFoundError( | |
| f"Could not find a params file for checkpoint {checkpoint_path}. " | |
| f"Checked: {', '.join(str(path) for path in candidate_paths)}" | |
| ) | |
| def _resolve_param( | |
| params: dict[str, Any], | |
| key: str, | |
| *, | |
| section: str, | |
| default: Any, | |
| ) -> Any: | |
| if key in params: | |
| return params[key] | |
| return params.get(section, {}).get(key, default) | |