Instructions to use Gertlek/DetectiveSAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sam2
How to use Gertlek/DetectiveSAM with sam2:
# Use SAM2 with images import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained(Gertlek/DetectiveSAM) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>) # add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): # propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ... - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| FeaturePyramid = list[torch.Tensor] | |
| StreamPyramid = list[list[torch.Tensor]] | |
| class SharedAdapter(nn.Module): | |
| """Applies a residual adapter to each feature scale.""" | |
| def __init__( | |
| self, | |
| in_channels_list: list[int], | |
| hidden_dim: int, | |
| dropout_rate: float = 0.1, | |
| max_streams: int = 2, | |
| ) -> None: | |
| super().__init__() | |
| max_streams = max(max_streams, 1) | |
| self.mlps_tune = nn.ModuleList( | |
| nn.Conv2d(max_streams * channels, hidden_dim, kernel_size=1) | |
| for channels in in_channels_list | |
| ) | |
| self.mlps_bottleneck = nn.ModuleList( | |
| nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1), nn.GELU()) | |
| for _ in in_channels_list | |
| ) | |
| self.mlp_up = nn.ModuleList( | |
| nn.Conv2d(hidden_dim, channels, kernel_size=1) | |
| for channels in in_channels_list | |
| ) | |
| self.activation = nn.GELU() | |
| self.dropout = nn.Dropout2d(p=dropout_rate) | |
| def forward( | |
| self, | |
| stream_features: list[torch.Tensor], | |
| unadapted: torch.Tensor, | |
| scale_idx: int, | |
| ) -> torch.Tensor: | |
| fused_streams = torch.cat(stream_features, dim=1) if stream_features else unadapted | |
| hidden = self.mlps_tune[scale_idx](fused_streams) | |
| hidden = self.activation(hidden) | |
| hidden = self.dropout(hidden) | |
| hidden = self.mlps_bottleneck[scale_idx](hidden) | |
| delta = self.mlp_up[scale_idx](hidden) | |
| return unadapted + delta | |
| class RefineBlock(nn.Module): | |
| """Refines the coarse mask with low-level features.""" | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| low_channels: int, | |
| out_channels: int = 1, | |
| dropout_rate: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(hidden_dim + low_channels, hidden_dim, kernel_size=3, padding=1) | |
| self.activation1 = nn.GELU() | |
| self.dropout1 = nn.Dropout2d(p=dropout_rate) | |
| self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1) | |
| self.activation2 = nn.GELU() | |
| self.dropout2 = nn.Dropout2d(p=dropout_rate) | |
| self.conv3 = nn.Conv2d(hidden_dim, out_channels, kernel_size=1) | |
| def forward( | |
| self, | |
| attention_features: torch.Tensor, | |
| low_features: torch.Tensor, | |
| coarse_upsampled: torch.Tensor, | |
| ) -> torch.Tensor: | |
| refined = torch.cat([attention_features, low_features], dim=1) | |
| refined = self.conv1(refined) | |
| refined = self.activation1(refined) | |
| refined = self.dropout1(refined) | |
| refined = self.conv2(refined) | |
| refined = self.activation2(refined) | |
| refined = self.dropout2(refined) | |
| delta = self.conv3(refined) | |
| return coarse_upsampled + delta | |
| class CoarseProcessingBlock(nn.Module): | |
| """Adds transformer-based coarse reasoning before refinement.""" | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| attn_dim: int, | |
| n_heads: int, | |
| num_encoder_layers: int, | |
| dropout_rate: float, | |
| downscale: int, | |
| ) -> None: | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.coarse_down = nn.Sequential( | |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=downscale, stride=downscale, groups=hidden_dim), | |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1), | |
| nn.GELU(), | |
| nn.Dropout2d(p=dropout_rate), | |
| ) | |
| self.pos_embed_conv = nn.Conv2d(2, hidden_dim, kernel_size=1) | |
| self.pos_dropout = nn.Dropout2d(p=dropout_rate) | |
| self.feat_proj = nn.Sequential( | |
| nn.Linear(hidden_dim, attn_dim), | |
| nn.Dropout(p=dropout_rate), | |
| ) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=attn_dim, | |
| nhead=n_heads, | |
| dim_feedforward=attn_dim * 4, | |
| dropout=dropout_rate, | |
| activation="gelu", | |
| batch_first=True, | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder( | |
| encoder_layer, | |
| num_layers=num_encoder_layers, | |
| ) | |
| self.transformer_out = nn.Sequential( | |
| nn.Linear(attn_dim, hidden_dim), | |
| nn.Dropout(p=dropout_rate), | |
| ) | |
| self.residual_gate_conv = nn.Sequential( | |
| nn.Conv2d(hidden_dim * 2, hidden_dim // 4, kernel_size=3, padding=1), | |
| nn.GELU(), | |
| nn.Dropout2d(p=dropout_rate), | |
| nn.Conv2d(hidden_dim // 4, 1, kernel_size=1), | |
| ) | |
| self.cached_pos_encodings: dict[tuple[int, int], torch.Tensor] = {} | |
| def _generate_pos_encoding(self, height: int, width: int) -> torch.Tensor: | |
| device = self.pos_embed_conv.weight.device | |
| y_pos = torch.linspace(-1, 1, height, device=device).view(height, 1).expand(height, width) | |
| x_pos = torch.linspace(-1, 1, width, device=device).view(1, width).expand(height, width) | |
| pos_grid = torch.stack([y_pos, x_pos], dim=0).unsqueeze(0) | |
| return self.pos_embed_conv(pos_grid) | |
| def _get_positional_encoding(self, batch_size: int, height: int, width: int) -> torch.Tensor: | |
| key = (height, width) | |
| device = self.pos_embed_conv.weight.device | |
| if key not in self.cached_pos_encodings: | |
| self.cached_pos_encodings[key] = self._generate_pos_encoding(height, width).detach() | |
| cached_encoding = self.cached_pos_encodings[key] | |
| if cached_encoding.device != device: | |
| cached_encoding = cached_encoding.to(device) | |
| self.cached_pos_encodings[key] = cached_encoding | |
| return cached_encoding.expand(batch_size, -1, -1, -1) | |
| def forward(self, fused: torch.Tensor) -> torch.Tensor: | |
| coarse_features = self.coarse_down(fused) | |
| batch_size, _, height, width = coarse_features.shape | |
| pos_embed = self._get_positional_encoding(batch_size, height, width) | |
| pos_embed = self.pos_dropout(pos_embed) | |
| coarse_with_position = coarse_features + pos_embed | |
| feature_sequence = coarse_with_position.flatten(2).permute(0, 2, 1) | |
| feature_sequence = self.feat_proj(feature_sequence) | |
| transformer_output = self.transformer_encoder(feature_sequence) | |
| hidden = self.transformer_out(transformer_output) | |
| hidden = hidden.permute(0, 2, 1).view(batch_size, self.hidden_dim, height, width) | |
| gate_input = torch.cat([hidden, coarse_features], dim=1) | |
| residual_gate = torch.sigmoid(self.residual_gate_conv(gate_input)) | |
| return residual_gate * hidden + (1 - residual_gate) * coarse_features | |
| class FineProcessingBlock(nn.Module): | |
| """Produces the coarse mask and uncertainty map.""" | |
| def __init__(self, hidden_dim: int, dropout_rate: float) -> None: | |
| super().__init__() | |
| self.feature_refinement = nn.Sequential( | |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), | |
| nn.GELU(), | |
| nn.Dropout2d(p=dropout_rate), | |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), | |
| nn.GELU(), | |
| nn.Dropout2d(p=dropout_rate), | |
| ) | |
| self.coarse_head = nn.Conv2d(hidden_dim, 1, kernel_size=1) | |
| self.uncertainty_head = nn.Conv2d(hidden_dim, 1, kernel_size=1) | |
| def forward( | |
| self, | |
| hidden: torch.Tensor, | |
| output_size: tuple[int, int], | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| hidden = self.feature_refinement(hidden) | |
| coarse_logit = self.coarse_head(hidden) | |
| uncertainty_logit = self.uncertainty_head(hidden) | |
| coarse_mask = F.interpolate( | |
| coarse_logit, | |
| size=output_size, | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| uncertainty_map = F.interpolate( | |
| uncertainty_logit, | |
| size=output_size, | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| return hidden, coarse_mask, torch.sigmoid(uncertainty_map) | |
| class FeatureFusionBlockSpatial(nn.Module): | |
| """Fuses original, adapted, and perturbed features with per-pixel attention.""" | |
| def __init__( | |
| self, | |
| in_channels_list: list[int], | |
| hidden_dim: int = 128, | |
| dropout_rate: float = 0.1, | |
| max_streams: int = 2, | |
| attn_reduction: int = 4, | |
| ) -> None: | |
| super().__init__() | |
| self.num_streams = 2 + max_streams | |
| self.att_conv = nn.ModuleList() | |
| self.proj_conv = nn.ModuleList() | |
| for channels in in_channels_list: | |
| total_channels = channels * self.num_streams | |
| mid_channels = max(total_channels // attn_reduction, 8) | |
| self.att_conv.append( | |
| nn.Sequential( | |
| nn.Conv2d( | |
| total_channels, | |
| mid_channels, | |
| kernel_size=3, | |
| padding=1, | |
| groups=self.num_streams, | |
| bias=False, | |
| ), | |
| nn.GELU(), | |
| nn.Conv2d(mid_channels, self.num_streams, kernel_size=1, bias=False), | |
| ) | |
| ) | |
| self.proj_conv.append( | |
| nn.Sequential( | |
| nn.Conv2d(channels, hidden_dim, kernel_size=1), | |
| nn.GELU(), | |
| nn.Dropout2d(p=dropout_rate), | |
| ) | |
| ) | |
| fusion_channels = hidden_dim * len(in_channels_list) | |
| self.fuse_project = nn.Sequential( | |
| nn.Conv2d(fusion_channels, hidden_dim, kernel_size=1), | |
| nn.GELU(), | |
| nn.Dropout2d(p=dropout_rate), | |
| ) | |
| def forward( | |
| self, | |
| adapted: FeaturePyramid, | |
| unadapted: FeaturePyramid, | |
| streams_unadapted: StreamPyramid, | |
| output_size: tuple[int, int], | |
| ) -> torch.Tensor: | |
| fused_scales = [] | |
| for scale_idx, (att_head, projection) in enumerate(zip(self.att_conv, self.proj_conv)): | |
| streams = [adapted[scale_idx], unadapted[scale_idx], *streams_unadapted[scale_idx]] | |
| logits = att_head(torch.cat(streams, dim=1)) | |
| weights = F.softmax(logits, dim=1).unsqueeze(2) | |
| fused = (torch.stack(streams, dim=1) * weights).sum(dim=1) | |
| fused = projection(fused) | |
| fused = F.interpolate( | |
| fused, | |
| size=output_size, | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| fused_scales.append(fused) | |
| return self.fuse_project(torch.cat(fused_scales, dim=1)) | |
| class MaskAdapter(nn.Module): | |
| """Builds the prompt mask passed into the SAM decoder.""" | |
| def __init__( | |
| self, | |
| hidden_dim: int = 256, | |
| downscale: int = 16, | |
| output_resolution: tuple[int, int] = (128, 128), | |
| in_channels_list: list[int] | None = None, | |
| attn_dim: int = 16, | |
| n_heads: int = 4, | |
| num_encoder_layers: int = 2, | |
| dropout_rate: float = 0.1, | |
| max_streams: int = 2, | |
| ) -> None: | |
| super().__init__() | |
| channels = in_channels_list or [256, 32, 64] | |
| self.downscale = downscale | |
| self.output_resolution = output_resolution | |
| self.feature_fusion = FeatureFusionBlockSpatial( | |
| in_channels_list=channels, | |
| hidden_dim=hidden_dim, | |
| dropout_rate=dropout_rate, | |
| max_streams=max_streams, | |
| ) | |
| self.coarse_processor = CoarseProcessingBlock( | |
| hidden_dim=hidden_dim, | |
| attn_dim=attn_dim, | |
| n_heads=n_heads, | |
| num_encoder_layers=num_encoder_layers, | |
| dropout_rate=dropout_rate, | |
| downscale=downscale, | |
| ) | |
| self.fine_processor = FineProcessingBlock(hidden_dim, dropout_rate) | |
| self.spatial_gate = nn.Sequential( | |
| nn.Conv2d(2, hidden_dim // 2, kernel_size=3, padding=1), | |
| nn.GELU(), | |
| nn.Dropout2d(p=dropout_rate), | |
| nn.Conv2d(hidden_dim // 2, 1, kernel_size=1), | |
| nn.Sigmoid(), | |
| ) | |
| self.refine_head = RefineBlock( | |
| hidden_dim=hidden_dim, | |
| low_channels=32, | |
| out_channels=1, | |
| dropout_rate=dropout_rate, | |
| ) | |
| def forward( | |
| self, | |
| adapted: FeaturePyramid, | |
| streams_unadapted: StreamPyramid, | |
| unadapted: FeaturePyramid, | |
| ) -> torch.Tensor: | |
| output_height, output_width = self.output_resolution | |
| output_size = (output_height, output_width) | |
| coarse_size = (output_height // self.downscale, output_width // self.downscale) | |
| fused = self.feature_fusion(adapted, unadapted, streams_unadapted, output_size) | |
| hidden = self.coarse_processor(fused) | |
| if hidden.shape[-2:] != coarse_size: | |
| hidden = F.adaptive_avg_pool2d(hidden, coarse_size) | |
| hidden, coarse_mask, uncertainty_map = self.fine_processor(hidden, output_size) | |
| attention_features = F.interpolate(hidden, size=output_size, mode="bilinear", align_corners=False) | |
| low_features = F.interpolate(unadapted[1], size=output_size, mode="bilinear", align_corners=False) | |
| refined_mask = self.refine_head(attention_features, low_features, coarse_mask) | |
| spatial_gate = self.spatial_gate(torch.cat([coarse_mask, uncertainty_map], dim=1)) | |
| return spatial_gate * refined_mask + (1 - spatial_gate) * coarse_mask | |