File size: 19,634 Bytes
1d971a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba60410
1d971a3
ba60410
1d971a3
 
 
 
 
 
ba60410
 
 
 
 
1d971a3
 
ba60410
 
 
 
 
 
1d971a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba60410
 
 
 
 
 
 
 
 
 
 
 
 
1d971a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba60410
 
1d971a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba60410
 
 
 
 
1d971a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba60410
1d971a3
 
ba60410
 
 
1d971a3
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
#!/usr/bin/env python3
"""
Export DiT Transformer with unrolled ODE solver to ONNX format.

The DiT transformer is the core denoising model in SAM Audio. It uses a flow-based
generative model with an ODE solver. For ONNX export, we unroll the fixed-step
midpoint ODE solver into a static computation graph.

The default configuration uses:
- method: "midpoint"
- step_size: 2/32 (0.0625)
- integration range: [0, 1]
- total steps: 16

This creates a single ONNX model that performs the complete denoising process,
taking noise and conditioning as input and producing denoised audio features.

Usage:
    python -m onnx_export.export_dit --output-dir onnx_models --verify
"""

import os
import math
import argparse
import torch
import torch.nn as nn
from typing import Optional


class SinusoidalEmbedding(nn.Module):
    """Sinusoidal timestep embedding (identical to SAMAudio implementation)."""
    
    def __init__(self, dim, theta=10000):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        inv_freq = torch.exp(
            -math.log(theta) * torch.arange(half_dim).float() / half_dim
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, x, pos=None):
        if pos is None:
            seq_len, device = x.shape[1], x.device
            pos = torch.arange(seq_len, device=device)

        emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
        emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
        return emb


class EmbedAnchors(nn.Module):
    """Anchor embedding (identical to SAMAudio implementation)."""
    
    def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
        super().__init__()
        self.embed = nn.Embedding(
            num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
        )
        self.gate = nn.Parameter(torch.tensor([0.0]))
        self.proj = nn.Linear(embedding_dim, out_dim, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        anchor_ids: Optional[torch.Tensor] = None,
        anchor_alignment: Optional[torch.Tensor] = None,
    ):
        if anchor_ids is None:
            return x

        embs = self.embed(anchor_ids.gather(1, anchor_alignment))
        proj = self.proj(embs)
        return x + self.gate.tanh() * proj


