|
|
|
|
|
""" |
|
|
Export DiT Transformer with unrolled ODE solver to ONNX format. |
|
|
|
|
|
The DiT transformer is the core denoising model in SAM Audio. It uses a flow-based |
|
|
generative model with an ODE solver. For ONNX export, we unroll the fixed-step |
|
|
midpoint ODE solver into a static computation graph. |
|
|
|
|
|
The default configuration uses: |
|
|
- method: "midpoint" |
|
|
- step_size: 2/32 (0.0625) |
|
|
- integration range: [0, 1] |
|
|
- total steps: 16 |
|
|
|
|
|
This creates a single ONNX model that performs the complete denoising process, |
|
|
taking noise and conditioning as input and producing denoised audio features. |
|
|
|
|
|
Usage: |
|
|
python -m onnx_export.export_dit --output-dir onnx_models --verify |
|
|
""" |
|
|
|
|
|
import os |
|
|
import math |
|
|
import argparse |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
class SinusoidalEmbedding(nn.Module): |
|
|
"""Sinusoidal timestep embedding (identical to SAMAudio implementation).""" |
|
|
|
|
|
def __init__(self, dim, theta=10000): |
|
|
super().__init__() |
|
|
assert (dim % 2) == 0 |
|
|
half_dim = dim // 2 |
|
|
inv_freq = torch.exp( |
|
|
-math.log(theta) * torch.arange(half_dim).float() / half_dim |
|
|
) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def forward(self, x, pos=None): |
|
|
if pos is None: |
|
|
seq_len, device = x.shape[1], x.device |
|
|
pos = torch.arange(seq_len, device=device) |
|
|
|
|
|
emb = torch.einsum("i, j -> i j", pos, self.inv_freq) |
|
|
emb = torch.cat((emb.cos(), emb.sin()), dim=-1) |
|
|
return emb |
|
|
|
|
|
|
|
|
class EmbedAnchors(nn.Module): |
|
|
"""Anchor embedding (identical to SAMAudio implementation).""" |
|
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int): |
|
|
super().__init__() |
|
|
self.embed = nn.Embedding( |
|
|
num_embeddings + 1, embedding_dim, padding_idx=num_embeddings |
|
|
) |
|
|
self.gate = nn.Parameter(torch.tensor([0.0])) |
|
|
self.proj = nn.Linear(embedding_dim, out_dim, bias=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
anchor_ids: Optional[torch.Tensor] = None, |
|
|
anchor_alignment: Optional[torch.Tensor] = None, |
|
|
): |
|
|
if anchor_ids is None: |
|
|
return x |
|
|
|
|
|
embs = self.embed(anchor_ids.gather(1, anchor_alignment)) |
|
|
proj = self.proj(embs) |
|
|
return x + self.gate.tanh() * proj |
|
|
|
|
|
|
|
|
class DiTSingleStepWrapper(nn.Module): |
|
|
""" |
|
|
Wrapper for DiT that performs a single forward pass (one ODE evaluation). |
|
|
|
|
|
This mirrors the SAMAudio.forward() method exactly. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
transformer: nn.Module, |
|
|
proj: nn.Module, |
|
|
align_masked_video: nn.Module, |
|
|
embed_anchors: nn.Module, |
|
|
timestep_emb: nn.Module, |
|
|
memory_proj: nn.Module, |
|
|
): |
|
|
super().__init__() |
|
|
self.transformer = transformer |
|
|
self.proj = proj |
|
|
self.align_masked_video = align_masked_video |
|
|
self.embed_anchors = embed_anchors |
|
|
self.timestep_emb = timestep_emb |
|
|
self.memory_proj = memory_proj |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
noisy_audio: torch.Tensor, |
|
|
time: torch.Tensor, |
|
|
audio_features: torch.Tensor, |
|
|
text_features: torch.Tensor, |
|
|
text_mask: torch.Tensor, |
|
|
masked_video_features: torch.Tensor, |
|
|
anchor_ids: torch.Tensor, |
|
|
anchor_alignment: torch.Tensor, |
|
|
audio_pad_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Single forward pass of the DiT (one ODE function evaluation). |
|
|
|
|
|
This exactly mirrors SAMAudio.forward() method. |
|
|
""" |
|
|
|
|
|
|
|
|
x = torch.cat( |
|
|
[ |
|
|
noisy_audio, |
|
|
torch.zeros_like(audio_features), |
|
|
audio_features, |
|
|
], |
|
|
dim=2, |
|
|
) |
|
|
|
|
|
projected = self.proj(x) |
|
|
aligned = self.align_masked_video(projected, masked_video_features) |
|
|
aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment) |
|
|
|
|
|
|
|
|
|
|
|
timestep_emb_val = self.timestep_emb(time, pos=time).unsqueeze(1) |
|
|
memory = self.memory_proj(text_features) + timestep_emb_val |
|
|
|
|
|
|
|
|
output = self.transformer( |
|
|
aligned, |
|
|
time, |
|
|
padding_mask=audio_pad_mask, |
|
|
memory=memory, |
|
|
memory_padding_mask=text_mask, |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class UnrolledDiTWrapper(nn.Module): |
|
|
""" |
|
|
DiT wrapper with unrolled midpoint ODE solver. |
|
|
|
|
|
The midpoint method computes: |
|
|
k1 = f(t, y) |
|
|
k2 = f(t + h/2, y + h/2 * k1) |
|
|
y_new = y + h * k2 |
|
|
|
|
|
With step_size=0.0625 and range [0,1], we have 16 steps. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
single_step: DiTSingleStepWrapper, |
|
|
num_steps: int = 16, |
|
|
): |
|
|
super().__init__() |
|
|
self.single_step = single_step |
|
|
self.num_steps = num_steps |
|
|
self.step_size = 1.0 / num_steps |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
noise: torch.Tensor, |
|
|
audio_features: torch.Tensor, |
|
|
text_features: torch.Tensor, |
|
|
text_mask: torch.Tensor, |
|
|
masked_video_features: torch.Tensor, |
|
|
anchor_ids: torch.Tensor, |
|
|
anchor_alignment: torch.Tensor, |
|
|
audio_pad_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Complete denoising using unrolled midpoint ODE solver.""" |
|
|
B = noise.shape[0] |
|
|
h = self.step_size |
|
|
y = noise |
|
|
t = torch.zeros(B, device=noise.device, dtype=noise.dtype) |
|
|
|
|
|
for step in range(self.num_steps): |
|
|
|
|
|
k1 = self.single_step( |
|
|
y, t, |
|
|
audio_features, text_features, text_mask, |
|
|
masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask |
|
|
) |
|
|
|
|
|
|
|
|
t_mid = t + h / 2 |
|
|
y_mid = y + (h / 2) * k1 |
|
|
k2 = self.single_step( |
|
|
y_mid, t_mid, |
|
|
audio_features, text_features, text_mask, |
|
|
masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask |
|
|
) |
|
|
|
|
|
|
|
|
y = y + h * k2 |
|
|
t = t + h |
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
def load_sam_audio_components(model_id: str = "facebook/sam-audio-small", device: str = "cpu"): |
|
|
""" |
|
|
Load SAM Audio components needed for DiT export. |
|
|
|
|
|
Since we can't load the full SAMAudio model (missing perception_models), |
|
|
we construct the components directly and load weights from checkpoint. |
|
|
""" |
|
|
import json |
|
|
import sys |
|
|
import types |
|
|
import importlib.util |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
print(f"Loading SAM Audio components from {model_id}...") |
|
|
|
|
|
|
|
|
config_path = hf_hub_download(repo_id=model_id, filename="config.json") |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt") |
|
|
|
|
|
|
|
|
from onnx_export.standalone_config import TransformerConfig |
|
|
|
|
|
sam_audio_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
|
|
|
if 'sam_audio' not in sys.modules: |
|
|
sam_audio_pkg = types.ModuleType('sam_audio') |
|
|
sam_audio_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio')] |
|
|
sys.modules['sam_audio'] = sam_audio_pkg |
|
|
|
|
|
if 'sam_audio.model' not in sys.modules: |
|
|
model_pkg = types.ModuleType('sam_audio.model') |
|
|
model_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio', 'model')] |
|
|
sys.modules['sam_audio.model'] = model_pkg |
|
|
|
|
|
|
|
|
if 'sam_audio.model.config' not in sys.modules: |
|
|
import onnx_export.standalone_config as standalone_config |
|
|
sys.modules['sam_audio.model.config'] = standalone_config |
|
|
|
|
|
|
|
|
transformer_spec = importlib.util.spec_from_file_location( |
|
|
"sam_audio.model.transformer", |
|
|
os.path.join(sam_audio_path, "sam_audio", "model", "transformer.py") |
|
|
) |
|
|
transformer_module = importlib.util.module_from_spec(transformer_spec) |
|
|
sys.modules['sam_audio.model.transformer'] = transformer_module |
|
|
transformer_spec.loader.exec_module(transformer_module) |
|
|
DiT = transformer_module.DiT |
|
|
|
|
|
|
|
|
align_spec = importlib.util.spec_from_file_location( |
|
|
"sam_audio.model.align", |
|
|
os.path.join(sam_audio_path, "sam_audio", "model", "align.py") |
|
|
) |
|
|
align_module = importlib.util.module_from_spec(align_spec) |
|
|
sys.modules['sam_audio.model.align'] = align_module |
|
|
align_spec.loader.exec_module(align_module) |
|
|
AlignModalities = align_module.AlignModalities |
|
|
|
|
|
|
|
|
transformer_config = TransformerConfig(**config.get("transformer", {})) |
|
|
transformer = DiT(transformer_config) |
|
|
|
|
|
|
|
|
in_channels = config.get("in_channels", 768) |
|
|
num_anchors = config.get("num_anchors", 3) |
|
|
anchor_embedding_dim = config.get("anchor_embedding_dim", 128) |
|
|
|
|
|
|
|
|
vision_config = config.get("vision_encoder", {}) |
|
|
vision_dim = vision_config.get("dim", 768) |
|
|
|
|
|
|
|
|
proj = nn.Linear(in_channels, transformer_config.d_model) |
|
|
align_masked_video = AlignModalities(vision_dim, transformer_config.d_model) |
|
|
embed_anchors = EmbedAnchors(num_anchors, anchor_embedding_dim, transformer_config.d_model) |
|
|
timestep_emb = SinusoidalEmbedding(transformer_config.d_model) |
|
|
|
|
|
|
|
|
text_encoder_config = config.get("text_encoder", {}) |
|
|
text_encoder_dim = text_encoder_config.get("dim", 1024) |
|
|
memory_proj = nn.Linear(text_encoder_dim, transformer_config.d_model) |
|
|
|
|
|
|
|
|
print("Loading weights from checkpoint...") |
|
|
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True) |
|
|
|
|
|
|
|
|
transformer_state = {} |
|
|
proj_state = {} |
|
|
align_state = {} |
|
|
embed_anchors_state = {} |
|
|
memory_proj_state = {} |
|
|
|
|
|
for key, value in state_dict.items(): |
|
|
if key.startswith("transformer."): |
|
|
new_key = key[len("transformer."):] |
|
|
transformer_state[new_key] = value |
|
|
elif key.startswith("proj."): |
|
|
new_key = key[len("proj."):] |
|
|
proj_state[new_key] = value |
|
|
elif key.startswith("align_masked_video."): |
|
|
new_key = key[len("align_masked_video."):] |
|
|
align_state[new_key] = value |
|
|
elif key.startswith("embed_anchors."): |
|
|
new_key = key[len("embed_anchors."):] |
|
|
embed_anchors_state[new_key] = value |
|
|
elif key.startswith("memory_proj."): |
|
|
new_key = key[len("memory_proj."):] |
|
|
memory_proj_state[new_key] = value |
|
|
|
|
|
transformer.load_state_dict(transformer_state) |
|
|
proj.load_state_dict(proj_state) |
|
|
align_masked_video.load_state_dict(align_state) |
|
|
embed_anchors.load_state_dict(embed_anchors_state) |
|
|
memory_proj.load_state_dict(memory_proj_state) |
|
|
|
|
|
print(f" ✓ Loaded transformer weights ({len(transformer_state)} tensors)") |
|
|
print(f" ✓ Loaded component weights") |
|
|
|
|
|
|
|
|
single_step = DiTSingleStepWrapper( |
|
|
transformer=transformer, |
|
|
proj=proj, |
|
|
align_masked_video=align_masked_video, |
|
|
embed_anchors=embed_anchors, |
|
|
timestep_emb=timestep_emb, |
|
|
memory_proj=memory_proj, |
|
|
).eval().to(device) |
|
|
|
|
|
return single_step, config |
|
|
|
|
|
|
|
|
def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "cpu"): |
|
|
"""Create sample inputs for tracing.""" |
|
|
latent_dim = 128 |
|
|
text_dim = 768 |
|
|
vision_dim = 1024 |
|
|
text_len = 77 |
|
|
|
|
|
return { |
|
|
"noisy_audio": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device), |
|
|
"time": torch.zeros(batch_size, device=device), |
|
|
"audio_features": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device), |
|
|
"text_features": torch.randn(batch_size, text_len, text_dim, device=device), |
|
|
"text_mask": torch.ones(batch_size, text_len, dtype=torch.bool, device=device), |
|
|
"masked_video_features": torch.zeros(batch_size, vision_dim, seq_len, device=device), |
|
|
"anchor_ids": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device), |
|
|
"anchor_alignment": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device), |
|
|
"audio_pad_mask": torch.ones(batch_size, seq_len, dtype=torch.bool, device=device), |
|
|
} |
|
|
|
|
|
|
|
|
def export_dit_single_step( |
|
|
single_step: DiTSingleStepWrapper, |
|
|
output_path: str, |
|
|
opset_version: int = 21, |
|
|
device: str = "cpu", |
|
|
fp16: bool = False, |
|
|
): |
|
|
"""Export single-step DiT to ONNX (for runtime ODE solving).""" |
|
|
import onnx |
|
|
|
|
|
print(f"Exporting DiT single-step to {output_path}...") |
|
|
|
|
|
|
|
|
if fp16: |
|
|
print(" Converting model to FP16...") |
|
|
single_step = single_step.half() |
|
|
|
|
|
sample_inputs = create_sample_inputs(device=device) |
|
|
|
|
|
|
|
|
if fp16: |
|
|
for key, value in sample_inputs.items(): |
|
|
if value.dtype == torch.float32: |
|
|
sample_inputs[key] = value.half() |
|
|
|
|
|
torch.onnx.export( |
|
|
single_step, |
|
|
tuple(sample_inputs.values()), |
|
|
output_path, |
|
|
input_names=list(sample_inputs.keys()), |
|
|
output_names=["velocity"], |
|
|
dynamic_axes={ |
|
|
"noisy_audio": {0: "batch_size", 1: "seq_len"}, |
|
|
"time": {0: "batch_size"}, |
|
|
"audio_features": {0: "batch_size", 1: "seq_len"}, |
|
|
"text_features": {0: "batch_size", 1: "text_len"}, |
|
|
"text_mask": {0: "batch_size", 1: "text_len"}, |
|
|
"masked_video_features": {0: "batch_size", 2: "seq_len"}, |
|
|
"anchor_ids": {0: "batch_size", 1: "seq_len"}, |
|
|
"anchor_alignment": {0: "batch_size", 1: "seq_len"}, |
|
|
"audio_pad_mask": {0: "batch_size", 1: "seq_len"}, |
|
|
"velocity": {0: "batch_size", 1: "seq_len"}, |
|
|
}, |
|
|
opset_version=opset_version, |
|
|
do_constant_folding=True, |
|
|
dynamo=True, |
|
|
external_data=True, |
|
|
) |
|
|
|
|
|
print(" ✓ DiT single-step exported successfully") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
external_data_path = output_path + ".data" |
|
|
if os.path.exists(external_data_path): |
|
|
print(f" ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)") |
|
|
else: |
|
|
raise RuntimeError(f"External data file missing: {external_data_path}") |
|
|
|
|
|
|
|
|
model = onnx.load(output_path, load_external_data=False) |
|
|
print(f" ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)") |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def verify_dit_single_step( |
|
|
single_step: DiTSingleStepWrapper, |
|
|
onnx_path: str, |
|
|
device: str = "cpu", |
|
|
tolerance: float = 1e-3, |
|
|
) -> bool: |
|
|
"""Verify single-step ONNX output matches PyTorch.""" |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
|
|
|
print("Verifying DiT single-step output...") |
|
|
|
|
|
sample_inputs = create_sample_inputs(device=device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
pytorch_output = single_step(**sample_inputs).cpu().numpy() |
|
|
|
|
|
|
|
|
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
|
|
|
|
|
onnx_inputs = {} |
|
|
for name, tensor in sample_inputs.items(): |
|
|
if tensor.dtype == torch.bool: |
|
|
onnx_inputs[name] = tensor.cpu().numpy().astype(bool) |
|
|
elif tensor.dtype == torch.long: |
|
|
onnx_inputs[name] = tensor.cpu().numpy().astype(np.int64) |
|
|
else: |
|
|
onnx_inputs[name] = tensor.cpu().numpy().astype(np.float32) |
|
|
|
|
|
onnx_output = sess.run(["velocity"], onnx_inputs)[0] |
|
|
|
|
|
|
|
|
max_diff = np.abs(pytorch_output - onnx_output).max() |
|
|
mean_diff = np.abs(pytorch_output - onnx_output).mean() |
|
|
|
|
|
print(f" Max difference: {max_diff:.2e}") |
|
|
print(f" Mean difference: {mean_diff:.2e}") |
|
|
|
|
|
if max_diff < tolerance: |
|
|
print(f" ✓ Verification passed (tolerance: {tolerance})") |
|
|
return True |
|
|
else: |
|
|
print(f" ✗ Verification failed (tolerance: {tolerance})") |
|
|
return False |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Export DiT Transformer to ONNX") |
|
|
parser.add_argument( |
|
|
"--model-id", |
|
|
type=str, |
|
|
default="facebook/sam-audio-small", |
|
|
help="SAM Audio model ID from HuggingFace", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="onnx_models", |
|
|
help="Output directory for ONNX models", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-steps", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Number of ODE solver steps (default: 16)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--opset", |
|
|
type=int, |
|
|
default=21, |
|
|
help="ONNX opset version (default: 21)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cpu", |
|
|
help="Device to use for export (default: cpu)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verify", |
|
|
action="store_true", |
|
|
help="Verify ONNX output matches PyTorch", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tolerance", |
|
|
type=float, |
|
|
default=1e-3, |
|
|
help="Tolerance for verification (default: 1e-3)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fp16", |
|
|
action="store_true", |
|
|
help="Export model in FP16 precision (half the size)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
single_step, config = load_sam_audio_components(args.model_id, args.device) |
|
|
|
|
|
print(f"\nDiT Configuration:") |
|
|
print(f" Model: {args.model_id}") |
|
|
print(f" ODE steps: {args.num_steps}") |
|
|
print(f" Step size: {1.0/args.num_steps:.4f}") |
|
|
|
|
|
|
|
|
single_step_path = os.path.join(args.output_dir, "dit_single_step.onnx") |
|
|
export_dit_single_step( |
|
|
single_step, |
|
|
single_step_path, |
|
|
opset_version=args.opset, |
|
|
device=args.device, |
|
|
fp16=args.fp16, |
|
|
) |
|
|
|
|
|
if args.fp16: |
|
|
print(f" ✓ Model exported in FP16 precision") |
|
|
|
|
|
|
|
|
if args.verify: |
|
|
verify_dit_single_step( |
|
|
single_step, |
|
|
single_step_path, |
|
|
device=args.device, |
|
|
tolerance=args.tolerance, |
|
|
) |
|
|
|
|
|
print(f"\n✓ Export complete! Model saved to {args.output_dir}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|