#!/usr/bin/env python3 """ Export DiT Transformer with unrolled ODE solver to ONNX format. The DiT transformer is the core denoising model in SAM Audio. It uses a flow-based generative model with an ODE solver. For ONNX export, we unroll the fixed-step midpoint ODE solver into a static computation graph. The default configuration uses: - method: "midpoint" - step_size: 2/32 (0.0625) - integration range: [0, 1] - total steps: 16 This creates a single ONNX model that performs the complete denoising process, taking noise and conditioning as input and producing denoised audio features. Usage: python -m onnx_export.export_dit --output-dir onnx_models --verify """ import os import math import argparse import torch import torch.nn as nn from typing import Optional class SinusoidalEmbedding(nn.Module): """Sinusoidal timestep embedding (identical to SAMAudio implementation).""" def __init__(self, dim, theta=10000): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 inv_freq = torch.exp( -math.log(theta) * torch.arange(half_dim).float() / half_dim ) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, x, pos=None): if pos is None: seq_len, device = x.shape[1], x.device pos = torch.arange(seq_len, device=device) emb = torch.einsum("i, j -> i j", pos, self.inv_freq) emb = torch.cat((emb.cos(), emb.sin()), dim=-1) return emb class EmbedAnchors(nn.Module): """Anchor embedding (identical to SAMAudio implementation).""" def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int): super().__init__() self.embed = nn.Embedding( num_embeddings + 1, embedding_dim, padding_idx=num_embeddings ) self.gate = nn.Parameter(torch.tensor([0.0])) self.proj = nn.Linear(embedding_dim, out_dim, bias=False) def forward( self, x: torch.Tensor, anchor_ids: Optional[torch.Tensor] = None, anchor_alignment: Optional[torch.Tensor] = None, ): if anchor_ids is None: return x embs = self.embed(anchor_ids.gather(1, anchor_alignment)) proj = self.proj(embs) return x + self.gate.tanh() * proj class DiTSingleStepWrapper(nn.Module): """ Wrapper for DiT that performs a single forward pass (one ODE evaluation). This mirrors the SAMAudio.forward() method exactly. """ def __init__( self, transformer: nn.Module, proj: nn.Module, align_masked_video: nn.Module, embed_anchors: nn.Module, timestep_emb: nn.Module, memory_proj: nn.Module, ): super().__init__() self.transformer = transformer self.proj = proj self.align_masked_video = align_masked_video self.embed_anchors = embed_anchors self.timestep_emb = timestep_emb self.memory_proj = memory_proj def forward( self, noisy_audio: torch.Tensor, time: torch.Tensor, audio_features: torch.Tensor, text_features: torch.Tensor, text_mask: torch.Tensor, masked_video_features: torch.Tensor, anchor_ids: torch.Tensor, anchor_alignment: torch.Tensor, audio_pad_mask: torch.Tensor, ) -> torch.Tensor: """ Single forward pass of the DiT (one ODE function evaluation). This exactly mirrors SAMAudio.forward() method. """ # Align inputs (concatenate noisy_audio with audio_features) # Same as SAMAudio.align_inputs() x = torch.cat( [ noisy_audio, torch.zeros_like(audio_features), audio_features, ], dim=2, ) projected = self.proj(x) aligned = self.align_masked_video(projected, masked_video_features) aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment) # Timestep embedding and memory # Same as SAMAudio.forward() timestep_emb_val = self.timestep_emb(time, pos=time).unsqueeze(1) memory = self.memory_proj(text_features) + timestep_emb_val # Transformer forward output = self.transformer( aligned, time, padding_mask=audio_pad_mask, memory=memory, memory_padding_mask=text_mask, ) return output class UnrolledDiTWrapper(nn.Module): """ DiT wrapper with unrolled midpoint ODE solver. The midpoint method computes: k1 = f(t, y) k2 = f(t + h/2, y + h/2 * k1) y_new = y + h * k2 With step_size=0.0625 and range [0,1], we have 16 steps. """ def __init__( self, single_step: DiTSingleStepWrapper, num_steps: int = 16, ): super().__init__() self.single_step = single_step self.num_steps = num_steps self.step_size = 1.0 / num_steps def forward( self, noise: torch.Tensor, audio_features: torch.Tensor, text_features: torch.Tensor, text_mask: torch.Tensor, masked_video_features: torch.Tensor, anchor_ids: torch.Tensor, anchor_alignment: torch.Tensor, audio_pad_mask: torch.Tensor, ) -> torch.Tensor: """Complete denoising using unrolled midpoint ODE solver.""" B = noise.shape[0] h = self.step_size y = noise t = torch.zeros(B, device=noise.device, dtype=noise.dtype) for step in range(self.num_steps): # k1 = f(t, y) k1 = self.single_step( y, t, audio_features, text_features, text_mask, masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask ) # k2 = f(t + h/2, y + h/2 * k1) t_mid = t + h / 2 y_mid = y + (h / 2) * k1 k2 = self.single_step( y_mid, t_mid, audio_features, text_features, text_mask, masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask ) # y = y + h * k2 y = y + h * k2 t = t + h return y def load_sam_audio_components(model_id: str = "facebook/sam-audio-small", device: str = "cpu"): """ Load SAM Audio components needed for DiT export. Since we can't load the full SAMAudio model (missing perception_models), we construct the components directly and load weights from checkpoint. """ import json import sys import types import importlib.util from huggingface_hub import hf_hub_download print(f"Loading SAM Audio components from {model_id}...") # Download config config_path = hf_hub_download(repo_id=model_id, filename="config.json") with open(config_path) as f: config = json.load(f) # Download checkpoint checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt") # Use our standalone config that doesn't have 'core' dependencies from onnx_export.standalone_config import TransformerConfig sam_audio_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Create fake module hierarchy so transformer.py's relative imports work if 'sam_audio' not in sys.modules: sam_audio_pkg = types.ModuleType('sam_audio') sam_audio_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio')] sys.modules['sam_audio'] = sam_audio_pkg if 'sam_audio.model' not in sys.modules: model_pkg = types.ModuleType('sam_audio.model') model_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio', 'model')] sys.modules['sam_audio.model'] = model_pkg # Register our standalone config as sam_audio.model.config if 'sam_audio.model.config' not in sys.modules: import onnx_export.standalone_config as standalone_config sys.modules['sam_audio.model.config'] = standalone_config # Now import transformer module - it will use our standalone config transformer_spec = importlib.util.spec_from_file_location( "sam_audio.model.transformer", os.path.join(sam_audio_path, "sam_audio", "model", "transformer.py") ) transformer_module = importlib.util.module_from_spec(transformer_spec) sys.modules['sam_audio.model.transformer'] = transformer_module transformer_spec.loader.exec_module(transformer_module) DiT = transformer_module.DiT # Import align module align_spec = importlib.util.spec_from_file_location( "sam_audio.model.align", os.path.join(sam_audio_path, "sam_audio", "model", "align.py") ) align_module = importlib.util.module_from_spec(align_spec) sys.modules['sam_audio.model.align'] = align_module align_spec.loader.exec_module(align_module) AlignModalities = align_module.AlignModalities # Create transformer transformer_config = TransformerConfig(**config.get("transformer", {})) transformer = DiT(transformer_config) # Calculate dimensions in_channels = config.get("in_channels", 768) num_anchors = config.get("num_anchors", 3) anchor_embedding_dim = config.get("anchor_embedding_dim", 128) # Get vision encoder dim for align_masked_video vision_config = config.get("vision_encoder", {}) vision_dim = vision_config.get("dim", 768) # Create components exactly as SAMAudio does proj = nn.Linear(in_channels, transformer_config.d_model) align_masked_video = AlignModalities(vision_dim, transformer_config.d_model) embed_anchors = EmbedAnchors(num_anchors, anchor_embedding_dim, transformer_config.d_model) timestep_emb = SinusoidalEmbedding(transformer_config.d_model) # Memory projection for text features text_encoder_config = config.get("text_encoder", {}) text_encoder_dim = text_encoder_config.get("dim", 1024) # google/flan-t5-large memory_proj = nn.Linear(text_encoder_dim, transformer_config.d_model) # Load weights from checkpoint print("Loading weights from checkpoint...") state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True) # Filter and load weights for each component transformer_state = {} proj_state = {} align_state = {} embed_anchors_state = {} memory_proj_state = {} for key, value in state_dict.items(): if key.startswith("transformer."): new_key = key[len("transformer."):] transformer_state[new_key] = value elif key.startswith("proj."): new_key = key[len("proj."):] proj_state[new_key] = value elif key.startswith("align_masked_video."): new_key = key[len("align_masked_video."):] align_state[new_key] = value elif key.startswith("embed_anchors."): new_key = key[len("embed_anchors."):] embed_anchors_state[new_key] = value elif key.startswith("memory_proj."): new_key = key[len("memory_proj."):] memory_proj_state[new_key] = value transformer.load_state_dict(transformer_state) proj.load_state_dict(proj_state) align_masked_video.load_state_dict(align_state) embed_anchors.load_state_dict(embed_anchors_state) memory_proj.load_state_dict(memory_proj_state) print(f" ✓ Loaded transformer weights ({len(transformer_state)} tensors)") print(f" ✓ Loaded component weights") # Create single step wrapper single_step = DiTSingleStepWrapper( transformer=transformer, proj=proj, align_masked_video=align_masked_video, embed_anchors=embed_anchors, timestep_emb=timestep_emb, memory_proj=memory_proj, ).eval().to(device) return single_step, config def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "cpu"): """Create sample inputs for tracing.""" latent_dim = 128 text_dim = 768 # T5-base hidden size (SAM Audio was trained with 768-dim text) vision_dim = 1024 # Vision encoder dim from config text_len = 77 return { "noisy_audio": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device), "time": torch.zeros(batch_size, device=device), "audio_features": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device), "text_features": torch.randn(batch_size, text_len, text_dim, device=device), "text_mask": torch.ones(batch_size, text_len, dtype=torch.bool, device=device), "masked_video_features": torch.zeros(batch_size, vision_dim, seq_len, device=device), "anchor_ids": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device), "anchor_alignment": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device), "audio_pad_mask": torch.ones(batch_size, seq_len, dtype=torch.bool, device=device), } def export_dit_single_step( single_step: DiTSingleStepWrapper, output_path: str, opset_version: int = 21, device: str = "cpu", fp16: bool = False, ): """Export single-step DiT to ONNX (for runtime ODE solving).""" import onnx print(f"Exporting DiT single-step to {output_path}...") # Convert to FP16 if requested if fp16: print(" Converting model to FP16...") single_step = single_step.half() sample_inputs = create_sample_inputs(device=device) # Convert float inputs to FP16 if exporting in FP16 if fp16: for key, value in sample_inputs.items(): if value.dtype == torch.float32: sample_inputs[key] = value.half() torch.onnx.export( single_step, tuple(sample_inputs.values()), output_path, input_names=list(sample_inputs.keys()), output_names=["velocity"], dynamic_axes={ "noisy_audio": {0: "batch_size", 1: "seq_len"}, "time": {0: "batch_size"}, "audio_features": {0: "batch_size", 1: "seq_len"}, "text_features": {0: "batch_size", 1: "text_len"}, "text_mask": {0: "batch_size", 1: "text_len"}, "masked_video_features": {0: "batch_size", 2: "seq_len"}, "anchor_ids": {0: "batch_size", 1: "seq_len"}, "anchor_alignment": {0: "batch_size", 1: "seq_len"}, "audio_pad_mask": {0: "batch_size", 1: "seq_len"}, "velocity": {0: "batch_size", 1: "seq_len"}, }, opset_version=opset_version, do_constant_folding=True, dynamo=True, external_data=True, ) print(" ✓ DiT single-step exported successfully") # When using external_data=True, we can't run check_model on a model # loaded without external data - the checker validates data references. # Since torch.onnx.export with dynamo=True already validates the model, # we just verify the files exist. external_data_path = output_path + ".data" if os.path.exists(external_data_path): print(f" ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)") else: raise RuntimeError(f"External data file missing: {external_data_path}") # Verify the ONNX file structure is valid (without loading weights) model = onnx.load(output_path, load_external_data=False) print(f" ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)") return True def verify_dit_single_step( single_step: DiTSingleStepWrapper, onnx_path: str, device: str = "cpu", tolerance: float = 1e-3, ) -> bool: """Verify single-step ONNX output matches PyTorch.""" import onnxruntime as ort import numpy as np print("Verifying DiT single-step output...") sample_inputs = create_sample_inputs(device=device) # PyTorch output with torch.no_grad(): pytorch_output = single_step(**sample_inputs).cpu().numpy() # ONNX Runtime output sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) onnx_inputs = {} for name, tensor in sample_inputs.items(): if tensor.dtype == torch.bool: onnx_inputs[name] = tensor.cpu().numpy().astype(bool) elif tensor.dtype == torch.long: onnx_inputs[name] = tensor.cpu().numpy().astype(np.int64) else: onnx_inputs[name] = tensor.cpu().numpy().astype(np.float32) onnx_output = sess.run(["velocity"], onnx_inputs)[0] # Compare max_diff = np.abs(pytorch_output - onnx_output).max() mean_diff = np.abs(pytorch_output - onnx_output).mean() print(f" Max difference: {max_diff:.2e}") print(f" Mean difference: {mean_diff:.2e}") if max_diff < tolerance: print(f" ✓ Verification passed (tolerance: {tolerance})") return True else: print(f" ✗ Verification failed (tolerance: {tolerance})") return False def main(): parser = argparse.ArgumentParser(description="Export DiT Transformer to ONNX") parser.add_argument( "--model-id", type=str, default="facebook/sam-audio-small", help="SAM Audio model ID from HuggingFace", ) parser.add_argument( "--output-dir", type=str, default="onnx_models", help="Output directory for ONNX models", ) parser.add_argument( "--num-steps", type=int, default=16, help="Number of ODE solver steps (default: 16)", ) parser.add_argument( "--opset", type=int, default=21, help="ONNX opset version (default: 21)", ) parser.add_argument( "--device", type=str, default="cpu", help="Device to use for export (default: cpu)", ) parser.add_argument( "--verify", action="store_true", help="Verify ONNX output matches PyTorch", ) parser.add_argument( "--tolerance", type=float, default=1e-3, help="Tolerance for verification (default: 1e-3)", ) parser.add_argument( "--fp16", action="store_true", help="Export model in FP16 precision (half the size)", ) args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Load components single_step, config = load_sam_audio_components(args.model_id, args.device) print(f"\nDiT Configuration:") print(f" Model: {args.model_id}") print(f" ODE steps: {args.num_steps}") print(f" Step size: {1.0/args.num_steps:.4f}") # Export single-step model single_step_path = os.path.join(args.output_dir, "dit_single_step.onnx") export_dit_single_step( single_step, single_step_path, opset_version=args.opset, device=args.device, fp16=args.fp16, ) if args.fp16: print(f" ✓ Model exported in FP16 precision") # Verify single-step if args.verify: verify_dit_single_step( single_step, single_step_path, device=args.device, tolerance=args.tolerance, ) print(f"\n✓ Export complete! Model saved to {args.output_dir}") if __name__ == "__main__": main()