class DiTSingleStepWrapper(nn.Module):
    """
    Wrapper for DiT that performs a single forward pass (one ODE evaluation).
    
    This mirrors the SAMAudio.forward() method exactly.
    """
    
    def __init__(
        self,
        transformer: nn.Module,
        proj: nn.Module,
        align_masked_video: nn.Module,
        embed_anchors: nn.Module,
        timestep_emb: nn.Module,
        memory_proj: nn.Module,
    ):
        super().__init__()
        self.transformer = transformer
        self.proj = proj
        self.align_masked_video = align_masked_video
        self.embed_anchors = embed_anchors
        self.timestep_emb = timestep_emb
        self.memory_proj = memory_proj
        
    def forward(
        self,
        noisy_audio: torch.Tensor,
        time: torch.Tensor,
        audio_features: torch.Tensor,
        text_features: torch.Tensor,
        text_mask: torch.Tensor,
        masked_video_features: torch.Tensor,
        anchor_ids: torch.Tensor,
        anchor_alignment: torch.Tensor,
        audio_pad_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Single forward pass of the DiT (one ODE function evaluation).
        
        This exactly mirrors SAMAudio.forward() method.
        """
        # Align inputs (concatenate noisy_audio with audio_features)
        # Same as SAMAudio.align_inputs()
        x = torch.cat(
            [
                noisy_audio,
                torch.zeros_like(audio_features),
                audio_features,
            ],
            dim=2,
        )
        
        projected = self.proj(x)
        aligned = self.align_masked_video(projected, masked_video_features)
        aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
        
        # Timestep embedding and memory
        # Same as SAMAudio.forward()
        timestep_emb_val = self.timestep_emb(time, pos=time).unsqueeze(1)
        memory = self.memory_proj(text_features) + timestep_emb_val
        
        # Transformer forward
        output = self.transformer(
            aligned,
            time,
            padding_mask=audio_pad_mask,
            memory=memory,
            memory_padding_mask=text_mask,
        )
        
        return output


class UnrolledDiTWrapper(nn.Module):
    """
    DiT wrapper with unrolled midpoint ODE solver.
    
    The midpoint method computes:
        k1 = f(t, y)
        k2 = f(t + h/2, y + h/2 * k1)
        y_new = y + h * k2
    
    With step_size=0.0625 and range [0,1], we have 16 steps.
    """
    
    def __init__(
        self,
        single_step: DiTSingleStepWrapper,
        num_steps: int = 16,
    ):
        super().__init__()
        self.single_step = single_step
        self.num_steps = num_steps
        self.step_size = 1.0 / num_steps
        
    def forward(
        self,
        noise: torch.Tensor,
        audio_features: torch.Tensor,
        text_features: torch.Tensor,
        text_mask: torch.Tensor,
        masked_video_features: torch.Tensor,
        anchor_ids: torch.Tensor,
        anchor_alignment: torch.Tensor,
        audio_pad_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Complete denoising using unrolled midpoint ODE solver."""
        B = noise.shape[0]
        h = self.step_size
        y = noise
        t = torch.zeros(B, device=noise.device, dtype=noise.dtype)
        
        for step in range(self.num_steps):
            # k1 = f(t, y)
            k1 = self.single_step(
                y, t,
                audio_features, text_features, text_mask,
                masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
            )
            
            # k2 = f(t + h/2, y + h/2 * k1)
            t_mid = t + h / 2
            y_mid = y + (h / 2) * k1
            k2 = self.single_step(
                y_mid, t_mid,
                audio_features, text_features, text_mask,
                masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
            )
            
            # y = y + h * k2
            y = y + h * k2
            t = t + h
        
        return y


def load_sam_audio_components(model_id: str = "facebook/sam-audio-small", device: str = "cpu"):
    """
    Load SAM Audio components needed for DiT export.
    
    Since we can't load the full SAMAudio model (missing perception_models),
    we construct the components directly and load weights from checkpoint.
    """
    import json
    import sys
    import types
    import importlib.util
    from huggingface_hub import hf_hub_download
    
    print(f"Loading SAM Audio components from {model_id}...")
    
    # Download config
    config_path = hf_hub_download(repo_id=model_id, filename="config.json")
    with open(config_path) as f:
        config = json.load(f)
    
    # Download checkpoint
    checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
    
    # Use our standalone config that doesn't have 'core' dependencies
    from onnx_export.standalone_config import TransformerConfig
    
    sam_audio_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    
    # Create fake module hierarchy so transformer.py's relative imports work
    if 'sam_audio' not in sys.modules:
        sam_audio_pkg = types.ModuleType('sam_audio')
        sam_audio_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio')]
        sys.modules['sam_audio'] = sam_audio_pkg
    
    if 'sam_audio.model' not in sys.modules:
        model_pkg = types.ModuleType('sam_audio.model')
        model_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio', 'model')]
        sys.modules['sam_audio.model'] = model_pkg
    
    # Register our standalone config as sam_audio.model.config
    if 'sam_audio.model.config' not in sys.modules:
        import onnx_export.standalone_config as standalone_config
        sys.modules['sam_audio.model.config'] = standalone_config
    
    # Now import transformer module - it will use our standalone config
    transformer_spec = importlib.util.spec_from_file_location(
        "sam_audio.model.transformer", 
        os.path.join(sam_audio_path, "sam_audio", "model", "transformer.py")
    )
    transformer_module = importlib.util.module_from_spec(transformer_spec)
    sys.modules['sam_audio.model.transformer'] = transformer_module
    transformer_spec.loader.exec_module(transformer_module)
    DiT = transformer_module.DiT
    
    # Import align module
    align_spec = importlib.util.spec_from_file_location(
        "sam_audio.model.align",
        os.path.join(sam_audio_path, "sam_audio", "model", "align.py")
    )
    align_module = importlib.util.module_from_spec(align_spec)
    sys.modules['sam_audio.model.align'] = align_module
    align_spec.loader.exec_module(align_module)
    AlignModalities = align_module.AlignModalities
    
    # Create transformer
    transformer_config = TransformerConfig(**config.get("transformer", {}))
    transformer = DiT(transformer_config)
    
    # Calculate dimensions
    in_channels = config.get("in_channels", 768)
    num_anchors = config.get("num_anchors", 3)
    anchor_embedding_dim = config.get("anchor_embedding_dim", 128)
    
    # Get vision encoder dim for align_masked_video
    vision_config = config.get("vision_encoder", {})
    vision_dim = vision_config.get("dim", 768)
    
    # Create components exactly as SAMAudio does
    proj = nn.Linear(in_channels, transformer_config.d_model)
    align_masked_video = AlignModalities(vision_dim, transformer_config.d_model)
    embed_anchors = EmbedAnchors(num_anchors, anchor_embedding_dim, transformer_config.d_model)
    timestep_emb = SinusoidalEmbedding(transformer_config.d_model)
    
    # Memory projection for text features
    text_encoder_config = config.get("text_encoder", {})
    text_encoder_dim = text_encoder_config.get("dim", 1024)  # google/flan-t5-large
    memory_proj = nn.Linear(text_encoder_dim, transformer_config.d_model)
    
    # Load weights from checkpoint
    print("Loading weights from checkpoint...")
    state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
    
    # Filter and load weights for each component
    transformer_state = {}
    proj_state = {}
    align_state = {}
    embed_anchors_state = {}
    memory_proj_state = {}
    
    for key, value in state_dict.items():
        if key.startswith("transformer."):
            new_key = key[len("transformer."):]
            transformer_state[new_key] = value
        elif key.startswith("proj."):
            new_key = key[len("proj."):]
            proj_state[new_key] = value
        elif key.startswith("align_masked_video."):
            new_key = key[len("align_masked_video."):]
            align_state[new_key] = value
        elif key.startswith("embed_anchors."):
            new_key = key[len("embed_anchors."):]
            embed_anchors_state[new_key] = value
        elif key.startswith("memory_proj."):
            new_key = key[len("memory_proj."):]
            memory_proj_state[new_key] = value
    
    transformer.load_state_dict(transformer_state)
    proj.load_state_dict(proj_state)
    align_masked_video.load_state_dict(align_state)
    embed_anchors.load_state_dict(embed_anchors_state)
    memory_proj.load_state_dict(memory_proj_state)
    
    print(f"  ✓ Loaded transformer weights ({len(transformer_state)} tensors)")
    print(f"  ✓ Loaded component weights")
    
    # Create single step wrapper
    single_step = DiTSingleStepWrapper(
        transformer=transformer,
        proj=proj,
        align_masked_video=align_masked_video,
        embed_anchors=embed_anchors,
        timestep_emb=timestep_emb,
        memory_proj=memory_proj,
    ).eval().to(device)
    
    return single_step, config


def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "cpu"):
    """Create sample inputs for tracing."""
    latent_dim = 128
    text_dim = 768  # T5-base hidden size (SAM Audio was trained with 768-dim text)
    vision_dim = 1024  # Vision encoder dim from config
    text_len = 77
    
    return {
        "noisy_audio": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
        "time": torch.zeros(batch_size, device=device),
        "audio_features": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
        "text_features": torch.randn(batch_size, text_len, text_dim, device=device),
        "text_mask": torch.ones(batch_size, text_len, dtype=torch.bool, device=device),
        "masked_video_features": torch.zeros(batch_size, vision_dim, seq_len, device=device),
        "anchor_ids": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
        "anchor_alignment": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
        "audio_pad_mask": torch.ones(batch_size, seq_len, dtype=torch.bool, device=device),
    }


