Instructions to use Gertlek/DetectiveSAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sam2
How to use Gertlek/DetectiveSAM with sam2:
# Use SAM2 with images import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>) # add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): # propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ... - Notebooks
- Google Colab
- Kaggle
File size: 3,560 Bytes
7b474fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | from __future__ import annotations
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
from hydra import initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from detectivesam_inference.checkpoint import (
InferenceConfig,
load_inference_config,
resolve_checkpoint_path,
resolve_repo_path,
)
from detectivesam_inference.dataset import PreparedSample
from detectivesam_inference.models.forgerylocalizer import ForgeryLocalizer
@dataclass(frozen=True)
class PredictionResult:
probability: np.ndarray
pred_mask: np.ndarray
def get_repo_root() -> Path:
return Path(__file__).resolve().parent.parent
def select_device(device: str | None = None) -> torch.device:
if device is not None:
return torch.device(device)
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def initialize_sam2_config(config_dir: str | Path) -> None:
config_dir = str(Path(config_dir).resolve())
hydra = GlobalHydra.instance()
current_dir = getattr(initialize_sam2_config, "_current_dir", None)
if hydra.is_initialized():
if current_dir == config_dir:
return
hydra.clear()
initialize_config_dir(config_dir=config_dir, version_base=None)
initialize_sam2_config._current_dir = config_dir
class DetectiveSAMRunner:
def __init__(
self,
checkpoint_path: str | Path | None = None,
device: str | None = None,
) -> None:
self.repo_root = get_repo_root()
self.checkpoint_path = resolve_checkpoint_path(checkpoint_path, self.repo_root)
self.device = select_device(device)
self.config = load_inference_config(self.checkpoint_path)
self.model = self._load_model()
def _load_model(self) -> ForgeryLocalizer:
sam_config_path = resolve_repo_path(self.config.sam_config_file, self.repo_root)
sam_checkpoint_path = resolve_repo_path(self.config.sam_checkpoint, self.repo_root)
initialize_sam2_config(sam_config_path.parent)
model = ForgeryLocalizer(
sam_config=sam_config_path.name,
sam_checkpoint=str(sam_checkpoint_path),
prompt_dim=self.config.prompt_dim,
downscale=self.config.downscale,
dropout_rate=self.config.dropout_rate,
max_streams=self.config.max_streams,
device=str(self.device),
).to(self.device)
checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False)
state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint
model.load_state_dict(state_dict)
model.eval()
return model
def autocast_context(self):
if self.device.type == "cuda":
return torch.amp.autocast(device_type="cuda")
return nullcontext()
def predict_sample(
self,
sample: PreparedSample,
threshold: float = 0.5,
) -> PredictionResult:
orig = sample.orig.unsqueeze(0).to(self.device)
streams = [stream.unsqueeze(0).to(self.device) for stream in sample.streams]
with torch.inference_mode():
with self.autocast_context():
logits = self.model(orig, streams)
probability = torch.sigmoid(logits).squeeze().detach().cpu().numpy()
pred_mask = (probability > threshold).astype("uint8")
return PredictionResult(probability=probability, pred_mask=pred_mask)
|