sam3-text-onnx / export_sam3_decoder.py
danilobukvic's picture
Upload 6 files
3a8f269 verified
"""Export SAM3's decoder pipeline to ONNX.
Phase 3c — the hard one. Bundles geometry_encoder (skipped for now since
text-only prompts don't use it), detr_encoder, detr_decoder, mask_decoder,
and dot_product_scoring into a single ONNX file that takes pre-computed
vision FPN features + projected text features as input.
The point: avoid re-running the heavy vision/text encoders. The browser will
run vision_encoder.onnx and text_encoder.onnx once each, cache the outputs,
then call this decoder.onnx for the actual segmentation per image+prompt.
Inputs (all tensors, no structured types):
fpn_hidden_state_0,1,2: 3 FPN levels at spatial scales 288, 144, 72
fpn_position_encoding_0,1,2: matching position encodings
text_features: [B, 32, 256] projected text features
attention_mask: [B, 32] int64 (1=real token, 0=pad)
Outputs:
pred_masks: [B, num_queries, H, W]
pred_boxes: [B, num_queries, 4] xyxy format
pred_logits: [B, num_queries] classification scores
"""
from pathlib import Path
import torch
from torch import nn
from transformers import AutoTokenizer, Sam3Model
from transformers.models.sam3.image_processing_sam3 import Sam3ImageProcessor
from transformers.models.sam3.modeling_sam3 import (
inverse_sigmoid,
box_cxcywh_to_xyxy,
)
from PIL import Image
OUTPUT_DIR = Path("sam3-onnx-test")
OUTPUT_DIR.mkdir(exist_ok=True)
OUTPUT_FILE = OUTPUT_DIR / "decoder.onnx"
MODEL_ID = "facebook/sam3"
class WrappedDecoder(nn.Module):
"""Bundles the SAM3 decoder pipeline (detr_encoder → detr_decoder → mask_decoder).
Skips geometry prompts entirely — text-only path. Mirrors the relevant
portion of Sam3Model.forward() but with flat tensor I/O for ONNX export.
"""
def __init__(self, full_model: Sam3Model):
super().__init__()
self.detr_encoder = full_model.detr_encoder
self.detr_decoder = full_model.detr_decoder
self.mask_decoder = full_model.mask_decoder
self.dot_product_scoring = full_model.dot_product_scoring
def forward(
self,
fpn_hidden_state_0: torch.Tensor,
fpn_hidden_state_1: torch.Tensor,
fpn_hidden_state_2: torch.Tensor,
fpn_position_encoding_0: torch.Tensor,
fpn_position_encoding_1: torch.Tensor,
fpn_position_encoding_2: torch.Tensor,
text_features: torch.Tensor,
attention_mask: torch.Tensor,
):
fpn_hidden_states = (fpn_hidden_state_0, fpn_hidden_state_1, fpn_hidden_state_2)
fpn_position_encoding = (
fpn_position_encoding_0,
fpn_position_encoding_1,
fpn_position_encoding_2,
)
text_mask = attention_mask.bool()
combined_prompt_features = text_features
combined_prompt_mask = text_mask
# 1. DETR encoder operates on the smallest (most-pooled) FPN level + text
encoder_outputs = self.detr_encoder(
vision_features=[fpn_hidden_states[-1]],
text_features=combined_prompt_features,
vision_pos_embeds=[fpn_position_encoding[-1]],
text_mask=combined_prompt_mask,
)
# 2. DETR decoder produces object queries
decoder_outputs = self.detr_decoder(
vision_features=encoder_outputs.last_hidden_state,
text_features=encoder_outputs.text_features,
vision_pos_encoding=encoder_outputs.pos_embeds_flattened,
text_mask=combined_prompt_mask,
spatial_shapes=encoder_outputs.spatial_shapes,
)
# 3. Box predictions: refine reference boxes via decoder's box head
all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states)
reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes)
all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid()
all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh)
# 4. Classification scores: dot product between queries and text
all_pred_logits = self.dot_product_scoring(
decoder_hidden_states=decoder_outputs.intermediate_hidden_states,
text_features=encoder_outputs.text_features,
text_mask=combined_prompt_mask,
).squeeze(-1)
# We only return the FINAL decoder layer's predictions (the typical case)
pred_logits = all_pred_logits[-1]
pred_boxes = all_pred_boxes[-1]
decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1]
# 5. Mask decoder produces the actual segmentation masks
mask_outputs = self.mask_decoder(
decoder_queries=decoder_hidden_states,
backbone_features=list(fpn_hidden_states),
encoder_hidden_states=encoder_outputs.last_hidden_state,
prompt_features=combined_prompt_features,
prompt_mask=combined_prompt_mask,
)
return mask_outputs.pred_masks, pred_boxes, pred_logits
def main() -> None:
print(f"Loading {MODEL_ID} ...")
model = Sam3Model.from_pretrained(MODEL_ID)
model.eval()
# Build real inputs end-to-end using the actual vision + text encoders.
# We don't want to fabricate fake FPN tensors — they have to match the
# exact shape and statistical distribution the decoder was trained on.
print("\nBuilding real inputs by running the encoders ...")
image_processor = Sam3ImageProcessor.from_pretrained(MODEL_ID)
dummy_pil = Image.new("RGB", (640, 480), color=(128, 128, 128))
pixel_values = image_processor(images=dummy_pil, return_tensors="pt")["pixel_values"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
max_len = int(model.config.text_config.max_position_embeddings)
encoded = tokenizer(
"seed",
return_tensors="pt",
padding="max_length",
max_length=max_len,
truncation=True,
)
with torch.no_grad():
vision_out = model.vision_encoder(pixel_values)
text_out = model.get_text_features(
input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"]
)
# The decoder uses FPN[:-1] (first 3 of 4 levels)
fpn_h = vision_out.fpn_hidden_states[:-1]
fpn_p = vision_out.fpn_position_encoding[:-1]
text_features = text_out.pooler_output
attention_mask = encoded["attention_mask"]
print(f" FPN hidden states: {len(fpn_h)} tensors")
for i, t in enumerate(fpn_h):
print(f" [{i}] shape={tuple(t.shape)}")
print(f" text_features: {tuple(text_features.shape)}")
print(f" attention_mask: {tuple(attention_mask.shape)}")
# Smoke test the wrapped decoder
print("\nSmoke testing wrapped decoder in PyTorch ...")
wrapped = WrappedDecoder(model).eval()
with torch.no_grad():
pred_masks, pred_boxes, pred_logits = wrapped(
fpn_h[0], fpn_h[1], fpn_h[2],
fpn_p[0], fpn_p[1], fpn_p[2],
text_features,
attention_mask,
)
print(f" pred_masks: shape={tuple(pred_masks.shape)} dtype={pred_masks.dtype}")
print(f" pred_boxes: shape={tuple(pred_boxes.shape)} dtype={pred_boxes.dtype}")
print(f" pred_logits: shape={tuple(pred_logits.shape)} dtype={pred_logits.dtype}")
print(f" logits mean={pred_logits.mean().item():.4f} std={pred_logits.std().item():.4f}")
# Export
print(f"\nExporting to {OUTPUT_FILE} ...")
torch.onnx.export(
wrapped,
(
fpn_h[0], fpn_h[1], fpn_h[2],
fpn_p[0], fpn_p[1], fpn_p[2],
text_features,
attention_mask,
),
str(OUTPUT_FILE),
input_names=[
"fpn_hidden_state_0", "fpn_hidden_state_1", "fpn_hidden_state_2",
"fpn_position_encoding_0", "fpn_position_encoding_1", "fpn_position_encoding_2",
"text_features",
"attention_mask",
],
output_names=["pred_masks", "pred_boxes", "pred_logits"],
dynamic_axes={
"fpn_hidden_state_0": {0: "batch", 2: "h0", 3: "w0"},
"fpn_hidden_state_1": {0: "batch", 2: "h1", 3: "w1"},
"fpn_hidden_state_2": {0: "batch", 2: "h2", 3: "w2"},
"fpn_position_encoding_0": {0: "batch", 2: "h0", 3: "w0"},
"fpn_position_encoding_1": {0: "batch", 2: "h1", 3: "w1"},
"fpn_position_encoding_2": {0: "batch", 2: "h2", 3: "w2"},
"text_features": {0: "batch", 1: "text_seq"},
"attention_mask": {0: "batch", 1: "text_seq"},
"pred_masks": {0: "batch"},
"pred_boxes": {0: "batch"},
"pred_logits": {0: "batch"},
},
opset_version=18,
do_constant_folding=True,
verbose=False,
# SAM3's attention layers use .reshape() on transposed tensors with a
# dynamic batch dim, which trips PyTorch's new dynamo exporter (it can't
# trace the view-vs-copy decision symbolically). The legacy torch.jit.trace
# path handles this pattern fine. Force it.
dynamo=False,
)
size_mb = OUTPUT_FILE.stat().st_size / (1024 * 1024)
print(f"\n✅ Exported decoder: {OUTPUT_FILE} ({size_mb:.1f} MB graph)")
print("\nFiles in output dir:")
for f in sorted(OUTPUT_DIR.iterdir()):
size_mb = f.stat().st_size / (1024 * 1024)
print(f" {f.name}: {size_mb:.1f} MB")
if __name__ == "__main__":
main()