Mask Generation
Transformers
ONNX
sam3
sam-3
image-segmentation
text-promptable
open-vocabulary
concept-segmentation
Instructions to use danilobukvic/sam3-text-onnx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use danilobukvic/sam3-text-onnx with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("mask-generation", model="danilobukvic/sam3-text-onnx")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("danilobukvic/sam3-text-onnx", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 9,458 Bytes
3a8f269 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | """Export SAM3's decoder pipeline to ONNX.
Phase 3c — the hard one. Bundles geometry_encoder (skipped for now since
text-only prompts don't use it), detr_encoder, detr_decoder, mask_decoder,
and dot_product_scoring into a single ONNX file that takes pre-computed
vision FPN features + projected text features as input.
The point: avoid re-running the heavy vision/text encoders. The browser will
run vision_encoder.onnx and text_encoder.onnx once each, cache the outputs,
then call this decoder.onnx for the actual segmentation per image+prompt.
Inputs (all tensors, no structured types):
fpn_hidden_state_0,1,2: 3 FPN levels at spatial scales 288, 144, 72
fpn_position_encoding_0,1,2: matching position encodings
text_features: [B, 32, 256] projected text features
attention_mask: [B, 32] int64 (1=real token, 0=pad)
Outputs:
pred_masks: [B, num_queries, H, W]
pred_boxes: [B, num_queries, 4] xyxy format
pred_logits: [B, num_queries] classification scores
"""
from pathlib import Path
import torch
from torch import nn
from transformers import AutoTokenizer, Sam3Model
from transformers.models.sam3.image_processing_sam3 import Sam3ImageProcessor
from transformers.models.sam3.modeling_sam3 import (
inverse_sigmoid,
box_cxcywh_to_xyxy,
)
from PIL import Image
OUTPUT_DIR = Path("sam3-onnx-test")
OUTPUT_DIR.mkdir(exist_ok=True)
OUTPUT_FILE = OUTPUT_DIR / "decoder.onnx"
MODEL_ID = "facebook/sam3"
class WrappedDecoder(nn.Module):
"""Bundles the SAM3 decoder pipeline (detr_encoder → detr_decoder → mask_decoder).
Skips geometry prompts entirely — text-only path. Mirrors the relevant
portion of Sam3Model.forward() but with flat tensor I/O for ONNX export.
"""
def __init__(self, full_model: Sam3Model):
super().__init__()
self.detr_encoder = full_model.detr_encoder
self.detr_decoder = full_model.detr_decoder
self.mask_decoder = full_model.mask_decoder
self.dot_product_scoring = full_model.dot_product_scoring
def forward(
self,
fpn_hidden_state_0: torch.Tensor,
fpn_hidden_state_1: torch.Tensor,
fpn_hidden_state_2: torch.Tensor,
fpn_position_encoding_0: torch.Tensor,
fpn_position_encoding_1: torch.Tensor,
fpn_position_encoding_2: torch.Tensor,
text_features: torch.Tensor,
attention_mask: torch.Tensor,
):
fpn_hidden_states = (fpn_hidden_state_0, fpn_hidden_state_1, fpn_hidden_state_2)
fpn_position_encoding = (
fpn_position_encoding_0,
fpn_position_encoding_1,
fpn_position_encoding_2,
)
text_mask = attention_mask.bool()
combined_prompt_features = text_features
combined_prompt_mask = text_mask
# 1. DETR encoder operates on the smallest (most-pooled) FPN level + text
encoder_outputs = self.detr_encoder(
vision_features=[fpn_hidden_states[-1]],
text_features=combined_prompt_features,
vision_pos_embeds=[fpn_position_encoding[-1]],
text_mask=combined_prompt_mask,
)
# 2. DETR decoder produces object queries
decoder_outputs = self.detr_decoder(
vision_features=encoder_outputs.last_hidden_state,
text_features=encoder_outputs.text_features,
vision_pos_encoding=encoder_outputs.pos_embeds_flattened,
text_mask=combined_prompt_mask,
spatial_shapes=encoder_outputs.spatial_shapes,
)
# 3. Box predictions: refine reference boxes via decoder's box head
all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states)
reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes)
all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid()
all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh)
# 4. Classification scores: dot product between queries and text
all_pred_logits = self.dot_product_scoring(
decoder_hidden_states=decoder_outputs.intermediate_hidden_states,
text_features=encoder_outputs.text_features,
text_mask=combined_prompt_mask,
).squeeze(-1)
# We only return the FINAL decoder layer's predictions (the typical case)
pred_logits = all_pred_logits[-1]
pred_boxes = all_pred_boxes[-1]
decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1]
# 5. Mask decoder produces the actual segmentation masks
mask_outputs = self.mask_decoder(
decoder_queries=decoder_hidden_states,
backbone_features=list(fpn_hidden_states),
encoder_hidden_states=encoder_outputs.last_hidden_state,
prompt_features=combined_prompt_features,
prompt_mask=combined_prompt_mask,
)
return mask_outputs.pred_masks, pred_boxes, pred_logits
def main() -> None:
print(f"Loading {MODEL_ID} ...")
model = Sam3Model.from_pretrained(MODEL_ID)
model.eval()
# Build real inputs end-to-end using the actual vision + text encoders.
# We don't want to fabricate fake FPN tensors — they have to match the
# exact shape and statistical distribution the decoder was trained on.
print("\nBuilding real inputs by running the encoders ...")
image_processor = Sam3ImageProcessor.from_pretrained(MODEL_ID)
dummy_pil = Image.new("RGB", (640, 480), color=(128, 128, 128))
pixel_values = image_processor(images=dummy_pil, return_tensors="pt")["pixel_values"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
max_len = int(model.config.text_config.max_position_embeddings)
encoded = tokenizer(
"seed",
return_tensors="pt",
padding="max_length",
max_length=max_len,
truncation=True,
)
with torch.no_grad():
vision_out = model.vision_encoder(pixel_values)
text_out = model.get_text_features(
input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"]
)
# The decoder uses FPN[:-1] (first 3 of 4 levels)
fpn_h = vision_out.fpn_hidden_states[:-1]
fpn_p = vision_out.fpn_position_encoding[:-1]
text_features = text_out.pooler_output
attention_mask = encoded["attention_mask"]
print(f" FPN hidden states: {len(fpn_h)} tensors")
for i, t in enumerate(fpn_h):
print(f" [{i}] shape={tuple(t.shape)}")
print(f" text_features: {tuple(text_features.shape)}")
print(f" attention_mask: {tuple(attention_mask.shape)}")
# Smoke test the wrapped decoder
print("\nSmoke testing wrapped decoder in PyTorch ...")
wrapped = WrappedDecoder(model).eval()
with torch.no_grad():
pred_masks, pred_boxes, pred_logits = wrapped(
fpn_h[0], fpn_h[1], fpn_h[2],
fpn_p[0], fpn_p[1], fpn_p[2],
text_features,
attention_mask,
)
print(f" pred_masks: shape={tuple(pred_masks.shape)} dtype={pred_masks.dtype}")
print(f" pred_boxes: shape={tuple(pred_boxes.shape)} dtype={pred_boxes.dtype}")
print(f" pred_logits: shape={tuple(pred_logits.shape)} dtype={pred_logits.dtype}")
print(f" logits mean={pred_logits.mean().item():.4f} std={pred_logits.std().item():.4f}")
# Export
print(f"\nExporting to {OUTPUT_FILE} ...")
torch.onnx.export(
wrapped,
(
fpn_h[0], fpn_h[1], fpn_h[2],
fpn_p[0], fpn_p[1], fpn_p[2],
text_features,
attention_mask,
),
str(OUTPUT_FILE),
input_names=[
"fpn_hidden_state_0", "fpn_hidden_state_1", "fpn_hidden_state_2",
"fpn_position_encoding_0", "fpn_position_encoding_1", "fpn_position_encoding_2",
"text_features",
"attention_mask",
],
output_names=["pred_masks", "pred_boxes", "pred_logits"],
dynamic_axes={
"fpn_hidden_state_0": {0: "batch", 2: "h0", 3: "w0"},
"fpn_hidden_state_1": {0: "batch", 2: "h1", 3: "w1"},
"fpn_hidden_state_2": {0: "batch", 2: "h2", 3: "w2"},
"fpn_position_encoding_0": {0: "batch", 2: "h0", 3: "w0"},
"fpn_position_encoding_1": {0: "batch", 2: "h1", 3: "w1"},
"fpn_position_encoding_2": {0: "batch", 2: "h2", 3: "w2"},
"text_features": {0: "batch", 1: "text_seq"},
"attention_mask": {0: "batch", 1: "text_seq"},
"pred_masks": {0: "batch"},
"pred_boxes": {0: "batch"},
"pred_logits": {0: "batch"},
},
opset_version=18,
do_constant_folding=True,
verbose=False,
# SAM3's attention layers use .reshape() on transposed tensors with a
# dynamic batch dim, which trips PyTorch's new dynamo exporter (it can't
# trace the view-vs-copy decision symbolically). The legacy torch.jit.trace
# path handles this pattern fine. Force it.
dynamo=False,
)
size_mb = OUTPUT_FILE.stat().st_size / (1024 * 1024)
print(f"\n✅ Exported decoder: {OUTPUT_FILE} ({size_mb:.1f} MB graph)")
print("\nFiles in output dir:")
for f in sorted(OUTPUT_DIR.iterdir()):
size_mb = f.stat().st_size / (1024 * 1024)
print(f" {f.name}: {size_mb:.1f} MB")
if __name__ == "__main__":
main()
|