matbee's picture
Upload folder using huggingface_hub
07823f7 verified
#!/usr/bin/env python3
"""
Export DiT Transformer with unrolled ODE solver to ONNX format.
The DiT transformer is the core denoising model in SAM Audio. It uses a flow-based
generative model with an ODE solver. For ONNX export, we unroll the fixed-step
midpoint ODE solver into a static computation graph.
The default configuration uses:
- method: "midpoint"
- step_size: 2/32 (0.0625)
- integration range: [0, 1]
- total steps: 16
This creates a single ONNX model that performs the complete denoising process,
taking noise and conditioning as input and producing denoised audio features.
Usage:
python -m onnx_export.export_dit --output-dir onnx_models --verify
"""
import os
import math
import argparse
import torch
import torch.nn as nn
from typing import Optional
class SinusoidalEmbedding(nn.Module):
"""Sinusoidal timestep embedding (identical to SAMAudio implementation)."""
def __init__(self, dim, theta=10000):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
inv_freq = torch.exp(
-math.log(theta) * torch.arange(half_dim).float() / half_dim
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, pos=None):
if pos is None:
seq_len, device = x.shape[1], x.device
pos = torch.arange(seq_len, device=device)
emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
return emb
class EmbedAnchors(nn.Module):
"""Anchor embedding (identical to SAMAudio implementation)."""
def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
super().__init__()
self.embed = nn.Embedding(
num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
)
self.gate = nn.Parameter(torch.tensor([0.0]))
self.proj = nn.Linear(embedding_dim, out_dim, bias=False)
def forward(
self,
x: torch.Tensor,
anchor_ids: Optional[torch.Tensor] = None,
anchor_alignment: Optional[torch.Tensor] = None,
):
if anchor_ids is None:
return x
embs = self.embed(anchor_ids.gather(1, anchor_alignment))
proj = self.proj(embs)
return x + self.gate.tanh() * proj
class DiTSingleStepWrapper(nn.Module):
"""
Wrapper for DiT that performs a single forward pass (one ODE evaluation).
This mirrors the SAMAudio.forward() method exactly.
"""
def __init__(
self,
transformer: nn.Module,
proj: nn.Module,
align_masked_video: nn.Module,
embed_anchors: nn.Module,
timestep_emb: nn.Module,
memory_proj: nn.Module,
):
super().__init__()
self.transformer = transformer
self.proj = proj
self.align_masked_video = align_masked_video
self.embed_anchors = embed_anchors
self.timestep_emb = timestep_emb
self.memory_proj = memory_proj
def forward(
self,
noisy_audio: torch.Tensor,
time: torch.Tensor,
audio_features: torch.Tensor,
text_features: torch.Tensor,
text_mask: torch.Tensor,
masked_video_features: torch.Tensor,
anchor_ids: torch.Tensor,
anchor_alignment: torch.Tensor,
audio_pad_mask: torch.Tensor,
) -> torch.Tensor:
"""
Single forward pass of the DiT (one ODE function evaluation).
This exactly mirrors SAMAudio.forward() method.
"""
# Align inputs (concatenate noisy_audio with audio_features)
# Same as SAMAudio.align_inputs()
x = torch.cat(
[
noisy_audio,
torch.zeros_like(audio_features),
audio_features,
],
dim=2,
)
projected = self.proj(x)
aligned = self.align_masked_video(projected, masked_video_features)
aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
# Timestep embedding and memory
# Same as SAMAudio.forward()
timestep_emb_val = self.timestep_emb(time, pos=time).unsqueeze(1)
memory = self.memory_proj(text_features) + timestep_emb_val
# Transformer forward
output = self.transformer(
aligned,
time,
padding_mask=audio_pad_mask,
memory=memory,
memory_padding_mask=text_mask,
)
return output
class UnrolledDiTWrapper(nn.Module):
"""
DiT wrapper with unrolled midpoint ODE solver.
The midpoint method computes:
k1 = f(t, y)
k2 = f(t + h/2, y + h/2 * k1)
y_new = y + h * k2
With step_size=0.0625 and range [0,1], we have 16 steps.
"""
def __init__(
self,
single_step: DiTSingleStepWrapper,
num_steps: int = 16,
):
super().__init__()
self.single_step = single_step
self.num_steps = num_steps
self.step_size = 1.0 / num_steps
def forward(
self,
noise: torch.Tensor,
audio_features: torch.Tensor,
text_features: torch.Tensor,
text_mask: torch.Tensor,
masked_video_features: torch.Tensor,
anchor_ids: torch.Tensor,
anchor_alignment: torch.Tensor,
audio_pad_mask: torch.Tensor,
) -> torch.Tensor:
"""Complete denoising using unrolled midpoint ODE solver."""
B = noise.shape[0]
h = self.step_size
y = noise
t = torch.zeros(B, device=noise.device, dtype=noise.dtype)
for step in range(self.num_steps):
# k1 = f(t, y)
k1 = self.single_step(
y, t,
audio_features, text_features, text_mask,
masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
)
# k2 = f(t + h/2, y + h/2 * k1)
t_mid = t + h / 2
y_mid = y + (h / 2) * k1
k2 = self.single_step(
y_mid, t_mid,
audio_features, text_features, text_mask,
masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
)
# y = y + h * k2
y = y + h * k2
t = t + h
return y
def load_sam_audio_components(model_id: str = "facebook/sam-audio-small", device: str = "cpu"):
"""
Load SAM Audio components needed for DiT export.
Since we can't load the full SAMAudio model (missing perception_models),
we construct the components directly and load weights from checkpoint.
"""
import json
import sys
import types
import importlib.util
from huggingface_hub import hf_hub_download
print(f"Loading SAM Audio components from {model_id}...")
# Download config
config_path = hf_hub_download(repo_id=model_id, filename="config.json")
with open(config_path) as f:
config = json.load(f)
# Download checkpoint
checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
# Use our standalone config that doesn't have 'core' dependencies
from onnx_export.standalone_config import TransformerConfig
sam_audio_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Create fake module hierarchy so transformer.py's relative imports work
if 'sam_audio' not in sys.modules:
sam_audio_pkg = types.ModuleType('sam_audio')
sam_audio_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio')]
sys.modules['sam_audio'] = sam_audio_pkg
if 'sam_audio.model' not in sys.modules:
model_pkg = types.ModuleType('sam_audio.model')
model_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio', 'model')]
sys.modules['sam_audio.model'] = model_pkg
# Register our standalone config as sam_audio.model.config
if 'sam_audio.model.config' not in sys.modules:
import onnx_export.standalone_config as standalone_config
sys.modules['sam_audio.model.config'] = standalone_config
# Now import transformer module - it will use our standalone config
transformer_spec = importlib.util.spec_from_file_location(
"sam_audio.model.transformer",
os.path.join(sam_audio_path, "sam_audio", "model", "transformer.py")
)
transformer_module = importlib.util.module_from_spec(transformer_spec)
sys.modules['sam_audio.model.transformer'] = transformer_module
transformer_spec.loader.exec_module(transformer_module)
DiT = transformer_module.DiT
# Import align module
align_spec = importlib.util.spec_from_file_location(
"sam_audio.model.align",
os.path.join(sam_audio_path, "sam_audio", "model", "align.py")
)
align_module = importlib.util.module_from_spec(align_spec)
sys.modules['sam_audio.model.align'] = align_module
align_spec.loader.exec_module(align_module)
AlignModalities = align_module.AlignModalities
# Create transformer
transformer_config = TransformerConfig(**config.get("transformer", {}))
transformer = DiT(transformer_config)
# Calculate dimensions
in_channels = config.get("in_channels", 768)
num_anchors = config.get("num_anchors", 3)
anchor_embedding_dim = config.get("anchor_embedding_dim", 128)
# Get vision encoder dim for align_masked_video
vision_config = config.get("vision_encoder", {})
vision_dim = vision_config.get("dim", 768)
# Create components exactly as SAMAudio does
proj = nn.Linear(in_channels, transformer_config.d_model)
align_masked_video = AlignModalities(vision_dim, transformer_config.d_model)
embed_anchors = EmbedAnchors(num_anchors, anchor_embedding_dim, transformer_config.d_model)
timestep_emb = SinusoidalEmbedding(transformer_config.d_model)
# Memory projection for text features
text_encoder_config = config.get("text_encoder", {})
text_encoder_dim = text_encoder_config.get("dim", 1024) # google/flan-t5-large
memory_proj = nn.Linear(text_encoder_dim, transformer_config.d_model)
# Load weights from checkpoint
print("Loading weights from checkpoint...")
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
# Filter and load weights for each component
transformer_state = {}
proj_state = {}
align_state = {}
embed_anchors_state = {}
memory_proj_state = {}
for key, value in state_dict.items():
if key.startswith("transformer."):
new_key = key[len("transformer."):]
transformer_state[new_key] = value
elif key.startswith("proj."):
new_key = key[len("proj."):]
proj_state[new_key] = value
elif key.startswith("align_masked_video."):
new_key = key[len("align_masked_video."):]
align_state[new_key] = value
elif key.startswith("embed_anchors."):
new_key = key[len("embed_anchors."):]
embed_anchors_state[new_key] = value
elif key.startswith("memory_proj."):
new_key = key[len("memory_proj."):]
memory_proj_state[new_key] = value
transformer.load_state_dict(transformer_state)
proj.load_state_dict(proj_state)
align_masked_video.load_state_dict(align_state)
embed_anchors.load_state_dict(embed_anchors_state)
memory_proj.load_state_dict(memory_proj_state)
print(f" ✓ Loaded transformer weights ({len(transformer_state)} tensors)")
print(f" ✓ Loaded component weights")
# Create single step wrapper
single_step = DiTSingleStepWrapper(
transformer=transformer,
proj=proj,
align_masked_video=align_masked_video,
embed_anchors=embed_anchors,
timestep_emb=timestep_emb,
memory_proj=memory_proj,
).eval().to(device)
return single_step, config
def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "cpu"):
"""Create sample inputs for tracing."""
latent_dim = 128
text_dim = 768 # T5-base hidden size (SAM Audio was trained with 768-dim text)
vision_dim = 1024 # Vision encoder dim from config
text_len = 77
return {
"noisy_audio": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
"time": torch.zeros(batch_size, device=device),
"audio_features": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
"text_features": torch.randn(batch_size, text_len, text_dim, device=device),
"text_mask": torch.ones(batch_size, text_len, dtype=torch.bool, device=device),
"masked_video_features": torch.zeros(batch_size, vision_dim, seq_len, device=device),
"anchor_ids": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
"anchor_alignment": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
"audio_pad_mask": torch.ones(batch_size, seq_len, dtype=torch.bool, device=device),
}
def export_dit_single_step(
single_step: DiTSingleStepWrapper,
output_path: str,
opset_version: int = 21,
device: str = "cpu",
fp16: bool = False,
):
"""Export single-step DiT to ONNX (for runtime ODE solving)."""
import onnx
print(f"Exporting DiT single-step to {output_path}...")
# Convert to FP16 if requested
if fp16:
print(" Converting model to FP16...")
single_step = single_step.half()
sample_inputs = create_sample_inputs(device=device)
# Convert float inputs to FP16 if exporting in FP16
if fp16:
for key, value in sample_inputs.items():
if value.dtype == torch.float32:
sample_inputs[key] = value.half()
torch.onnx.export(
single_step,
tuple(sample_inputs.values()),
output_path,
input_names=list(sample_inputs.keys()),
output_names=["velocity"],
dynamic_axes={
"noisy_audio": {0: "batch_size", 1: "seq_len"},
"time": {0: "batch_size"},
"audio_features": {0: "batch_size", 1: "seq_len"},
"text_features": {0: "batch_size", 1: "text_len"},
"text_mask": {0: "batch_size", 1: "text_len"},
"masked_video_features": {0: "batch_size", 2: "seq_len"},
"anchor_ids": {0: "batch_size", 1: "seq_len"},
"anchor_alignment": {0: "batch_size", 1: "seq_len"},
"audio_pad_mask": {0: "batch_size", 1: "seq_len"},
"velocity": {0: "batch_size", 1: "seq_len"},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=True,
external_data=True,
)
print(" ✓ DiT single-step exported successfully")
# When using external_data=True, we can't run check_model on a model
# loaded without external data - the checker validates data references.
# Since torch.onnx.export with dynamo=True already validates the model,
# we just verify the files exist.
external_data_path = output_path + ".data"
if os.path.exists(external_data_path):
print(f" ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)")
else:
raise RuntimeError(f"External data file missing: {external_data_path}")
# Verify the ONNX file structure is valid (without loading weights)
model = onnx.load(output_path, load_external_data=False)
print(f" ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)")
return True
def verify_dit_single_step(
single_step: DiTSingleStepWrapper,
onnx_path: str,
device: str = "cpu",
tolerance: float = 1e-3,
) -> bool:
"""Verify single-step ONNX output matches PyTorch."""
import onnxruntime as ort
import numpy as np
print("Verifying DiT single-step output...")
sample_inputs = create_sample_inputs(device=device)
# PyTorch output
with torch.no_grad():
pytorch_output = single_step(**sample_inputs).cpu().numpy()
# ONNX Runtime output
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
onnx_inputs = {}
for name, tensor in sample_inputs.items():
if tensor.dtype == torch.bool:
onnx_inputs[name] = tensor.cpu().numpy().astype(bool)
elif tensor.dtype == torch.long:
onnx_inputs[name] = tensor.cpu().numpy().astype(np.int64)
else:
onnx_inputs[name] = tensor.cpu().numpy().astype(np.float32)
onnx_output = sess.run(["velocity"], onnx_inputs)[0]
# Compare
max_diff = np.abs(pytorch_output - onnx_output).max()
mean_diff = np.abs(pytorch_output - onnx_output).mean()
print(f" Max difference: {max_diff:.2e}")
print(f" Mean difference: {mean_diff:.2e}")
if max_diff < tolerance:
print(f" ✓ Verification passed (tolerance: {tolerance})")
return True
else:
print(f" ✗ Verification failed (tolerance: {tolerance})")
return False
def main():
parser = argparse.ArgumentParser(description="Export DiT Transformer to ONNX")
parser.add_argument(
"--model-id",
type=str,
default="facebook/sam-audio-small",
help="SAM Audio model ID from HuggingFace",
)
parser.add_argument(
"--output-dir",
type=str,
default="onnx_models",
help="Output directory for ONNX models",
)
parser.add_argument(
"--num-steps",
type=int,
default=16,
help="Number of ODE solver steps (default: 16)",
)
parser.add_argument(
"--opset",
type=int,
default=21,
help="ONNX opset version (default: 21)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to use for export (default: cpu)",
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify ONNX output matches PyTorch",
)
parser.add_argument(
"--tolerance",
type=float,
default=1e-3,
help="Tolerance for verification (default: 1e-3)",
)
parser.add_argument(
"--fp16",
action="store_true",
help="Export model in FP16 precision (half the size)",
)
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Load components
single_step, config = load_sam_audio_components(args.model_id, args.device)
print(f"\nDiT Configuration:")
print(f" Model: {args.model_id}")
print(f" ODE steps: {args.num_steps}")
print(f" Step size: {1.0/args.num_steps:.4f}")
# Export single-step model
single_step_path = os.path.join(args.output_dir, "dit_single_step.onnx")
export_dit_single_step(
single_step,
single_step_path,
opset_version=args.opset,
device=args.device,
fp16=args.fp16,
)
if args.fp16:
print(f" ✓ Model exported in FP16 precision")
# Verify single-step
if args.verify:
verify_dit_single_step(
single_step,
single_step_path,
device=args.device,
tolerance=args.tolerance,
)
print(f"\n✓ Export complete! Model saved to {args.output_dir}")
if __name__ == "__main__":
main()