Mask Generation
Transformers
ONNX
sam3
sam-3
image-segmentation
text-promptable
open-vocabulary
concept-segmentation
Instructions to use danilobukvic/sam3-text-onnx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use danilobukvic/sam3-text-onnx with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("mask-generation", model="danilobukvic/sam3-text-onnx")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("danilobukvic/sam3-text-onnx", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Export SAM3's decoder pipeline to ONNX. | |
| Phase 3c — the hard one. Bundles geometry_encoder (skipped for now since | |
| text-only prompts don't use it), detr_encoder, detr_decoder, mask_decoder, | |
| and dot_product_scoring into a single ONNX file that takes pre-computed | |
| vision FPN features + projected text features as input. | |
| The point: avoid re-running the heavy vision/text encoders. The browser will | |
| run vision_encoder.onnx and text_encoder.onnx once each, cache the outputs, | |
| then call this decoder.onnx for the actual segmentation per image+prompt. | |
| Inputs (all tensors, no structured types): | |
| fpn_hidden_state_0,1,2: 3 FPN levels at spatial scales 288, 144, 72 | |
| fpn_position_encoding_0,1,2: matching position encodings | |
| text_features: [B, 32, 256] projected text features | |
| attention_mask: [B, 32] int64 (1=real token, 0=pad) | |
| Outputs: | |
| pred_masks: [B, num_queries, H, W] | |
| pred_boxes: [B, num_queries, 4] xyxy format | |
| pred_logits: [B, num_queries] classification scores | |
| """ | |
| from pathlib import Path | |
| import torch | |
| from torch import nn | |
| from transformers import AutoTokenizer, Sam3Model | |
| from transformers.models.sam3.image_processing_sam3 import Sam3ImageProcessor | |
| from transformers.models.sam3.modeling_sam3 import ( | |
| inverse_sigmoid, | |
| box_cxcywh_to_xyxy, | |
| ) | |
| from PIL import Image | |
| OUTPUT_DIR = Path("sam3-onnx-test") | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| OUTPUT_FILE = OUTPUT_DIR / "decoder.onnx" | |
| MODEL_ID = "facebook/sam3" | |
| class WrappedDecoder(nn.Module): | |
| """Bundles the SAM3 decoder pipeline (detr_encoder → detr_decoder → mask_decoder). | |
| Skips geometry prompts entirely — text-only path. Mirrors the relevant | |
| portion of Sam3Model.forward() but with flat tensor I/O for ONNX export. | |
| """ | |
| def __init__(self, full_model: Sam3Model): | |
| super().__init__() | |
| self.detr_encoder = full_model.detr_encoder | |
| self.detr_decoder = full_model.detr_decoder | |
| self.mask_decoder = full_model.mask_decoder | |
| self.dot_product_scoring = full_model.dot_product_scoring | |
| def forward( | |
| self, | |
| fpn_hidden_state_0: torch.Tensor, | |
| fpn_hidden_state_1: torch.Tensor, | |
| fpn_hidden_state_2: torch.Tensor, | |
| fpn_position_encoding_0: torch.Tensor, | |
| fpn_position_encoding_1: torch.Tensor, | |
| fpn_position_encoding_2: torch.Tensor, | |
| text_features: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| ): | |
| fpn_hidden_states = (fpn_hidden_state_0, fpn_hidden_state_1, fpn_hidden_state_2) | |
| fpn_position_encoding = ( | |
| fpn_position_encoding_0, | |
| fpn_position_encoding_1, | |
| fpn_position_encoding_2, | |
| ) | |
| text_mask = attention_mask.bool() | |
| combined_prompt_features = text_features | |
| combined_prompt_mask = text_mask | |
| # 1. DETR encoder operates on the smallest (most-pooled) FPN level + text | |
| encoder_outputs = self.detr_encoder( | |
| vision_features=[fpn_hidden_states[-1]], | |
| text_features=combined_prompt_features, | |
| vision_pos_embeds=[fpn_position_encoding[-1]], | |
| text_mask=combined_prompt_mask, | |
| ) | |
| # 2. DETR decoder produces object queries | |
| decoder_outputs = self.detr_decoder( | |
| vision_features=encoder_outputs.last_hidden_state, | |
| text_features=encoder_outputs.text_features, | |
| vision_pos_encoding=encoder_outputs.pos_embeds_flattened, | |
| text_mask=combined_prompt_mask, | |
| spatial_shapes=encoder_outputs.spatial_shapes, | |
| ) | |
| # 3. Box predictions: refine reference boxes via decoder's box head | |
| all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states) | |
| reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes) | |
| all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid() | |
| all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh) | |
| # 4. Classification scores: dot product between queries and text | |
| all_pred_logits = self.dot_product_scoring( | |
| decoder_hidden_states=decoder_outputs.intermediate_hidden_states, | |
| text_features=encoder_outputs.text_features, | |
| text_mask=combined_prompt_mask, | |
| ).squeeze(-1) | |
| # We only return the FINAL decoder layer's predictions (the typical case) | |
| pred_logits = all_pred_logits[-1] | |
| pred_boxes = all_pred_boxes[-1] | |
| decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1] | |
| # 5. Mask decoder produces the actual segmentation masks | |
| mask_outputs = self.mask_decoder( | |
| decoder_queries=decoder_hidden_states, | |
| backbone_features=list(fpn_hidden_states), | |
| encoder_hidden_states=encoder_outputs.last_hidden_state, | |
| prompt_features=combined_prompt_features, | |
| prompt_mask=combined_prompt_mask, | |
| ) | |
| return mask_outputs.pred_masks, pred_boxes, pred_logits | |
| def main() -> None: | |
| print(f"Loading {MODEL_ID} ...") | |
| model = Sam3Model.from_pretrained(MODEL_ID) | |
| model.eval() | |
| # Build real inputs end-to-end using the actual vision + text encoders. | |
| # We don't want to fabricate fake FPN tensors — they have to match the | |
| # exact shape and statistical distribution the decoder was trained on. | |
| print("\nBuilding real inputs by running the encoders ...") | |
| image_processor = Sam3ImageProcessor.from_pretrained(MODEL_ID) | |
| dummy_pil = Image.new("RGB", (640, 480), color=(128, 128, 128)) | |
| pixel_values = image_processor(images=dummy_pil, return_tensors="pt")["pixel_values"] | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| max_len = int(model.config.text_config.max_position_embeddings) | |
| encoded = tokenizer( | |
| "seed", | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=max_len, | |
| truncation=True, | |
| ) | |
| with torch.no_grad(): | |
| vision_out = model.vision_encoder(pixel_values) | |
| text_out = model.get_text_features( | |
| input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"] | |
| ) | |
| # The decoder uses FPN[:-1] (first 3 of 4 levels) | |
| fpn_h = vision_out.fpn_hidden_states[:-1] | |
| fpn_p = vision_out.fpn_position_encoding[:-1] | |
| text_features = text_out.pooler_output | |
| attention_mask = encoded["attention_mask"] | |
| print(f" FPN hidden states: {len(fpn_h)} tensors") | |
| for i, t in enumerate(fpn_h): | |
| print(f" [{i}] shape={tuple(t.shape)}") | |
| print(f" text_features: {tuple(text_features.shape)}") | |
| print(f" attention_mask: {tuple(attention_mask.shape)}") | |
| # Smoke test the wrapped decoder | |
| print("\nSmoke testing wrapped decoder in PyTorch ...") | |
| wrapped = WrappedDecoder(model).eval() | |
| with torch.no_grad(): | |
| pred_masks, pred_boxes, pred_logits = wrapped( | |
| fpn_h[0], fpn_h[1], fpn_h[2], | |
| fpn_p[0], fpn_p[1], fpn_p[2], | |
| text_features, | |
| attention_mask, | |
| ) | |
| print(f" pred_masks: shape={tuple(pred_masks.shape)} dtype={pred_masks.dtype}") | |
| print(f" pred_boxes: shape={tuple(pred_boxes.shape)} dtype={pred_boxes.dtype}") | |
| print(f" pred_logits: shape={tuple(pred_logits.shape)} dtype={pred_logits.dtype}") | |
| print(f" logits mean={pred_logits.mean().item():.4f} std={pred_logits.std().item():.4f}") | |
| # Export | |
| print(f"\nExporting to {OUTPUT_FILE} ...") | |
| torch.onnx.export( | |
| wrapped, | |
| ( | |
| fpn_h[0], fpn_h[1], fpn_h[2], | |
| fpn_p[0], fpn_p[1], fpn_p[2], | |
| text_features, | |
| attention_mask, | |
| ), | |
| str(OUTPUT_FILE), | |
| input_names=[ | |
| "fpn_hidden_state_0", "fpn_hidden_state_1", "fpn_hidden_state_2", | |
| "fpn_position_encoding_0", "fpn_position_encoding_1", "fpn_position_encoding_2", | |
| "text_features", | |
| "attention_mask", | |
| ], | |
| output_names=["pred_masks", "pred_boxes", "pred_logits"], | |
| dynamic_axes={ | |
| "fpn_hidden_state_0": {0: "batch", 2: "h0", 3: "w0"}, | |
| "fpn_hidden_state_1": {0: "batch", 2: "h1", 3: "w1"}, | |
| "fpn_hidden_state_2": {0: "batch", 2: "h2", 3: "w2"}, | |
| "fpn_position_encoding_0": {0: "batch", 2: "h0", 3: "w0"}, | |
| "fpn_position_encoding_1": {0: "batch", 2: "h1", 3: "w1"}, | |
| "fpn_position_encoding_2": {0: "batch", 2: "h2", 3: "w2"}, | |
| "text_features": {0: "batch", 1: "text_seq"}, | |
| "attention_mask": {0: "batch", 1: "text_seq"}, | |
| "pred_masks": {0: "batch"}, | |
| "pred_boxes": {0: "batch"}, | |
| "pred_logits": {0: "batch"}, | |
| }, | |
| opset_version=18, | |
| do_constant_folding=True, | |
| verbose=False, | |
| # SAM3's attention layers use .reshape() on transposed tensors with a | |
| # dynamic batch dim, which trips PyTorch's new dynamo exporter (it can't | |
| # trace the view-vs-copy decision symbolically). The legacy torch.jit.trace | |
| # path handles this pattern fine. Force it. | |
| dynamo=False, | |
| ) | |
| size_mb = OUTPUT_FILE.stat().st_size / (1024 * 1024) | |
| print(f"\n✅ Exported decoder: {OUTPUT_FILE} ({size_mb:.1f} MB graph)") | |
| print("\nFiles in output dir:") | |
| for f in sorted(OUTPUT_DIR.iterdir()): | |
| size_mb = f.stat().st_size / (1024 * 1024) | |
| print(f" {f.name}: {size_mb:.1f} MB") | |
| if __name__ == "__main__": | |
| main() | |