def export_dit_single_step(
    single_step: DiTSingleStepWrapper,
    output_path: str,
    opset_version: int = 21,
    device: str = "cpu",
    fp16: bool = False,
):
    """Export single-step DiT to ONNX (for runtime ODE solving)."""
    import onnx
    
    print(f"Exporting DiT single-step to {output_path}...")
    
    # Convert to FP16 if requested
    if fp16:
        print("  Converting model to FP16...")
        single_step = single_step.half()
    
    sample_inputs = create_sample_inputs(device=device)
    
    # Convert float inputs to FP16 if exporting in FP16
    if fp16:
        for key, value in sample_inputs.items():
            if value.dtype == torch.float32:
                sample_inputs[key] = value.half()
    
    torch.onnx.export(
        single_step,
        tuple(sample_inputs.values()),
        output_path,
        input_names=list(sample_inputs.keys()),
        output_names=["velocity"],
        dynamic_axes={
            "noisy_audio": {0: "batch_size", 1: "seq_len"},
            "time": {0: "batch_size"},
            "audio_features": {0: "batch_size", 1: "seq_len"},
            "text_features": {0: "batch_size", 1: "text_len"},
            "text_mask": {0: "batch_size", 1: "text_len"},
            "masked_video_features": {0: "batch_size", 2: "seq_len"},
            "anchor_ids": {0: "batch_size", 1: "seq_len"},
            "anchor_alignment": {0: "batch_size", 1: "seq_len"},
            "audio_pad_mask": {0: "batch_size", 1: "seq_len"},
            "velocity": {0: "batch_size", 1: "seq_len"},
        },
        opset_version=opset_version,
        do_constant_folding=True,
        dynamo=True,
        external_data=True,
    )
    
    print("  ✓ DiT single-step exported successfully")
    
    # When using external_data=True, we can't run check_model on a model
    # loaded without external data - the checker validates data references.
    # Since torch.onnx.export with dynamo=True already validates the model,
    # we just verify the files exist.
    external_data_path = output_path + ".data"
    if os.path.exists(external_data_path):
        print(f"  ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)")
    else:
        raise RuntimeError(f"External data file missing: {external_data_path}")
    
    # Verify the ONNX file structure is valid (without loading weights)
    model = onnx.load(output_path, load_external_data=False)
    print(f"  ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)")
    
    return True


