Instructions to use Gertlek/DetectiveSAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sam2
How to use Gertlek/DetectiveSAM with sam2:
# Use SAM2 with images import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>) # add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): # propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ... - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import hashlib | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| from PIL import Image | |
| from sam2.utils.transforms import SAM2Transforms | |
| from torch.utils.data import Dataset | |
| from detectivesam_inference.perturbations import ( | |
| add_gaussian_noise_deterministic, | |
| apply_blur_to_image_tensor, | |
| apply_jpeg_compression_to_tensor, | |
| ) | |
| LEGACY_CONTRASTIVE_FLAG = False | |
| class PreparedSample: | |
| name: str | |
| source_path: Path | |
| target_path: Path | |
| mask_path: Path | None | |
| source_image: Image.Image | |
| target_image: Image.Image | |
| orig: torch.Tensor | |
| streams: list[torch.Tensor] | |
| mask: torch.Tensor | None | |
| def parse_perturbation_types(perturbation_type: str) -> list[str]: | |
| if perturbation_type == "none": | |
| return [] | |
| if "+" in perturbation_type: | |
| return [item.strip() for item in perturbation_type.split("+")] | |
| if "/" in perturbation_type: | |
| return [item.strip() for item in perturbation_type.split("/")] | |
| return [perturbation_type.strip()] | |
| def compute_perturbation_params(perturbation_intensity: float) -> dict[str, float | int]: | |
| return { | |
| "blur_sigma": perturbation_intensity * 2.0, | |
| "jpeg_quality": max(10, int(95 - (perturbation_intensity * 56.67))), | |
| "noise_std": perturbation_intensity * 0.2, | |
| } | |
| def create_combined_mask( | |
| mask_rgba: np.ndarray, | |
| ) -> np.ndarray: | |
| if mask_rgba.ndim == 2: | |
| return (mask_rgba // 255).astype(np.uint8) | |
| if mask_rgba.ndim == 3 and mask_rgba.shape[2] == 4: | |
| alpha = mask_rgba[:, :, 3] | |
| alpha_is_opaque = alpha.sum() == alpha.size * 255 | |
| if alpha_is_opaque: | |
| foreground = (mask_rgba[:, :, 0] > 0).astype(np.uint8) | |
| return cv2.resize(foreground, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.uint8) | |
| _, binary = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY) | |
| return (1 - (binary // 255)).astype(np.uint8) | |
| return (mask_rgba[:, :, 0] > 0).astype(np.uint8) | |
| def resize_triplet( | |
| source_image: Image.Image, | |
| target_image: Image.Image, | |
| mask_image: Image.Image | None, | |
| img_size: int, | |
| ) -> tuple[Image.Image, Image.Image, Image.Image | None]: | |
| source_resized = source_image.resize((img_size, img_size), Image.BILINEAR) | |
| target_resized = target_image.resize((img_size, img_size), Image.BILINEAR) | |
| if mask_image is None: | |
| return source_resized, target_resized, None | |
| return source_resized, target_resized, mask_image.resize((img_size, img_size), Image.NEAREST) | |
| def build_streams( | |
| target_image: Image.Image, | |
| perturbation_type: str, | |
| perturbation_intensity: float, | |
| seed: int, | |
| ) -> list[torch.Tensor]: | |
| perturbations = parse_perturbation_types(perturbation_type) | |
| params = compute_perturbation_params(perturbation_intensity) | |
| orig_tensor = TF.to_tensor(target_image) | |
| streams: list[torch.Tensor] = [] | |
| for perturbation in perturbations: | |
| if perturbation == "gaussian_blur": | |
| streams.append(apply_blur_to_image_tensor(orig_tensor, sigma=float(params["blur_sigma"]))) | |
| elif perturbation == "jpeg_compression": | |
| streams.append(apply_jpeg_compression_to_tensor(orig_tensor, quality=int(params["jpeg_quality"]))) | |
| elif perturbation == "gaussian_noise": | |
| streams.append( | |
| add_gaussian_noise_deterministic( | |
| orig_tensor, | |
| std=float(params["noise_std"]), | |
| seed=seed + len(streams), | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported perturbation type: {perturbation}") | |
| return streams | |
| def build_sample_seed( | |
| source_path: Path, | |
| target_path: Path, | |
| mask_path: Path | None, | |
| perturbation_type: str, | |
| perturbation_intensity: float, | |
| ) -> int: | |
| sample_key = "|".join( | |
| [ | |
| source_path.parent.parent.name, | |
| source_path.stem, | |
| target_path.parent.parent.name, | |
| target_path.stem, | |
| mask_path.stem if mask_path is not None else "no-mask", | |
| perturbation_type, | |
| f"{perturbation_intensity:.8f}", | |
| str(LEGACY_CONTRASTIVE_FLAG), | |
| ] | |
| ) | |
| digest = hashlib.sha256(sample_key.encode("utf-8")).digest() | |
| return int.from_bytes(digest[:8], byteorder="big", signed=False) % (2**31) | |
| def prepare_sample( | |
| source_path: str | Path, | |
| target_path: str | Path, | |
| mask_path: str | Path | None, | |
| img_size: int, | |
| perturbation_type: str, | |
| perturbation_intensity: float, | |
| ) -> PreparedSample: | |
| source_path = Path(source_path) | |
| target_path = Path(target_path) | |
| mask_path = Path(mask_path) if mask_path is not None else None | |
| source_image = Image.open(source_path).convert("RGB") | |
| target_image = Image.open(target_path).convert("RGB") | |
| mask_image = Image.open(mask_path) if mask_path is not None else None | |
| source_image, target_image, mask_image = resize_triplet(source_image, target_image, mask_image, img_size) | |
| sample_seed = build_sample_seed( | |
| source_path=source_path, | |
| target_path=target_path, | |
| mask_path=mask_path, | |
| perturbation_type=perturbation_type, | |
| perturbation_intensity=perturbation_intensity, | |
| ) | |
| transforms = SAM2Transforms(resolution=img_size, mask_threshold=0.0) | |
| orig_tensor = TF.to_tensor(target_image) | |
| orig = transforms.transforms(orig_tensor).unsqueeze(0).squeeze(0) | |
| streams_raw = build_streams( | |
| target_image=target_image, | |
| perturbation_type=perturbation_type, | |
| perturbation_intensity=perturbation_intensity, | |
| seed=sample_seed, | |
| ) | |
| streams = [transforms.transforms(stream).unsqueeze(0).squeeze(0) for stream in streams_raw] | |
| mask_tensor = None | |
| if mask_image is not None: | |
| binary_mask = create_combined_mask( | |
| mask_rgba=np.array(mask_image), | |
| ) | |
| if binary_mask.shape != (img_size, img_size): | |
| binary_mask = cv2.resize(binary_mask, (img_size, img_size), interpolation=cv2.INTER_NEAREST) | |
| mask_tensor = torch.tensor(binary_mask, dtype=torch.float32).unsqueeze(0) | |
| return PreparedSample( | |
| name=target_path.stem, | |
| source_path=source_path, | |
| target_path=target_path, | |
| mask_path=mask_path, | |
| source_image=source_image, | |
| target_image=target_image, | |
| orig=orig, | |
| streams=streams, | |
| mask=mask_tensor, | |
| ) | |
| class PairDataset(Dataset): | |
| def __init__( | |
| self, | |
| root_dir: str | Path, | |
| img_size: int, | |
| perturbation_type: str, | |
| perturbation_intensity: float, | |
| max_samples: int | None = None, | |
| ) -> None: | |
| self.root_dir = Path(root_dir) | |
| self.source_dir = self.root_dir / "source" | |
| self.target_dir = self.root_dir / "target" | |
| self.mask_dir = self.root_dir / "mask" | |
| self.img_size = img_size | |
| self.perturbation_type = perturbation_type | |
| self.perturbation_intensity = perturbation_intensity | |
| if not self.source_dir.exists() or not self.target_dir.exists(): | |
| raise FileNotFoundError(f"{root_dir} must contain source/ and target/ directories") | |
| target_files = sorted( | |
| path | |
| for path in self.target_dir.iterdir() | |
| if path.suffix.lower() in {".png", ".jpg", ".jpeg"} | |
| ) | |
| if max_samples is not None: | |
| target_files = target_files[:max_samples] | |
| self.target_files = target_files | |
| def __len__(self) -> int: | |
| return len(self.target_files) | |
| def __getitem__(self, index: int) -> PreparedSample: | |
| target_path = self.target_files[index] | |
| source_path = self.source_dir / target_path.name | |
| if not source_path.exists(): | |
| png_fallback = self.source_dir / f"{target_path.stem}.png" | |
| jpg_fallback = self.source_dir / f"{target_path.stem}.jpg" | |
| source_path = png_fallback if png_fallback.exists() else jpg_fallback | |
| if not source_path.exists(): | |
| raise FileNotFoundError(f"Could not find source image for {target_path.name}") | |
| mask_candidate = self.mask_dir / f"{target_path.stem}.png" | |
| mask_path = mask_candidate if mask_candidate.exists() else None | |
| return prepare_sample( | |
| source_path=source_path, | |
| target_path=target_path, | |
| mask_path=mask_path, | |
| img_size=self.img_size, | |
| perturbation_type=self.perturbation_type, | |
| perturbation_intensity=self.perturbation_intensity, | |
| ) | |