"""Export SAM3's decoder pipeline to ONNX. Phase 3c — the hard one. Bundles geometry_encoder (skipped for now since text-only prompts don't use it), detr_encoder, detr_decoder, mask_decoder, and dot_product_scoring into a single ONNX file that takes pre-computed vision FPN features + projected text features as input. The point: avoid re-running the heavy vision/text encoders. The browser will run vision_encoder.onnx and text_encoder.onnx once each, cache the outputs, then call this decoder.onnx for the actual segmentation per image+prompt. Inputs (all tensors, no structured types): fpn_hidden_state_0,1,2: 3 FPN levels at spatial scales 288, 144, 72 fpn_position_encoding_0,1,2: matching position encodings text_features: [B, 32, 256] projected text features attention_mask: [B, 32] int64 (1=real token, 0=pad) Outputs: pred_masks: [B, num_queries, H, W] pred_boxes: [B, num_queries, 4] xyxy format pred_logits: [B, num_queries] classification scores """ from pathlib import Path import torch from torch import nn from transformers import AutoTokenizer, Sam3Model from transformers.models.sam3.image_processing_sam3 import Sam3ImageProcessor from transformers.models.sam3.modeling_sam3 import ( inverse_sigmoid, box_cxcywh_to_xyxy, ) from PIL import Image OUTPUT_DIR = Path("sam3-onnx-test") OUTPUT_DIR.mkdir(exist_ok=True) OUTPUT_FILE = OUTPUT_DIR / "decoder.onnx" MODEL_ID = "facebook/sam3" class WrappedDecoder(nn.Module): """Bundles the SAM3 decoder pipeline (detr_encoder → detr_decoder → mask_decoder). Skips geometry prompts entirely — text-only path. Mirrors the relevant portion of Sam3Model.forward() but with flat tensor I/O for ONNX export. """ def __init__(self, full_model: Sam3Model): super().__init__() self.detr_encoder = full_model.detr_encoder self.detr_decoder = full_model.detr_decoder self.mask_decoder = full_model.mask_decoder self.dot_product_scoring = full_model.dot_product_scoring def forward( self, fpn_hidden_state_0: torch.Tensor, fpn_hidden_state_1: torch.Tensor, fpn_hidden_state_2: torch.Tensor, fpn_position_encoding_0: torch.Tensor, fpn_position_encoding_1: torch.Tensor, fpn_position_encoding_2: torch.Tensor, text_features: torch.Tensor, attention_mask: torch.Tensor, ): fpn_hidden_states = (fpn_hidden_state_0, fpn_hidden_state_1, fpn_hidden_state_2) fpn_position_encoding = ( fpn_position_encoding_0, fpn_position_encoding_1, fpn_position_encoding_2, ) text_mask = attention_mask.bool() combined_prompt_features = text_features combined_prompt_mask = text_mask # 1. DETR encoder operates on the smallest (most-pooled) FPN level + text encoder_outputs = self.detr_encoder( vision_features=[fpn_hidden_states[-1]], text_features=combined_prompt_features, vision_pos_embeds=[fpn_position_encoding[-1]], text_mask=combined_prompt_mask, ) # 2. DETR decoder produces object queries decoder_outputs = self.detr_decoder( vision_features=encoder_outputs.last_hidden_state, text_features=encoder_outputs.text_features, vision_pos_encoding=encoder_outputs.pos_embeds_flattened, text_mask=combined_prompt_mask, spatial_shapes=encoder_outputs.spatial_shapes, ) # 3. Box predictions: refine reference boxes via decoder's box head all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states) reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes) all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid() all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh) # 4. Classification scores: dot product between queries and text all_pred_logits = self.dot_product_scoring( decoder_hidden_states=decoder_outputs.intermediate_hidden_states, text_features=encoder_outputs.text_features, text_mask=combined_prompt_mask, ).squeeze(-1) # We only return the FINAL decoder layer's predictions (the typical case) pred_logits = all_pred_logits[-1] pred_boxes = all_pred_boxes[-1] decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1] # 5. Mask decoder produces the actual segmentation masks mask_outputs = self.mask_decoder( decoder_queries=decoder_hidden_states, backbone_features=list(fpn_hidden_states), encoder_hidden_states=encoder_outputs.last_hidden_state, prompt_features=combined_prompt_features, prompt_mask=combined_prompt_mask, ) return mask_outputs.pred_masks, pred_boxes, pred_logits def main() -> None: print(f"Loading {MODEL_ID} ...") model = Sam3Model.from_pretrained(MODEL_ID) model.eval() # Build real inputs end-to-end using the actual vision + text encoders. # We don't want to fabricate fake FPN tensors — they have to match the # exact shape and statistical distribution the decoder was trained on. print("\nBuilding real inputs by running the encoders ...") image_processor = Sam3ImageProcessor.from_pretrained(MODEL_ID) dummy_pil = Image.new("RGB", (640, 480), color=(128, 128, 128)) pixel_values = image_processor(images=dummy_pil, return_tensors="pt")["pixel_values"] tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) max_len = int(model.config.text_config.max_position_embeddings) encoded = tokenizer( "seed", return_tensors="pt", padding="max_length", max_length=max_len, truncation=True, ) with torch.no_grad(): vision_out = model.vision_encoder(pixel_values) text_out = model.get_text_features( input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"] ) # The decoder uses FPN[:-1] (first 3 of 4 levels) fpn_h = vision_out.fpn_hidden_states[:-1] fpn_p = vision_out.fpn_position_encoding[:-1] text_features = text_out.pooler_output attention_mask = encoded["attention_mask"] print(f" FPN hidden states: {len(fpn_h)} tensors") for i, t in enumerate(fpn_h): print(f" [{i}] shape={tuple(t.shape)}") print(f" text_features: {tuple(text_features.shape)}") print(f" attention_mask: {tuple(attention_mask.shape)}") # Smoke test the wrapped decoder print("\nSmoke testing wrapped decoder in PyTorch ...") wrapped = WrappedDecoder(model).eval() with torch.no_grad(): pred_masks, pred_boxes, pred_logits = wrapped( fpn_h[0], fpn_h[1], fpn_h[2], fpn_p[0], fpn_p[1], fpn_p[2], text_features, attention_mask, ) print(f" pred_masks: shape={tuple(pred_masks.shape)} dtype={pred_masks.dtype}") print(f" pred_boxes: shape={tuple(pred_boxes.shape)} dtype={pred_boxes.dtype}") print(f" pred_logits: shape={tuple(pred_logits.shape)} dtype={pred_logits.dtype}") print(f" logits mean={pred_logits.mean().item():.4f} std={pred_logits.std().item():.4f}") # Export print(f"\nExporting to {OUTPUT_FILE} ...") torch.onnx.export( wrapped, ( fpn_h[0], fpn_h[1], fpn_h[2], fpn_p[0], fpn_p[1], fpn_p[2], text_features, attention_mask, ), str(OUTPUT_FILE), input_names=[ "fpn_hidden_state_0", "fpn_hidden_state_1", "fpn_hidden_state_2", "fpn_position_encoding_0", "fpn_position_encoding_1", "fpn_position_encoding_2", "text_features", "attention_mask", ], output_names=["pred_masks", "pred_boxes", "pred_logits"], dynamic_axes={ "fpn_hidden_state_0": {0: "batch", 2: "h0", 3: "w0"}, "fpn_hidden_state_1": {0: "batch", 2: "h1", 3: "w1"}, "fpn_hidden_state_2": {0: "batch", 2: "h2", 3: "w2"}, "fpn_position_encoding_0": {0: "batch", 2: "h0", 3: "w0"}, "fpn_position_encoding_1": {0: "batch", 2: "h1", 3: "w1"}, "fpn_position_encoding_2": {0: "batch", 2: "h2", 3: "w2"}, "text_features": {0: "batch", 1: "text_seq"}, "attention_mask": {0: "batch", 1: "text_seq"}, "pred_masks": {0: "batch"}, "pred_boxes": {0: "batch"}, "pred_logits": {0: "batch"}, }, opset_version=18, do_constant_folding=True, verbose=False, # SAM3's attention layers use .reshape() on transposed tensors with a # dynamic batch dim, which trips PyTorch's new dynamo exporter (it can't # trace the view-vs-copy decision symbolically). The legacy torch.jit.trace # path handles this pattern fine. Force it. dynamo=False, ) size_mb = OUTPUT_FILE.stat().st_size / (1024 * 1024) print(f"\n✅ Exported decoder: {OUTPUT_FILE} ({size_mb:.1f} MB graph)") print("\nFiles in output dir:") for f in sorted(OUTPUT_DIR.iterdir()): size_mb = f.stat().st_size / (1024 * 1024) print(f" {f.name}: {size_mb:.1f} MB") if __name__ == "__main__": main()