File size: 19,634 Bytes
1d971a3 ba60410 1d971a3 ba60410 1d971a3 ba60410 1d971a3 ba60410 1d971a3 ba60410 1d971a3 ba60410 1d971a3 ba60410 1d971a3 ba60410 1d971a3 ba60410 1d971a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 | #!/usr/bin/env python3
"""
Export DiT Transformer with unrolled ODE solver to ONNX format.
The DiT transformer is the core denoising model in SAM Audio. It uses a flow-based
generative model with an ODE solver. For ONNX export, we unroll the fixed-step
midpoint ODE solver into a static computation graph.
The default configuration uses:
- method: "midpoint"
- step_size: 2/32 (0.0625)
- integration range: [0, 1]
- total steps: 16
This creates a single ONNX model that performs the complete denoising process,
taking noise and conditioning as input and producing denoised audio features.
Usage:
python -m onnx_export.export_dit --output-dir onnx_models --verify
"""
import os
import math
import argparse
import torch
import torch.nn as nn
from typing import Optional
class SinusoidalEmbedding(nn.Module):
"""Sinusoidal timestep embedding (identical to SAMAudio implementation)."""
def __init__(self, dim, theta=10000):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
inv_freq = torch.exp(
-math.log(theta) * torch.arange(half_dim).float() / half_dim
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, pos=None):
if pos is None:
seq_len, device = x.shape[1], x.device
pos = torch.arange(seq_len, device=device)
emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
return emb
class EmbedAnchors(nn.Module):
"""Anchor embedding (identical to SAMAudio implementation)."""
def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
super().__init__()
self.embed = nn.Embedding(
num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
)
self.gate = nn.Parameter(torch.tensor([0.0]))
self.proj = nn.Linear(embedding_dim, out_dim, bias=False)
def forward(
self,
x: torch.Tensor,
anchor_ids: Optional[torch.Tensor] = None,
anchor_alignment: Optional[torch.Tensor] = None,
):
if anchor_ids is None:
return x
embs = self.embed(anchor_ids.gather(1, anchor_alignment))
proj = self.proj(embs)
return x + self.gate.tanh() * proj
class DiTSingleStepWrapper(nn.Module):
"""
Wrapper for DiT that performs a single forward pass (one ODE evaluation).
This mirrors the SAMAudio.forward() method exactly.
"""
def __init__(
self,
transformer: nn.Module,
proj: nn.Module,
align_masked_video: nn.Module,
embed_anchors: nn.Module,
timestep_emb: nn.Module,
memory_proj: nn.Module,
):
super().__init__()
self.transformer = transformer
self.proj = proj
self.align_masked_video = align_masked_video
self.embed_anchors = embed_anchors
self.timestep_emb = timestep_emb
self.memory_proj = memory_proj
def forward(
self,
noisy_audio: torch.Tensor,
time: torch.Tensor,
audio_features: torch.Tensor,
text_features: torch.Tensor,
text_mask: torch.Tensor,
masked_video_features: torch.Tensor,
anchor_ids: torch.Tensor,
anchor_alignment: torch.Tensor,
audio_pad_mask: torch.Tensor,
) -> torch.Tensor:
"""
Single forward pass of the DiT (one ODE function evaluation).
This exactly mirrors SAMAudio.forward() method.
"""
# Align inputs (concatenate noisy_audio with audio_features)
# Same as SAMAudio.align_inputs()
x = torch.cat(
[
noisy_audio,
torch.zeros_like(audio_features),
audio_features,
],
dim=2,
)
projected = self.proj(x)
aligned = self.align_masked_video(projected, masked_video_features)
aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
# Timestep embedding and memory
# Same as SAMAudio.forward()
timestep_emb_val = self.timestep_emb(time, pos=time).unsqueeze(1)
memory = self.memory_proj(text_features) + timestep_emb_val
# Transformer forward
output = self.transformer(
aligned,
time,
padding_mask=audio_pad_mask,
memory=memory,
memory_padding_mask=text_mask,
)
return output
class UnrolledDiTWrapper(nn.Module):
"""
DiT wrapper with unrolled midpoint ODE solver.
The midpoint method computes:
k1 = f(t, y)
k2 = f(t + h/2, y + h/2 * k1)
y_new = y + h * k2
With step_size=0.0625 and range [0,1], we have 16 steps.
"""
def __init__(
self,
single_step: DiTSingleStepWrapper,
num_steps: int = 16,
):
super().__init__()
self.single_step = single_step
self.num_steps = num_steps
self.step_size = 1.0 / num_steps
def forward(
self,
noise: torch.Tensor,
audio_features: torch.Tensor,
text_features: torch.Tensor,
text_mask: torch.Tensor,
masked_video_features: torch.Tensor,
anchor_ids: torch.Tensor,
anchor_alignment: torch.Tensor,
audio_pad_mask: torch.Tensor,
) -> torch.Tensor:
"""Complete denoising using unrolled midpoint ODE solver."""
B = noise.shape[0]
h = self.step_size
y = noise
t = torch.zeros(B, device=noise.device, dtype=noise.dtype)
for step in range(self.num_steps):
# k1 = f(t, y)
k1 = self.single_step(
y, t,
audio_features, text_features, text_mask,
masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
)
# k2 = f(t + h/2, y + h/2 * k1)
t_mid = t + h / 2
y_mid = y + (h / 2) * k1
k2 = self.single_step(
y_mid, t_mid,
audio_features, text_features, text_mask,
masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
)
# y = y + h * k2
y = y + h * k2
t = t + h
return y
def load_sam_audio_components(model_id: str = "facebook/sam-audio-small", device: str = "cpu"):
"""
Load SAM Audio components needed for DiT export.
Since we can't load the full SAMAudio model (missing perception_models),
we construct the components directly and load weights from checkpoint.
"""
import json
import sys
import types
import importlib.util
from huggingface_hub import hf_hub_download
print(f"Loading SAM Audio components from {model_id}...")
# Download config
config_path = hf_hub_download(repo_id=model_id, filename="config.json")
with open(config_path) as f:
config = json.load(f)
# Download checkpoint
checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
# Use our standalone config that doesn't have 'core' dependencies
from onnx_export.standalone_config import TransformerConfig
sam_audio_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Create fake module hierarchy so transformer.py's relative imports work
if 'sam_audio' not in sys.modules:
sam_audio_pkg = types.ModuleType('sam_audio')
sam_audio_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio')]
sys.modules['sam_audio'] = sam_audio_pkg
if 'sam_audio.model' not in sys.modules:
model_pkg = types.ModuleType('sam_audio.model')
model_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio', 'model')]
sys.modules['sam_audio.model'] = model_pkg
# Register our standalone config as sam_audio.model.config
if 'sam_audio.model.config' not in sys.modules:
import onnx_export.standalone_config as standalone_config
sys.modules['sam_audio.model.config'] = standalone_config
# Now import transformer module - it will use our standalone config
transformer_spec = importlib.util.spec_from_file_location(
"sam_audio.model.transformer",
os.path.join(sam_audio_path, "sam_audio", "model", "transformer.py")
)
transformer_module = importlib.util.module_from_spec(transformer_spec)
sys.modules['sam_audio.model.transformer'] = transformer_module
transformer_spec.loader.exec_module(transformer_module)
DiT = transformer_module.DiT
# Import align module
align_spec = importlib.util.spec_from_file_location(
"sam_audio.model.align",
os.path.join(sam_audio_path, "sam_audio", "model", "align.py")
)
align_module = importlib.util.module_from_spec(align_spec)
sys.modules['sam_audio.model.align'] = align_module
align_spec.loader.exec_module(align_module)
AlignModalities = align_module.AlignModalities
# Create transformer
transformer_config = TransformerConfig(**config.get("transformer", {}))
transformer = DiT(transformer_config)
# Calculate dimensions
in_channels = config.get("in_channels", 768)
num_anchors = config.get("num_anchors", 3)
anchor_embedding_dim = config.get("anchor_embedding_dim", 128)
# Get vision encoder dim for align_masked_video
vision_config = config.get("vision_encoder", {})
vision_dim = vision_config.get("dim", 768)
# Create components exactly as SAMAudio does
proj = nn.Linear(in_channels, transformer_config.d_model)
align_masked_video = AlignModalities(vision_dim, transformer_config.d_model)
embed_anchors = EmbedAnchors(num_anchors, anchor_embedding_dim, transformer_config.d_model)
timestep_emb = SinusoidalEmbedding(transformer_config.d_model)
# Memory projection for text features
text_encoder_config = config.get("text_encoder", {})
text_encoder_dim = text_encoder_config.get("dim", 1024) # google/flan-t5-large
memory_proj = nn.Linear(text_encoder_dim, transformer_config.d_model)
# Load weights from checkpoint
print("Loading weights from checkpoint...")
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
# Filter and load weights for each component
transformer_state = {}
proj_state = {}
align_state = {}
embed_anchors_state = {}
memory_proj_state = {}
for key, value in state_dict.items():
if key.startswith("transformer."):
new_key = key[len("transformer."):]
transformer_state[new_key] = value
elif key.startswith("proj."):
new_key = key[len("proj."):]
proj_state[new_key] = value
elif key.startswith("align_masked_video."):
new_key = key[len("align_masked_video."):]
align_state[new_key] = value
elif key.startswith("embed_anchors."):
new_key = key[len("embed_anchors."):]
embed_anchors_state[new_key] = value
elif key.startswith("memory_proj."):
new_key = key[len("memory_proj."):]
memory_proj_state[new_key] = value
transformer.load_state_dict(transformer_state)
proj.load_state_dict(proj_state)
align_masked_video.load_state_dict(align_state)
embed_anchors.load_state_dict(embed_anchors_state)
memory_proj.load_state_dict(memory_proj_state)
print(f" ✓ Loaded transformer weights ({len(transformer_state)} tensors)")
print(f" ✓ Loaded component weights")
# Create single step wrapper
single_step = DiTSingleStepWrapper(
transformer=transformer,
proj=proj,
align_masked_video=align_masked_video,
embed_anchors=embed_anchors,
timestep_emb=timestep_emb,
memory_proj=memory_proj,
).eval().to(device)
return single_step, config
def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "cpu"):
"""Create sample inputs for tracing."""
latent_dim = 128
text_dim = 768 # T5-base hidden size (SAM Audio was trained with 768-dim text)
vision_dim = 1024 # Vision encoder dim from config
text_len = 77
return {
"noisy_audio": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
"time": torch.zeros(batch_size, device=device),
"audio_features": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
"text_features": torch.randn(batch_size, text_len, text_dim, device=device),
"text_mask": torch.ones(batch_size, text_len, dtype=torch.bool, device=device),
"masked_video_features": torch.zeros(batch_size, vision_dim, seq_len, device=device),
"anchor_ids": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
"anchor_alignment": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
"audio_pad_mask": torch.ones(batch_size, seq_len, dtype=torch.bool, device=device),
}
def export_dit_single_step(
single_step: DiTSingleStepWrapper,
output_path: str,
opset_version: int = 21,
device: str = "cpu",
fp16: bool = False,
):
"""Export single-step DiT to ONNX (for runtime ODE solving)."""
import onnx
print(f"Exporting DiT single-step to {output_path}...")
# Convert to FP16 if requested
if fp16:
print(" Converting model to FP16...")
single_step = single_step.half()
sample_inputs = create_sample_inputs(device=device)
# Convert float inputs to FP16 if exporting in FP16
if fp16:
for key, value in sample_inputs.items():
if value.dtype == torch.float32:
sample_inputs[key] = value.half()
torch.onnx.export(
single_step,
tuple(sample_inputs.values()),
output_path,
input_names=list(sample_inputs.keys()),
output_names=["velocity"],
dynamic_axes={
"noisy_audio": {0: "batch_size", 1: "seq_len"},
"time": {0: "batch_size"},
"audio_features": {0: "batch_size", 1: "seq_len"},
"text_features": {0: "batch_size", 1: "text_len"},
"text_mask": {0: "batch_size", 1: "text_len"},
"masked_video_features": {0: "batch_size", 2: "seq_len"},
"anchor_ids": {0: "batch_size", 1: "seq_len"},
"anchor_alignment": {0: "batch_size", 1: "seq_len"},
"audio_pad_mask": {0: "batch_size", 1: "seq_len"},
"velocity": {0: "batch_size", 1: "seq_len"},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=True,
external_data=True,
)
print(" ✓ DiT single-step exported successfully")
# When using external_data=True, we can't run check_model on a model
# loaded without external data - the checker validates data references.
# Since torch.onnx.export with dynamo=True already validates the model,
# we just verify the files exist.
external_data_path = output_path + ".data"
if os.path.exists(external_data_path):
print(f" ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)")
else:
raise RuntimeError(f"External data file missing: {external_data_path}")
# Verify the ONNX file structure is valid (without loading weights)
model = onnx.load(output_path, load_external_data=False)
print(f" ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)")
return True
def verify_dit_single_step(
single_step: DiTSingleStepWrapper,
onnx_path: str,
device: str = "cpu",
tolerance: float = 1e-3,
) -> bool:
"""Verify single-step ONNX output matches PyTorch."""
import onnxruntime as ort
import numpy as np
print("Verifying DiT single-step output...")
sample_inputs = create_sample_inputs(device=device)
# PyTorch output
with torch.no_grad():
pytorch_output = single_step(**sample_inputs).cpu().numpy()
# ONNX Runtime output
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
onnx_inputs = {}
for name, tensor in sample_inputs.items():
if tensor.dtype == torch.bool:
onnx_inputs[name] = tensor.cpu().numpy().astype(bool)
elif tensor.dtype == torch.long:
onnx_inputs[name] = tensor.cpu().numpy().astype(np.int64)
else:
onnx_inputs[name] = tensor.cpu().numpy().astype(np.float32)
onnx_output = sess.run(["velocity"], onnx_inputs)[0]
# Compare
max_diff = np.abs(pytorch_output - onnx_output).max()
mean_diff = np.abs(pytorch_output - onnx_output).mean()
print(f" Max difference: {max_diff:.2e}")
print(f" Mean difference: {mean_diff:.2e}")
if max_diff < tolerance:
print(f" ✓ Verification passed (tolerance: {tolerance})")
return True
else:
print(f" ✗ Verification failed (tolerance: {tolerance})")
return False
def main():
parser = argparse.ArgumentParser(description="Export DiT Transformer to ONNX")
parser.add_argument(
"--model-id",
type=str,
default="facebook/sam-audio-small",
help="SAM Audio model ID from HuggingFace",
)
parser.add_argument(
"--output-dir",
type=str,
default="onnx_models",
help="Output directory for ONNX models",
)
parser.add_argument(
"--num-steps",
type=int,
default=16,
help="Number of ODE solver steps (default: 16)",
)
parser.add_argument(
"--opset",
type=int,
default=21,
help="ONNX opset version (default: 21)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to use for export (default: cpu)",
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify ONNX output matches PyTorch",
)
parser.add_argument(
"--tolerance",
type=float,
default=1e-3,
help="Tolerance for verification (default: 1e-3)",
)
parser.add_argument(
"--fp16",
action="store_true",
help="Export model in FP16 precision (half the size)",
)
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Load components
single_step, config = load_sam_audio_components(args.model_id, args.device)
print(f"\nDiT Configuration:")
print(f" Model: {args.model_id}")
print(f" ODE steps: {args.num_steps}")
print(f" Step size: {1.0/args.num_steps:.4f}")
# Export single-step model
single_step_path = os.path.join(args.output_dir, "dit_single_step.onnx")
export_dit_single_step(
single_step,
single_step_path,
opset_version=args.opset,
device=args.device,
fp16=args.fp16,
)
if args.fp16:
print(f" ✓ Model exported in FP16 precision")
# Verify single-step
if args.verify:
verify_dit_single_step(
single_step,
single_step_path,
device=args.device,
tolerance=args.tolerance,
)
print(f"\n✓ Export complete! Model saved to {args.output_dir}")
if __name__ == "__main__":
main()
|