| |
| """ |
| Export DiT Transformer with unrolled ODE solver to ONNX format. |
| |
| The DiT transformer is the core denoising model in SAM Audio. It uses a flow-based |
| generative model with an ODE solver. For ONNX export, we unroll the fixed-step |
| midpoint ODE solver into a static computation graph. |
| |
| The default configuration uses: |
| - method: "midpoint" |
| - step_size: 2/32 (0.0625) |
| - integration range: [0, 1] |
| - total steps: 16 |
| |
| This creates a single ONNX model that performs the complete denoising process, |
| taking noise and conditioning as input and producing denoised audio features. |
| |
| Usage: |
| python -m onnx_export.export_dit --output-dir onnx_models --verify |
| """ |
|
|
| import os |
| import math |
| import argparse |
| import torch |
| import torch.nn as nn |
| from typing import Optional |
|
|
|
|
| class SinusoidalEmbedding(nn.Module): |
| """Sinusoidal timestep embedding (identical to SAMAudio implementation).""" |
| |
| def __init__(self, dim, theta=10000): |
| super().__init__() |
| assert (dim % 2) == 0 |
| half_dim = dim // 2 |
| inv_freq = torch.exp( |
| -math.log(theta) * torch.arange(half_dim).float() / half_dim |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def forward(self, x, pos=None): |
| if pos is None: |
| seq_len, device = x.shape[1], x.device |
| pos = torch.arange(seq_len, device=device) |
|
|
| emb = torch.einsum("i, j -> i j", pos, self.inv_freq) |
| emb = torch.cat((emb.cos(), emb.sin()), dim=-1) |
| return emb |
|
|
|
|
| class EmbedAnchors(nn.Module): |
| """Anchor embedding (identical to SAMAudio implementation).""" |
| |
| def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int): |
| super().__init__() |
| self.embed = nn.Embedding( |
| num_embeddings + 1, embedding_dim, padding_idx=num_embeddings |
| ) |
| self.gate = nn.Parameter(torch.tensor([0.0])) |
| self.proj = nn.Linear(embedding_dim, out_dim, bias=False) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| anchor_ids: Optional[torch.Tensor] = None, |
| anchor_alignment: Optional[torch.Tensor] = None, |
| ): |
| if anchor_ids is None: |
| return x |
|
|
| embs = self.embed(anchor_ids.gather(1, anchor_alignment)) |
| proj = self.proj(embs) |
| return x + self.gate.tanh() * proj |
|
|
|
|
| class DiTSingleStepWrapper(nn.Module): |
| """ |
| Wrapper for DiT that performs a single forward pass (one ODE evaluation). |
| |
| This mirrors the SAMAudio.forward() method exactly. |
| """ |
| |
| def __init__( |
| self, |
| transformer: nn.Module, |
| proj: nn.Module, |
| align_masked_video: nn.Module, |
| embed_anchors: nn.Module, |
| timestep_emb: nn.Module, |
| memory_proj: nn.Module, |
| ): |
| super().__init__() |
| self.transformer = transformer |
| self.proj = proj |
| self.align_masked_video = align_masked_video |
| self.embed_anchors = embed_anchors |
| self.timestep_emb = timestep_emb |
| self.memory_proj = memory_proj |
| |
| def forward( |
| self, |
| noisy_audio: torch.Tensor, |
| time: torch.Tensor, |
| audio_features: torch.Tensor, |
| text_features: torch.Tensor, |
| text_mask: torch.Tensor, |
| masked_video_features: torch.Tensor, |
| anchor_ids: torch.Tensor, |
| anchor_alignment: torch.Tensor, |
| audio_pad_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Single forward pass of the DiT (one ODE function evaluation). |
| |
| This exactly mirrors SAMAudio.forward() method. |
| """ |
| |
| |
| x = torch.cat( |
| [ |
| noisy_audio, |
| torch.zeros_like(audio_features), |
| audio_features, |
| ], |
| dim=2, |
| ) |
| |
| projected = self.proj(x) |
| aligned = self.align_masked_video(projected, masked_video_features) |
| aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment) |
| |
| |
| |
| timestep_emb_val = self.timestep_emb(time, pos=time).unsqueeze(1) |
| memory = self.memory_proj(text_features) + timestep_emb_val |
| |
| |
| output = self.transformer( |
| aligned, |
| time, |
| padding_mask=audio_pad_mask, |
| memory=memory, |
| memory_padding_mask=text_mask, |
| ) |
| |
| return output |
|
|
|
|
| class UnrolledDiTWrapper(nn.Module): |
| """ |
| DiT wrapper with unrolled midpoint ODE solver. |
| |
| The midpoint method computes: |
| k1 = f(t, y) |
| k2 = f(t + h/2, y + h/2 * k1) |
| y_new = y + h * k2 |
| |
| With step_size=0.0625 and range [0,1], we have 16 steps. |
| """ |
| |
| def __init__( |
| self, |
| single_step: DiTSingleStepWrapper, |
| num_steps: int = 16, |
| ): |
| super().__init__() |
| self.single_step = single_step |
| self.num_steps = num_steps |
| self.step_size = 1.0 / num_steps |
| |
| def forward( |
| self, |
| noise: torch.Tensor, |
| audio_features: torch.Tensor, |
| text_features: torch.Tensor, |
| text_mask: torch.Tensor, |
| masked_video_features: torch.Tensor, |
| anchor_ids: torch.Tensor, |
| anchor_alignment: torch.Tensor, |
| audio_pad_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Complete denoising using unrolled midpoint ODE solver.""" |
| B = noise.shape[0] |
| h = self.step_size |
| y = noise |
| t = torch.zeros(B, device=noise.device, dtype=noise.dtype) |
| |
| for step in range(self.num_steps): |
| |
| k1 = self.single_step( |
| y, t, |
| audio_features, text_features, text_mask, |
| masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask |
| ) |
| |
| |
| t_mid = t + h / 2 |
| y_mid = y + (h / 2) * k1 |
| k2 = self.single_step( |
| y_mid, t_mid, |
| audio_features, text_features, text_mask, |
| masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask |
| ) |
| |
| |
| y = y + h * k2 |
| t = t + h |
| |
| return y |
|
|
|
|
| def load_sam_audio_components(model_id: str = "facebook/sam-audio-small", device: str = "cpu"): |
| """ |
| Load SAM Audio components needed for DiT export. |
| |
| Since we can't load the full SAMAudio model (missing perception_models), |
| we construct the components directly and load weights from checkpoint. |
| """ |
| import json |
| import sys |
| import types |
| import importlib.util |
| from huggingface_hub import hf_hub_download |
| |
| print(f"Loading SAM Audio components from {model_id}...") |
| |
| |
| config_path = hf_hub_download(repo_id=model_id, filename="config.json") |
| with open(config_path) as f: |
| config = json.load(f) |
| |
| |
| checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt") |
| |
| |
| from onnx_export.standalone_config import TransformerConfig |
| |
| sam_audio_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| |
| |
| if 'sam_audio' not in sys.modules: |
| sam_audio_pkg = types.ModuleType('sam_audio') |
| sam_audio_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio')] |
| sys.modules['sam_audio'] = sam_audio_pkg |
| |
| if 'sam_audio.model' not in sys.modules: |
| model_pkg = types.ModuleType('sam_audio.model') |
| model_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio', 'model')] |
| sys.modules['sam_audio.model'] = model_pkg |
| |
| |
| if 'sam_audio.model.config' not in sys.modules: |
| import onnx_export.standalone_config as standalone_config |
| sys.modules['sam_audio.model.config'] = standalone_config |
| |
| |
| transformer_spec = importlib.util.spec_from_file_location( |
| "sam_audio.model.transformer", |
| os.path.join(sam_audio_path, "sam_audio", "model", "transformer.py") |
| ) |
| transformer_module = importlib.util.module_from_spec(transformer_spec) |
| sys.modules['sam_audio.model.transformer'] = transformer_module |
| transformer_spec.loader.exec_module(transformer_module) |
| DiT = transformer_module.DiT |
| |
| |
| align_spec = importlib.util.spec_from_file_location( |
| "sam_audio.model.align", |
| os.path.join(sam_audio_path, "sam_audio", "model", "align.py") |
| ) |
| align_module = importlib.util.module_from_spec(align_spec) |
| sys.modules['sam_audio.model.align'] = align_module |
| align_spec.loader.exec_module(align_module) |
| AlignModalities = align_module.AlignModalities |
| |
| |
| transformer_config = TransformerConfig(**config.get("transformer", {})) |
| transformer = DiT(transformer_config) |
| |
| |
| in_channels = config.get("in_channels", 768) |
| num_anchors = config.get("num_anchors", 3) |
| anchor_embedding_dim = config.get("anchor_embedding_dim", 128) |
| |
| |
| vision_config = config.get("vision_encoder", {}) |
| vision_dim = vision_config.get("dim", 768) |
| |
| |
| proj = nn.Linear(in_channels, transformer_config.d_model) |
| align_masked_video = AlignModalities(vision_dim, transformer_config.d_model) |
| embed_anchors = EmbedAnchors(num_anchors, anchor_embedding_dim, transformer_config.d_model) |
| timestep_emb = SinusoidalEmbedding(transformer_config.d_model) |
| |
| |
| text_encoder_config = config.get("text_encoder", {}) |
| text_encoder_dim = text_encoder_config.get("dim", 1024) |
| memory_proj = nn.Linear(text_encoder_dim, transformer_config.d_model) |
| |
| |
| print("Loading weights from checkpoint...") |
| state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True) |
| |
| |
| transformer_state = {} |
| proj_state = {} |
| align_state = {} |
| embed_anchors_state = {} |
| memory_proj_state = {} |
| |
| for key, value in state_dict.items(): |
| if key.startswith("transformer."): |
| new_key = key[len("transformer."):] |
| transformer_state[new_key] = value |
| elif key.startswith("proj."): |
| new_key = key[len("proj."):] |
| proj_state[new_key] = value |
| elif key.startswith("align_masked_video."): |
| new_key = key[len("align_masked_video."):] |
| align_state[new_key] = value |
| elif key.startswith("embed_anchors."): |
| new_key = key[len("embed_anchors."):] |
| embed_anchors_state[new_key] = value |
| elif key.startswith("memory_proj."): |
| new_key = key[len("memory_proj."):] |
| memory_proj_state[new_key] = value |
| |
| transformer.load_state_dict(transformer_state) |
| proj.load_state_dict(proj_state) |
| align_masked_video.load_state_dict(align_state) |
| embed_anchors.load_state_dict(embed_anchors_state) |
| memory_proj.load_state_dict(memory_proj_state) |
| |
| print(f" ✓ Loaded transformer weights ({len(transformer_state)} tensors)") |
| print(f" ✓ Loaded component weights") |
| |
| |
| single_step = DiTSingleStepWrapper( |
| transformer=transformer, |
| proj=proj, |
| align_masked_video=align_masked_video, |
| embed_anchors=embed_anchors, |
| timestep_emb=timestep_emb, |
| memory_proj=memory_proj, |
| ).eval().to(device) |
| |
| return single_step, config |
|
|
|
|
| def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "cpu"): |
| """Create sample inputs for tracing.""" |
| latent_dim = 128 |
| text_dim = 768 |
| vision_dim = 1024 |
| text_len = 77 |
| |
| return { |
| "noisy_audio": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device), |
| "time": torch.zeros(batch_size, device=device), |
| "audio_features": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device), |
| "text_features": torch.randn(batch_size, text_len, text_dim, device=device), |
| "text_mask": torch.ones(batch_size, text_len, dtype=torch.bool, device=device), |
| "masked_video_features": torch.zeros(batch_size, vision_dim, seq_len, device=device), |
| "anchor_ids": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device), |
| "anchor_alignment": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device), |
| "audio_pad_mask": torch.ones(batch_size, seq_len, dtype=torch.bool, device=device), |
| } |
|
|
|
|
| def export_dit_single_step( |
| single_step: DiTSingleStepWrapper, |
| output_path: str, |
| opset_version: int = 21, |
| device: str = "cpu", |
| fp16: bool = False, |
| ): |
| """Export single-step DiT to ONNX (for runtime ODE solving).""" |
| import onnx |
| |
| print(f"Exporting DiT single-step to {output_path}...") |
| |
| |
| if fp16: |
| print(" Converting model to FP16...") |
| single_step = single_step.half() |
| |
| sample_inputs = create_sample_inputs(device=device) |
| |
| |
| if fp16: |
| for key, value in sample_inputs.items(): |
| if value.dtype == torch.float32: |
| sample_inputs[key] = value.half() |
| |
| torch.onnx.export( |
| single_step, |
| tuple(sample_inputs.values()), |
| output_path, |
| input_names=list(sample_inputs.keys()), |
| output_names=["velocity"], |
| dynamic_axes={ |
| "noisy_audio": {0: "batch_size", 1: "seq_len"}, |
| "time": {0: "batch_size"}, |
| "audio_features": {0: "batch_size", 1: "seq_len"}, |
| "text_features": {0: "batch_size", 1: "text_len"}, |
| "text_mask": {0: "batch_size", 1: "text_len"}, |
| "masked_video_features": {0: "batch_size", 2: "seq_len"}, |
| "anchor_ids": {0: "batch_size", 1: "seq_len"}, |
| "anchor_alignment": {0: "batch_size", 1: "seq_len"}, |
| "audio_pad_mask": {0: "batch_size", 1: "seq_len"}, |
| "velocity": {0: "batch_size", 1: "seq_len"}, |
| }, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| dynamo=True, |
| external_data=True, |
| ) |
| |
| print(" ✓ DiT single-step exported successfully") |
| |
| |
| |
| |
| |
| external_data_path = output_path + ".data" |
| if os.path.exists(external_data_path): |
| print(f" ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)") |
| else: |
| raise RuntimeError(f"External data file missing: {external_data_path}") |
| |
| |
| model = onnx.load(output_path, load_external_data=False) |
| print(f" ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)") |
| |
| return True |
|
|
|
|
| def verify_dit_single_step( |
| single_step: DiTSingleStepWrapper, |
| onnx_path: str, |
| device: str = "cpu", |
| tolerance: float = 1e-3, |
| ) -> bool: |
| """Verify single-step ONNX output matches PyTorch.""" |
| import onnxruntime as ort |
| import numpy as np |
| |
| print("Verifying DiT single-step output...") |
| |
| sample_inputs = create_sample_inputs(device=device) |
| |
| |
| with torch.no_grad(): |
| pytorch_output = single_step(**sample_inputs).cpu().numpy() |
| |
| |
| sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
| |
| onnx_inputs = {} |
| for name, tensor in sample_inputs.items(): |
| if tensor.dtype == torch.bool: |
| onnx_inputs[name] = tensor.cpu().numpy().astype(bool) |
| elif tensor.dtype == torch.long: |
| onnx_inputs[name] = tensor.cpu().numpy().astype(np.int64) |
| else: |
| onnx_inputs[name] = tensor.cpu().numpy().astype(np.float32) |
| |
| onnx_output = sess.run(["velocity"], onnx_inputs)[0] |
| |
| |
| max_diff = np.abs(pytorch_output - onnx_output).max() |
| mean_diff = np.abs(pytorch_output - onnx_output).mean() |
| |
| print(f" Max difference: {max_diff:.2e}") |
| print(f" Mean difference: {mean_diff:.2e}") |
| |
| if max_diff < tolerance: |
| print(f" ✓ Verification passed (tolerance: {tolerance})") |
| return True |
| else: |
| print(f" ✗ Verification failed (tolerance: {tolerance})") |
| return False |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Export DiT Transformer to ONNX") |
| parser.add_argument( |
| "--model-id", |
| type=str, |
| default="facebook/sam-audio-small", |
| help="SAM Audio model ID from HuggingFace", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="onnx_models", |
| help="Output directory for ONNX models", |
| ) |
| parser.add_argument( |
| "--num-steps", |
| type=int, |
| default=16, |
| help="Number of ODE solver steps (default: 16)", |
| ) |
| parser.add_argument( |
| "--opset", |
| type=int, |
| default=21, |
| help="ONNX opset version (default: 21)", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cpu", |
| help="Device to use for export (default: cpu)", |
| ) |
| parser.add_argument( |
| "--verify", |
| action="store_true", |
| help="Verify ONNX output matches PyTorch", |
| ) |
| parser.add_argument( |
| "--tolerance", |
| type=float, |
| default=1e-3, |
| help="Tolerance for verification (default: 1e-3)", |
| ) |
| parser.add_argument( |
| "--fp16", |
| action="store_true", |
| help="Export model in FP16 precision (half the size)", |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| single_step, config = load_sam_audio_components(args.model_id, args.device) |
| |
| print(f"\nDiT Configuration:") |
| print(f" Model: {args.model_id}") |
| print(f" ODE steps: {args.num_steps}") |
| print(f" Step size: {1.0/args.num_steps:.4f}") |
| |
| |
| single_step_path = os.path.join(args.output_dir, "dit_single_step.onnx") |
| export_dit_single_step( |
| single_step, |
| single_step_path, |
| opset_version=args.opset, |
| device=args.device, |
| fp16=args.fp16, |
| ) |
| |
| if args.fp16: |
| print(f" ✓ Model exported in FP16 precision") |
| |
| |
| if args.verify: |
| verify_dit_single_step( |
| single_step, |
| single_step_path, |
| device=args.device, |
| tolerance=args.tolerance, |
| ) |
| |
| print(f"\n✓ Export complete! Model saved to {args.output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|