def verify_dit_single_step(
    single_step: DiTSingleStepWrapper,
    onnx_path: str,
    device: str = "cpu",
    tolerance: float = 1e-3,
) -> bool:
    """Verify single-step ONNX output matches PyTorch."""
    import onnxruntime as ort
    import numpy as np
    
    print("Verifying DiT single-step output...")
    
    sample_inputs = create_sample_inputs(device=device)
    
    # PyTorch output
    with torch.no_grad():
        pytorch_output = single_step(**sample_inputs).cpu().numpy()
    
    # ONNX Runtime output
    sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
    
    onnx_inputs = {}
    for name, tensor in sample_inputs.items():
        if tensor.dtype == torch.bool:
            onnx_inputs[name] = tensor.cpu().numpy().astype(bool)
        elif tensor.dtype == torch.long:
            onnx_inputs[name] = tensor.cpu().numpy().astype(np.int64)
        else:
            onnx_inputs[name] = tensor.cpu().numpy().astype(np.float32)
    
    onnx_output = sess.run(["velocity"], onnx_inputs)[0]
    
    # Compare
    max_diff = np.abs(pytorch_output - onnx_output).max()
    mean_diff = np.abs(pytorch_output - onnx_output).mean()
    
    print(f"  Max difference: {max_diff:.2e}")
    print(f"  Mean difference: {mean_diff:.2e}")
    
    if max_diff < tolerance:
        print(f"  ✓ Verification passed (tolerance: {tolerance})")
        return True
    else:
        print(f"  ✗ Verification failed (tolerance: {tolerance})")
        return False


def main():
    parser = argparse.ArgumentParser(description="Export DiT Transformer to ONNX")
    parser.add_argument(
        "--model-id",
        type=str,
        default="facebook/sam-audio-small",
        help="SAM Audio model ID from HuggingFace",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="onnx_models",
        help="Output directory for ONNX models",
    )
    parser.add_argument(
        "--num-steps",
        type=int,
        default=16,
        help="Number of ODE solver steps (default: 16)",
    )
    parser.add_argument(
        "--opset",
        type=int,
        default=21,
        help="ONNX opset version (default: 21)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help="Device to use for export (default: cpu)",
    )
    parser.add_argument(
        "--verify",
        action="store_true",
        help="Verify ONNX output matches PyTorch",
    )
    parser.add_argument(
        "--tolerance",
        type=float,
        default=1e-3,
        help="Tolerance for verification (default: 1e-3)",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Export model in FP16 precision (half the size)",
    )
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load components
    single_step, config = load_sam_audio_components(args.model_id, args.device)
    
    print(f"\nDiT Configuration:")
    print(f"  Model: {args.model_id}")
    print(f"  ODE steps: {args.num_steps}")
    print(f"  Step size: {1.0/args.num_steps:.4f}")
    
    # Export single-step model
    single_step_path = os.path.join(args.output_dir, "dit_single_step.onnx")
    export_dit_single_step(
        single_step,
        single_step_path,
        opset_version=args.opset,
        device=args.device,
        fp16=args.fp16,
    )
    
    if args.fp16:
        print(f"  ✓ Model exported in FP16 precision")
    
    # Verify single-step
    if args.verify:
        verify_dit_single_step(
            single_step,
            single_step_path,
            device=args.device,
            tolerance=args.tolerance,
        )
    
    print(f"\n✓ Export complete! Model saved to {args.output_dir}")


if __name__ == "__main__":
    main()