matbee commited on
Commit
1d971a3
·
verified ·
1 Parent(s): 301a400

Upload folder using huggingface_hub

Browse files
onnx_export/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ONNX Export utilities for SAM Audio
onnx_export/export_all.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export all SAM Audio components to ONNX format.
4
+
5
+ This script exports:
6
+ 1. DACVAE encoder and decoder (audio codec)
7
+ 2. T5 text encoder
8
+ 3. DiT transformer (single-step for ODE solving)
9
+
10
+ Usage:
11
+ python -m onnx_export.export_all --output-dir onnx_models --verify
12
+ """
13
+
14
+ import os
15
+ import argparse
16
+ import subprocess
17
+ import sys
18
+
19
+
20
+ def run_export(module: str, args: list[str]) -> bool:
21
+ """Run an export module with the given arguments."""
22
+ cmd = [sys.executable, "-m", module] + args
23
+ print(f"\n{'='*60}")
24
+ print(f"Running: {' '.join(cmd)}")
25
+ print(f"{'='*60}\n")
26
+
27
+ result = subprocess.run(cmd, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28
+ return result.returncode == 0
29
+
30
+
31
+ def main():
32
+ parser = argparse.ArgumentParser(description="Export all SAM Audio components to ONNX")
33
+ parser.add_argument(
34
+ "--output-dir",
35
+ type=str,
36
+ default="onnx_models",
37
+ help="Output directory for ONNX models",
38
+ )
39
+ parser.add_argument(
40
+ "--verify",
41
+ action="store_true",
42
+ help="Verify ONNX output matches PyTorch",
43
+ )
44
+ parser.add_argument(
45
+ "--skip-dacvae",
46
+ action="store_true",
47
+ help="Skip DACVAE export",
48
+ )
49
+ parser.add_argument(
50
+ "--skip-t5",
51
+ action="store_true",
52
+ help="Skip T5 export",
53
+ )
54
+ parser.add_argument(
55
+ "--skip-dit",
56
+ action="store_true",
57
+ help="Skip DiT export",
58
+ )
59
+
60
+ args = parser.parse_args()
61
+
62
+ os.makedirs(args.output_dir, exist_ok=True)
63
+
64
+ results = {}
65
+
66
+ # Export DACVAE
67
+ if not args.skip_dacvae:
68
+ export_args = ["--output-dir", args.output_dir]
69
+ if args.verify:
70
+ export_args.append("--verify")
71
+ results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
72
+
73
+ # Export T5
74
+ if not args.skip_t5:
75
+ export_args = ["--output-dir", args.output_dir]
76
+ if args.verify:
77
+ export_args.append("--verify")
78
+ results["T5"] = run_export("onnx_export.export_t5", export_args)
79
+
80
+ # Export DiT
81
+ if not args.skip_dit:
82
+ export_args = ["--output-dir", args.output_dir]
83
+ if args.verify:
84
+ export_args.append("--verify")
85
+ results["DiT"] = run_export("onnx_export.export_dit", export_args)
86
+
87
+ # Print summary
88
+ print(f"\n{'='*60}")
89
+ print("Export Summary")
90
+ print(f"{'='*60}")
91
+
92
+ all_success = True
93
+ for name, success in results.items():
94
+ status = "✓" if success else "✗"
95
+ print(f" {status} {name}")
96
+ if not success:
97
+ all_success = False
98
+
99
+ # List exported files
100
+ print(f"\nExported files in {args.output_dir}:")
101
+ for f in sorted(os.listdir(args.output_dir)):
102
+ path = os.path.join(args.output_dir, f)
103
+ if os.path.isfile(path):
104
+ size_mb = os.path.getsize(path) / (1024 * 1024)
105
+ print(f" {f}: {size_mb:.1f} MB")
106
+
107
+ if all_success:
108
+ print("\n✓ All exports completed successfully!")
109
+ else:
110
+ print("\n✗ Some exports failed")
111
+ sys.exit(1)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
onnx_export/export_dacvae.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export DACVAE (audio codec) to ONNX format.
4
+
5
+ This exports the encoder and decoder separately:
6
+ - Encoder: audio waveform → latent features
7
+ - Decoder: latent features → audio waveform
8
+
9
+ Usage:
10
+ python -m onnx_export.export_dacvae --output-dir onnx_models --verify
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import torch
16
+ import torch.nn as nn
17
+ import dacvae
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ # Default DACVAE configuration (matches SAM Audio)
22
+ DEFAULT_CONFIG = {
23
+ "encoder_dim": 64,
24
+ "encoder_rates": [2, 8, 10, 12],
25
+ "latent_dim": 1024,
26
+ "decoder_dim": 1536,
27
+ "decoder_rates": [12, 10, 8, 2],
28
+ "n_codebooks": 16,
29
+ "codebook_size": 1024,
30
+ "codebook_dim": 128,
31
+ "quantizer_dropout": False,
32
+ "sample_rate": 48000,
33
+ }
34
+
35
+
36
+ class DACVAEEncoderWrapper(nn.Module):
37
+ """Wrapper for DACVAE encoder that outputs continuous latent features."""
38
+
39
+ def __init__(self, encoder, quantizer):
40
+ super().__init__()
41
+ self.encoder = encoder
42
+ self.in_proj = quantizer.in_proj
43
+
44
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Encode audio to latent features.
47
+
48
+ Args:
49
+ audio: Input waveform, shape (batch, 1, samples)
50
+
51
+ Returns:
52
+ latent_features: Continuous latent mean, shape (batch, 128, time_steps)
53
+ """
54
+ x = self.encoder(audio)
55
+ # in_proj outputs 256 dim, chunk into mean and variance, use only mean
56
+ mean, _ = self.in_proj(x).chunk(2, dim=1)
57
+ return mean
58
+
59
+
60
+ class DACVAEDecoderWrapper(nn.Module):
61
+ """Wrapper for DACVAE decoder that takes continuous latent features."""
62
+
63
+ def __init__(self, decoder, quantizer):
64
+ super().__init__()
65
+ self.decoder = decoder
66
+ self.out_proj = quantizer.out_proj
67
+
68
+ def forward(self, latent_features: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ Decode latent features to audio.
71
+
72
+ Args:
73
+ latent_features: Continuous latent, shape (batch, 128, time_steps)
74
+
75
+ Returns:
76
+ audio: Output waveform, shape (batch, 1, samples)
77
+ """
78
+ x = self.out_proj(latent_features)
79
+ return self.decoder(x)
80
+
81
+
82
+ def create_dacvae_model(model_id: str = "facebook/sam-audio-small") -> dacvae.DACVAE:
83
+ """
84
+ Create and load DACVAE model with weights from SAM Audio checkpoint.
85
+
86
+ This uses the standalone dacvae library, avoiding loading the full SAM Audio
87
+ model and its dependencies (vision encoder, imagebind, etc).
88
+ """
89
+ print(f"Creating DACVAE model...")
90
+
91
+ model = dacvae.DACVAE(
92
+ encoder_dim=DEFAULT_CONFIG["encoder_dim"],
93
+ encoder_rates=DEFAULT_CONFIG["encoder_rates"],
94
+ latent_dim=DEFAULT_CONFIG["latent_dim"],
95
+ decoder_dim=DEFAULT_CONFIG["decoder_dim"],
96
+ decoder_rates=DEFAULT_CONFIG["decoder_rates"],
97
+ n_codebooks=DEFAULT_CONFIG["n_codebooks"],
98
+ codebook_size=DEFAULT_CONFIG["codebook_size"],
99
+ codebook_dim=DEFAULT_CONFIG["codebook_dim"],
100
+ quantizer_dropout=DEFAULT_CONFIG["quantizer_dropout"],
101
+ sample_rate=DEFAULT_CONFIG["sample_rate"],
102
+ ).eval()
103
+
104
+ # Load weights from SAM Audio checkpoint
105
+ print(f"Downloading checkpoint from {model_id}...")
106
+ checkpoint_path = hf_hub_download(
107
+ repo_id=model_id,
108
+ filename="checkpoint.pt",
109
+ )
110
+
111
+ print("Loading DACVAE weights from checkpoint...")
112
+ state_dict = torch.load(
113
+ checkpoint_path,
114
+ map_location="cpu",
115
+ weights_only=True,
116
+ mmap=True, # Memory-efficient loading
117
+ )
118
+
119
+ # Extract only DACVAE weights (prefixed with "audio_codec.")
120
+ dacvae_state_dict = {}
121
+ for k, v in state_dict.items():
122
+ if k.startswith("audio_codec."):
123
+ new_key = k.replace("audio_codec.", "")
124
+ dacvae_state_dict[new_key] = v.clone()
125
+
126
+ # Load weights
127
+ model.load_state_dict(dacvae_state_dict, strict=False)
128
+
129
+ # Clear large checkpoint from memory
130
+ del state_dict
131
+
132
+ print(f" ✓ Loaded {len(dacvae_state_dict)} DACVAE weight tensors")
133
+
134
+ # Calculate hop_length for reference
135
+ import numpy as np
136
+ hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"]))
137
+ model.hop_length = hop_length
138
+ model.sample_rate = DEFAULT_CONFIG["sample_rate"]
139
+
140
+ return model
141
+
142
+
143
+ def export_encoder(
144
+ dacvae_model: dacvae.DACVAE,
145
+ output_path: str,
146
+ opset_version: int = 18,
147
+ device: str = "cpu",
148
+ ) -> None:
149
+ """Export DACVAE encoder to ONNX."""
150
+ print(f"Exporting DACVAE encoder to {output_path}...")
151
+
152
+ wrapper = DACVAEEncoderWrapper(
153
+ dacvae_model.encoder,
154
+ dacvae_model.quantizer
155
+ ).eval().to(device)
156
+
157
+ # Sample input: 1 second of audio at 48kHz
158
+ sample_rate = DEFAULT_CONFIG["sample_rate"]
159
+ dummy_audio = torch.randn(1, 1, sample_rate, device=device)
160
+
161
+ torch.onnx.export(
162
+ wrapper,
163
+ (dummy_audio,),
164
+ output_path,
165
+ input_names=["audio"],
166
+ output_names=["latent_features"],
167
+ dynamic_axes={
168
+ "audio": {0: "batch", 2: "samples"},
169
+ "latent_features": {0: "batch", 2: "time_steps"},
170
+ },
171
+ opset_version=opset_version,
172
+ do_constant_folding=True,
173
+ dynamo=True,
174
+ external_data=True,
175
+ )
176
+
177
+ print(f" ✓ Encoder exported successfully")
178
+
179
+ # Validate
180
+ import onnx
181
+ model = onnx.load(output_path)
182
+ onnx.checker.check_model(model)
183
+ print(f" ✓ ONNX model validation passed")
184
+
185
+
186
+ def export_decoder(
187
+ dacvae_model: dacvae.DACVAE,
188
+ output_path: str,
189
+ opset_version: int = 18,
190
+ device: str = "cpu",
191
+ ) -> None:
192
+ """Export DACVAE decoder to ONNX."""
193
+ print(f"Exporting DACVAE decoder to {output_path}...")
194
+
195
+ wrapper = DACVAEDecoderWrapper(
196
+ dacvae_model.decoder,
197
+ dacvae_model.quantizer
198
+ ).eval().to(device)
199
+
200
+ # Sample input: 25 time steps (1 second at 48kHz with hop_length=1920)
201
+ hop_length = int(__import__("numpy").prod(DEFAULT_CONFIG["encoder_rates"]))
202
+ time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length
203
+ dummy_latent = torch.randn(1, 128, time_steps, device=device)
204
+
205
+ torch.onnx.export(
206
+ wrapper,
207
+ (dummy_latent,),
208
+ output_path,
209
+ input_names=["latent_features"],
210
+ output_names=["waveform"],
211
+ dynamic_axes={
212
+ "latent_features": {0: "batch", 2: "time_steps"},
213
+ "waveform": {0: "batch", 2: "samples"},
214
+ },
215
+ opset_version=opset_version,
216
+ do_constant_folding=True,
217
+ dynamo=True,
218
+ external_data=True,
219
+ )
220
+
221
+ print(f" ✓ Decoder exported successfully")
222
+
223
+ # Validate
224
+ import onnx
225
+ model = onnx.load(output_path)
226
+ onnx.checker.check_model(model)
227
+ print(f" ✓ ONNX model validation passed")
228
+
229
+
230
+ def verify_encoder(
231
+ dacvae_model: dacvae.DACVAE,
232
+ onnx_path: str,
233
+ device: str = "cpu",
234
+ tolerance: float = 1e-4,
235
+ ) -> bool:
236
+ """Verify ONNX encoder output matches PyTorch."""
237
+ import onnxruntime as ort
238
+ import numpy as np
239
+
240
+ print("Verifying encoder output...")
241
+
242
+ wrapper = DACVAEEncoderWrapper(
243
+ dacvae_model.encoder,
244
+ dacvae_model.quantizer
245
+ ).eval().to(device)
246
+
247
+ # Test with random audio
248
+ sample_rate = DEFAULT_CONFIG["sample_rate"]
249
+ test_audio = torch.randn(1, 1, sample_rate * 2, device=device) # 2 seconds
250
+
251
+ # PyTorch output
252
+ with torch.no_grad():
253
+ pytorch_output = wrapper(test_audio).cpu().numpy()
254
+
255
+ # ONNX Runtime output
256
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
257
+ onnx_output = sess.run(
258
+ ["latent_features"],
259
+ {"audio": test_audio.cpu().numpy()}
260
+ )[0]
261
+
262
+ # Compare
263
+ max_diff = np.abs(pytorch_output - onnx_output).max()
264
+ mean_diff = np.abs(pytorch_output - onnx_output).mean()
265
+
266
+ print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
267
+
268
+ if max_diff > tolerance:
269
+ print(f" ✗ Verification failed (tolerance: {tolerance})")
270
+ return False
271
+
272
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
273
+ return True
274
+
275
+
276
+ def verify_decoder(
277
+ dacvae_model: dacvae.DACVAE,
278
+ onnx_path: str,
279
+ device: str = "cpu",
280
+ tolerance: float = 1e-3,
281
+ ) -> bool:
282
+ """Verify ONNX decoder output matches PyTorch."""
283
+ import onnxruntime as ort
284
+ import numpy as np
285
+
286
+ print("Verifying decoder output...")
287
+
288
+ wrapper = DACVAEDecoderWrapper(
289
+ dacvae_model.decoder,
290
+ dacvae_model.quantizer
291
+ ).eval().to(device)
292
+
293
+ # Test with random latent
294
+ hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"]))
295
+ time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length # 25 steps = 1 second
296
+ test_latent = torch.randn(1, 128, time_steps, device=device)
297
+
298
+ # PyTorch output
299
+ with torch.no_grad():
300
+ pytorch_output = wrapper(test_latent).cpu().numpy()
301
+
302
+ # ONNX Runtime output
303
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
304
+ onnx_output = sess.run(
305
+ ["waveform"],
306
+ {"latent_features": test_latent.cpu().numpy()}
307
+ )[0]
308
+
309
+ # Compare
310
+ max_diff = np.abs(pytorch_output - onnx_output).max()
311
+ mean_diff = np.abs(pytorch_output - onnx_output).mean()
312
+
313
+ print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
314
+
315
+ if max_diff > tolerance:
316
+ print(f" ✗ Verification failed (tolerance: {tolerance})")
317
+ return False
318
+
319
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
320
+ return True
321
+
322
+
323
+ def main():
324
+ parser = argparse.ArgumentParser(description="Export DACVAE to ONNX")
325
+ parser.add_argument(
326
+ "--model-id",
327
+ type=str,
328
+ default="facebook/sam-audio-small",
329
+ help="HuggingFace model ID (default: facebook/sam-audio-small)",
330
+ )
331
+ parser.add_argument(
332
+ "--output-dir",
333
+ type=str,
334
+ default="onnx_models",
335
+ help="Output directory for ONNX models",
336
+ )
337
+ parser.add_argument(
338
+ "--opset-version",
339
+ type=int,
340
+ default=18,
341
+ help="ONNX opset version (default: 18)",
342
+ )
343
+ parser.add_argument(
344
+ "--device",
345
+ type=str,
346
+ default="cpu",
347
+ help="Device to use for export (default: cpu)",
348
+ )
349
+ parser.add_argument(
350
+ "--verify",
351
+ action="store_true",
352
+ help="Verify ONNX output matches PyTorch",
353
+ )
354
+ parser.add_argument(
355
+ "--tolerance",
356
+ type=float,
357
+ default=1e-4,
358
+ help="Tolerance for verification (default: 1e-4)",
359
+ )
360
+ parser.add_argument(
361
+ "--encoder-only",
362
+ action="store_true",
363
+ help="Export only the encoder",
364
+ )
365
+ parser.add_argument(
366
+ "--decoder-only",
367
+ action="store_true",
368
+ help="Export only the decoder",
369
+ )
370
+
371
+ args = parser.parse_args()
372
+
373
+ # Create output directory
374
+ os.makedirs(args.output_dir, exist_ok=True)
375
+
376
+ # Load model
377
+ dacvae_model = create_dacvae_model(args.model_id)
378
+
379
+ print(f"\nDACVAE Configuration:")
380
+ print(f" Model: {args.model_id}")
381
+ print(f" Sample rate: {DEFAULT_CONFIG['sample_rate']} Hz")
382
+ print(f" Hop length: {int(__import__('numpy').prod(DEFAULT_CONFIG['encoder_rates']))}")
383
+ print(f" Latent dim: 128 (continuous)")
384
+
385
+ # Export encoder
386
+ if not args.decoder_only:
387
+ encoder_path = os.path.join(args.output_dir, "dacvae_encoder.onnx")
388
+ export_encoder(
389
+ dacvae_model,
390
+ encoder_path,
391
+ opset_version=args.opset_version,
392
+ device=args.device,
393
+ )
394
+
395
+ if args.verify:
396
+ verify_encoder(
397
+ dacvae_model,
398
+ encoder_path,
399
+ device=args.device,
400
+ tolerance=args.tolerance,
401
+ )
402
+
403
+ # Export decoder
404
+ if not args.encoder_only:
405
+ decoder_path = os.path.join(args.output_dir, "dacvae_decoder.onnx")
406
+ export_decoder(
407
+ dacvae_model,
408
+ decoder_path,
409
+ opset_version=args.opset_version,
410
+ device=args.device,
411
+ )
412
+
413
+ if args.verify:
414
+ verify_decoder(
415
+ dacvae_model,
416
+ decoder_path,
417
+ device=args.device,
418
+ tolerance=args.tolerance * 10, # Decoder has higher tolerance
419
+ )
420
+
421
+ print(f"\n✓ Export complete! Models saved to {args.output_dir}/")
422
+
423
+
424
+ if __name__ == "__main__":
425
+ main()
onnx_export/export_dit.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export DiT Transformer with unrolled ODE solver to ONNX format.
4
+
5
+ The DiT transformer is the core denoising model in SAM Audio. It uses a flow-based
6
+ generative model with an ODE solver. For ONNX export, we unroll the fixed-step
7
+ midpoint ODE solver into a static computation graph.
8
+
9
+ The default configuration uses:
10
+ - method: "midpoint"
11
+ - step_size: 2/32 (0.0625)
12
+ - integration range: [0, 1]
13
+ - total steps: 16
14
+
15
+ This creates a single ONNX model that performs the complete denoising process,
16
+ taking noise and conditioning as input and producing denoised audio features.
17
+
18
+ Usage:
19
+ python -m onnx_export.export_dit --output-dir onnx_models --verify
20
+ """
21
+
22
+ import os
23
+ import math
24
+ import argparse
25
+ import torch
26
+ import torch.nn as nn
27
+ from typing import Optional
28
+
29
+
30
+ class SinusoidalEmbedding(nn.Module):
31
+ """Sinusoidal timestep embedding (identical to SAMAudio implementation)."""
32
+
33
+ def __init__(self, dim, theta=10000):
34
+ super().__init__()
35
+ assert (dim % 2) == 0
36
+ half_dim = dim // 2
37
+ inv_freq = torch.exp(
38
+ -math.log(theta) * torch.arange(half_dim).float() / half_dim
39
+ )
40
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
41
+
42
+ def forward(self, x, pos=None):
43
+ if pos is None:
44
+ seq_len, device = x.shape[1], x.device
45
+ pos = torch.arange(seq_len, device=device)
46
+
47
+ emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
48
+ emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
49
+ return emb
50
+
51
+
52
+ class EmbedAnchors(nn.Module):
53
+ """Anchor embedding (identical to SAMAudio implementation)."""
54
+
55
+ def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
56
+ super().__init__()
57
+ self.embed = nn.Embedding(
58
+ num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
59
+ )
60
+ self.gate = nn.Parameter(torch.tensor([0.0]))
61
+ self.proj = nn.Linear(embedding_dim, out_dim, bias=False)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ anchor_ids: Optional[torch.Tensor] = None,
67
+ anchor_alignment: Optional[torch.Tensor] = None,
68
+ ):
69
+ if anchor_ids is None:
70
+ return x
71
+
72
+ embs = self.embed(anchor_ids.gather(1, anchor_alignment))
73
+ proj = self.proj(embs)
74
+ return x + self.gate.tanh() * proj
75
+
76
+
77
+ class DiTSingleStepWrapper(nn.Module):
78
+ """
79
+ Wrapper for DiT that performs a single forward pass (one ODE evaluation).
80
+
81
+ This mirrors the SAMAudio.forward() method exactly.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ transformer: nn.Module,
87
+ proj: nn.Module,
88
+ align_masked_video: nn.Module,
89
+ embed_anchors: nn.Module,
90
+ timestep_emb: nn.Module,
91
+ memory_proj: nn.Module,
92
+ ):
93
+ super().__init__()
94
+ self.transformer = transformer
95
+ self.proj = proj
96
+ self.align_masked_video = align_masked_video
97
+ self.embed_anchors = embed_anchors
98
+ self.timestep_emb = timestep_emb
99
+ self.memory_proj = memory_proj
100
+
101
+ def forward(
102
+ self,
103
+ noisy_audio: torch.Tensor,
104
+ time: torch.Tensor,
105
+ audio_features: torch.Tensor,
106
+ text_features: torch.Tensor,
107
+ text_mask: torch.Tensor,
108
+ masked_video_features: torch.Tensor,
109
+ anchor_ids: torch.Tensor,
110
+ anchor_alignment: torch.Tensor,
111
+ audio_pad_mask: torch.Tensor,
112
+ ) -> torch.Tensor:
113
+ """
114
+ Single forward pass of the DiT (one ODE function evaluation).
115
+
116
+ This exactly mirrors SAMAudio.forward() method.
117
+ """
118
+ # Align inputs (concatenate noisy_audio with audio_features)
119
+ # Same as SAMAudio.align_inputs()
120
+ x = torch.cat(
121
+ [
122
+ noisy_audio,
123
+ torch.zeros_like(audio_features),
124
+ audio_features,
125
+ ],
126
+ dim=2,
127
+ )
128
+
129
+ projected = self.proj(x)
130
+ aligned = self.align_masked_video(projected, masked_video_features)
131
+ aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
132
+
133
+ # Timestep embedding and memory
134
+ # Same as SAMAudio.forward()
135
+ timestep_emb_val = self.timestep_emb(time, pos=time).unsqueeze(1)
136
+ memory = self.memory_proj(text_features) + timestep_emb_val
137
+
138
+ # Transformer forward
139
+ output = self.transformer(
140
+ aligned,
141
+ time,
142
+ padding_mask=audio_pad_mask,
143
+ memory=memory,
144
+ memory_padding_mask=text_mask,
145
+ )
146
+
147
+ return output
148
+
149
+
150
+ class UnrolledDiTWrapper(nn.Module):
151
+ """
152
+ DiT wrapper with unrolled midpoint ODE solver.
153
+
154
+ The midpoint method computes:
155
+ k1 = f(t, y)
156
+ k2 = f(t + h/2, y + h/2 * k1)
157
+ y_new = y + h * k2
158
+
159
+ With step_size=0.0625 and range [0,1], we have 16 steps.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ single_step: DiTSingleStepWrapper,
165
+ num_steps: int = 16,
166
+ ):
167
+ super().__init__()
168
+ self.single_step = single_step
169
+ self.num_steps = num_steps
170
+ self.step_size = 1.0 / num_steps
171
+
172
+ def forward(
173
+ self,
174
+ noise: torch.Tensor,
175
+ audio_features: torch.Tensor,
176
+ text_features: torch.Tensor,
177
+ text_mask: torch.Tensor,
178
+ masked_video_features: torch.Tensor,
179
+ anchor_ids: torch.Tensor,
180
+ anchor_alignment: torch.Tensor,
181
+ audio_pad_mask: torch.Tensor,
182
+ ) -> torch.Tensor:
183
+ """Complete denoising using unrolled midpoint ODE solver."""
184
+ B = noise.shape[0]
185
+ h = self.step_size
186
+ y = noise
187
+ t = torch.zeros(B, device=noise.device, dtype=noise.dtype)
188
+
189
+ for step in range(self.num_steps):
190
+ # k1 = f(t, y)
191
+ k1 = self.single_step(
192
+ y, t,
193
+ audio_features, text_features, text_mask,
194
+ masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
195
+ )
196
+
197
+ # k2 = f(t + h/2, y + h/2 * k1)
198
+ t_mid = t + h / 2
199
+ y_mid = y + (h / 2) * k1
200
+ k2 = self.single_step(
201
+ y_mid, t_mid,
202
+ audio_features, text_features, text_mask,
203
+ masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
204
+ )
205
+
206
+ # y = y + h * k2
207
+ y = y + h * k2
208
+ t = t + h
209
+
210
+ return y
211
+
212
+
213
+ def load_sam_audio_components(model_id: str = "facebook/sam-audio-small", device: str = "cpu"):
214
+ """
215
+ Load SAM Audio components needed for DiT export.
216
+
217
+ Since we can't load the full SAMAudio model (missing perception_models),
218
+ we construct the components directly and load weights from checkpoint.
219
+ """
220
+ import json
221
+ import sys
222
+ import types
223
+ import importlib.util
224
+ from huggingface_hub import hf_hub_download
225
+
226
+ print(f"Loading SAM Audio components from {model_id}...")
227
+
228
+ # Download config
229
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json")
230
+ with open(config_path) as f:
231
+ config = json.load(f)
232
+
233
+ # Download checkpoint
234
+ checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
235
+
236
+ # Use our standalone config that doesn't have 'core' dependencies
237
+ from onnx_export.standalone_config import TransformerConfig
238
+
239
+ sam_audio_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
240
+
241
+ # Create fake module hierarchy so transformer.py's relative imports work
242
+ if 'sam_audio' not in sys.modules:
243
+ sam_audio_pkg = types.ModuleType('sam_audio')
244
+ sam_audio_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio')]
245
+ sys.modules['sam_audio'] = sam_audio_pkg
246
+
247
+ if 'sam_audio.model' not in sys.modules:
248
+ model_pkg = types.ModuleType('sam_audio.model')
249
+ model_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio', 'model')]
250
+ sys.modules['sam_audio.model'] = model_pkg
251
+
252
+ # Register our standalone config as sam_audio.model.config
253
+ if 'sam_audio.model.config' not in sys.modules:
254
+ import onnx_export.standalone_config as standalone_config
255
+ sys.modules['sam_audio.model.config'] = standalone_config
256
+
257
+ # Now import transformer module - it will use our standalone config
258
+ transformer_spec = importlib.util.spec_from_file_location(
259
+ "sam_audio.model.transformer",
260
+ os.path.join(sam_audio_path, "sam_audio", "model", "transformer.py")
261
+ )
262
+ transformer_module = importlib.util.module_from_spec(transformer_spec)
263
+ sys.modules['sam_audio.model.transformer'] = transformer_module
264
+ transformer_spec.loader.exec_module(transformer_module)
265
+ DiT = transformer_module.DiT
266
+
267
+ # Import align module
268
+ align_spec = importlib.util.spec_from_file_location(
269
+ "sam_audio.model.align",
270
+ os.path.join(sam_audio_path, "sam_audio", "model", "align.py")
271
+ )
272
+ align_module = importlib.util.module_from_spec(align_spec)
273
+ sys.modules['sam_audio.model.align'] = align_module
274
+ align_spec.loader.exec_module(align_module)
275
+ AlignModalities = align_module.AlignModalities
276
+
277
+ # Create transformer
278
+ transformer_config = TransformerConfig(**config.get("transformer", {}))
279
+ transformer = DiT(transformer_config)
280
+
281
+ # Calculate dimensions
282
+ in_channels = config.get("in_channels", 768)
283
+ num_anchors = config.get("num_anchors", 3)
284
+ anchor_embedding_dim = config.get("anchor_embedding_dim", 128)
285
+
286
+ # Get vision encoder dim for align_masked_video
287
+ vision_config = config.get("vision_encoder", {})
288
+ vision_dim = vision_config.get("dim", 768)
289
+
290
+ # Create components exactly as SAMAudio does
291
+ proj = nn.Linear(in_channels, transformer_config.d_model)
292
+ align_masked_video = AlignModalities(vision_dim, transformer_config.d_model)
293
+ embed_anchors = EmbedAnchors(num_anchors, anchor_embedding_dim, transformer_config.d_model)
294
+ timestep_emb = SinusoidalEmbedding(transformer_config.d_model)
295
+
296
+ # Memory projection for text features
297
+ text_encoder_config = config.get("text_encoder", {})
298
+ text_encoder_dim = text_encoder_config.get("dim", 1024) # google/flan-t5-large
299
+ memory_proj = nn.Linear(text_encoder_dim, transformer_config.d_model)
300
+
301
+ # Load weights from checkpoint
302
+ print("Loading weights from checkpoint...")
303
+ state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
304
+
305
+ # Filter and load weights for each component
306
+ transformer_state = {}
307
+ proj_state = {}
308
+ align_state = {}
309
+ embed_anchors_state = {}
310
+ memory_proj_state = {}
311
+
312
+ for key, value in state_dict.items():
313
+ if key.startswith("transformer."):
314
+ new_key = key[len("transformer."):]
315
+ transformer_state[new_key] = value
316
+ elif key.startswith("proj."):
317
+ new_key = key[len("proj."):]
318
+ proj_state[new_key] = value
319
+ elif key.startswith("align_masked_video."):
320
+ new_key = key[len("align_masked_video."):]
321
+ align_state[new_key] = value
322
+ elif key.startswith("embed_anchors."):
323
+ new_key = key[len("embed_anchors."):]
324
+ embed_anchors_state[new_key] = value
325
+ elif key.startswith("memory_proj."):
326
+ new_key = key[len("memory_proj."):]
327
+ memory_proj_state[new_key] = value
328
+
329
+ transformer.load_state_dict(transformer_state)
330
+ proj.load_state_dict(proj_state)
331
+ align_masked_video.load_state_dict(align_state)
332
+ embed_anchors.load_state_dict(embed_anchors_state)
333
+ memory_proj.load_state_dict(memory_proj_state)
334
+
335
+ print(f" ✓ Loaded transformer weights ({len(transformer_state)} tensors)")
336
+ print(f" ✓ Loaded component weights")
337
+
338
+ # Create single step wrapper
339
+ single_step = DiTSingleStepWrapper(
340
+ transformer=transformer,
341
+ proj=proj,
342
+ align_masked_video=align_masked_video,
343
+ embed_anchors=embed_anchors,
344
+ timestep_emb=timestep_emb,
345
+ memory_proj=memory_proj,
346
+ ).eval().to(device)
347
+
348
+ return single_step, config
349
+
350
+
351
+ def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "cpu"):
352
+ """Create sample inputs for tracing."""
353
+ latent_dim = 128
354
+ text_dim = 768 # T5-base hidden size (SAM Audio was trained with 768-dim text)
355
+ vision_dim = 1024 # Vision encoder dim from config
356
+ text_len = 77
357
+
358
+ return {
359
+ "noisy_audio": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
360
+ "time": torch.zeros(batch_size, device=device),
361
+ "audio_features": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
362
+ "text_features": torch.randn(batch_size, text_len, text_dim, device=device),
363
+ "text_mask": torch.ones(batch_size, text_len, dtype=torch.bool, device=device),
364
+ "masked_video_features": torch.zeros(batch_size, vision_dim, seq_len, device=device),
365
+ "anchor_ids": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
366
+ "anchor_alignment": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
367
+ "audio_pad_mask": torch.ones(batch_size, seq_len, dtype=torch.bool, device=device),
368
+ }
369
+
370
+
371
+ def export_dit_single_step(
372
+ single_step: DiTSingleStepWrapper,
373
+ output_path: str,
374
+ opset_version: int = 18,
375
+ device: str = "cpu",
376
+ ):
377
+ """Export single-step DiT to ONNX (for runtime ODE solving)."""
378
+ import onnx
379
+
380
+ print(f"Exporting DiT single-step to {output_path}...")
381
+
382
+ sample_inputs = create_sample_inputs(device=device)
383
+
384
+ torch.onnx.export(
385
+ single_step,
386
+ tuple(sample_inputs.values()),
387
+ output_path,
388
+ input_names=list(sample_inputs.keys()),
389
+ output_names=["velocity"],
390
+ dynamic_axes={
391
+ "noisy_audio": {0: "batch_size", 1: "seq_len"},
392
+ "time": {0: "batch_size"},
393
+ "audio_features": {0: "batch_size", 1: "seq_len"},
394
+ "text_features": {0: "batch_size", 1: "text_len"},
395
+ "text_mask": {0: "batch_size", 1: "text_len"},
396
+ "masked_video_features": {0: "batch_size", 2: "seq_len"},
397
+ "anchor_ids": {0: "batch_size", 1: "seq_len"},
398
+ "anchor_alignment": {0: "batch_size", 1: "seq_len"},
399
+ "audio_pad_mask": {0: "batch_size", 1: "seq_len"},
400
+ "velocity": {0: "batch_size", 1: "seq_len"},
401
+ },
402
+ opset_version=opset_version,
403
+ do_constant_folding=True,
404
+ dynamo=True,
405
+ external_data=True,
406
+ )
407
+
408
+ print(" ✓ DiT single-step exported successfully")
409
+
410
+ model = onnx.load(output_path)
411
+ onnx.checker.check_model(model)
412
+ print(" ✓ ONNX model validation passed")
413
+
414
+ return True
415
+
416
+
417
+ def verify_dit_single_step(
418
+ single_step: DiTSingleStepWrapper,
419
+ onnx_path: str,
420
+ device: str = "cpu",
421
+ tolerance: float = 1e-3,
422
+ ) -> bool:
423
+ """Verify single-step ONNX output matches PyTorch."""
424
+ import onnxruntime as ort
425
+ import numpy as np
426
+
427
+ print("Verifying DiT single-step output...")
428
+
429
+ sample_inputs = create_sample_inputs(device=device)
430
+
431
+ # PyTorch output
432
+ with torch.no_grad():
433
+ pytorch_output = single_step(**sample_inputs).cpu().numpy()
434
+
435
+ # ONNX Runtime output
436
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
437
+
438
+ onnx_inputs = {}
439
+ for name, tensor in sample_inputs.items():
440
+ if tensor.dtype == torch.bool:
441
+ onnx_inputs[name] = tensor.cpu().numpy().astype(bool)
442
+ elif tensor.dtype == torch.long:
443
+ onnx_inputs[name] = tensor.cpu().numpy().astype(np.int64)
444
+ else:
445
+ onnx_inputs[name] = tensor.cpu().numpy().astype(np.float32)
446
+
447
+ onnx_output = sess.run(["velocity"], onnx_inputs)[0]
448
+
449
+ # Compare
450
+ max_diff = np.abs(pytorch_output - onnx_output).max()
451
+ mean_diff = np.abs(pytorch_output - onnx_output).mean()
452
+
453
+ print(f" Max difference: {max_diff:.2e}")
454
+ print(f" Mean difference: {mean_diff:.2e}")
455
+
456
+ if max_diff < tolerance:
457
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
458
+ return True
459
+ else:
460
+ print(f" ✗ Verification failed (tolerance: {tolerance})")
461
+ return False
462
+
463
+
464
+ def main():
465
+ parser = argparse.ArgumentParser(description="Export DiT Transformer to ONNX")
466
+ parser.add_argument(
467
+ "--model-id",
468
+ type=str,
469
+ default="facebook/sam-audio-small",
470
+ help="SAM Audio model ID from HuggingFace",
471
+ )
472
+ parser.add_argument(
473
+ "--output-dir",
474
+ type=str,
475
+ default="onnx_models",
476
+ help="Output directory for ONNX models",
477
+ )
478
+ parser.add_argument(
479
+ "--num-steps",
480
+ type=int,
481
+ default=16,
482
+ help="Number of ODE solver steps (default: 16)",
483
+ )
484
+ parser.add_argument(
485
+ "--opset",
486
+ type=int,
487
+ default=18,
488
+ help="ONNX opset version (default: 18)",
489
+ )
490
+ parser.add_argument(
491
+ "--device",
492
+ type=str,
493
+ default="cpu",
494
+ help="Device to use for export (default: cpu)",
495
+ )
496
+ parser.add_argument(
497
+ "--verify",
498
+ action="store_true",
499
+ help="Verify ONNX output matches PyTorch",
500
+ )
501
+ parser.add_argument(
502
+ "--tolerance",
503
+ type=float,
504
+ default=1e-3,
505
+ help="Tolerance for verification (default: 1e-3)",
506
+ )
507
+
508
+ args = parser.parse_args()
509
+
510
+ # Create output directory
511
+ os.makedirs(args.output_dir, exist_ok=True)
512
+
513
+ # Load components
514
+ single_step, config = load_sam_audio_components(args.model_id, args.device)
515
+
516
+ print(f"\nDiT Configuration:")
517
+ print(f" Model: {args.model_id}")
518
+ print(f" ODE steps: {args.num_steps}")
519
+ print(f" Step size: {1.0/args.num_steps:.4f}")
520
+
521
+ # Export single-step model
522
+ single_step_path = os.path.join(args.output_dir, "dit_single_step.onnx")
523
+ export_dit_single_step(
524
+ single_step,
525
+ single_step_path,
526
+ opset_version=args.opset,
527
+ device=args.device,
528
+ )
529
+
530
+ # Verify single-step
531
+ if args.verify:
532
+ verify_dit_single_step(
533
+ single_step,
534
+ single_step_path,
535
+ device=args.device,
536
+ tolerance=args.tolerance,
537
+ )
538
+
539
+ print(f"\n✓ Export complete! Model saved to {args.output_dir}")
540
+
541
+
542
+ if __name__ == "__main__":
543
+ main()
onnx_export/export_peaframe.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export PE-A-Frame (Perception Encoder Audio Frame) span predictor to ONNX.
4
+
5
+ The PE-A-Frame model is used for automatic anchor detection in SAM Audio.
6
+ It analyzes audio features and predicts which segments correspond to the
7
+ target audio source.
8
+
9
+ Usage:
10
+ python -m onnx_export.export_peaframe --output-dir onnx_models --verify
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Optional
18
+
19
+
20
+ class PEAFrameWrapper(nn.Module):
21
+ """
22
+ Wrapper for PE-A-Frame model for ONNX export.
23
+
24
+ Exposes the forward pass that takes audio features and returns
25
+ frame-level predictions.
26
+ """
27
+
28
+ def __init__(self, model: nn.Module):
29
+ super().__init__()
30
+ self.model = model
31
+
32
+ def forward(
33
+ self,
34
+ audio_features: torch.Tensor,
35
+ audio_mask: Optional[torch.Tensor] = None,
36
+ ) -> torch.Tensor:
37
+ """
38
+ Forward pass for span prediction.
39
+
40
+ Args:
41
+ audio_features: Audio features [batch, seq_len, hidden_dim]
42
+ audio_mask: Optional attention mask [batch, seq_len]
43
+
44
+ Returns:
45
+ Frame-level predictions [batch, seq_len, num_classes]
46
+ """
47
+ return self.model(audio_features, attention_mask=audio_mask)
48
+
49
+
50
+ def load_peaframe_model(config_name: str = "pe-a-frame-large", device: str = "cpu"):
51
+ """Load the PE-A-Frame model from perception_models."""
52
+ from core.audio_visual_encoder.pe import PEAudioFrame
53
+
54
+ print(f"Loading PE-A-Frame model: {config_name}...")
55
+ model = PEAudioFrame.from_config(config_name, pretrained=True)
56
+ model = model.eval().to(device)
57
+
58
+ num_params = sum(p.numel() for p in model.parameters())
59
+ print(f" ✓ Model loaded: {num_params:,} parameters")
60
+
61
+ return model
62
+
63
+
64
+ def get_tokenizer(model):
65
+ """Get the text tokenizer from the model config."""
66
+ from transformers import AutoTokenizer
67
+
68
+ text_model_name = model.config.text_model._name_or_path
69
+ return AutoTokenizer.from_pretrained(text_model_name)
70
+
71
+
72
+ def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"):
73
+ """Create sample inputs for tracing."""
74
+ tokenizer = get_tokenizer(model)
75
+
76
+ # Sample text query
77
+ text = "a person speaking"
78
+ tokens = tokenizer(
79
+ [text] * batch_size,
80
+ return_tensors="pt",
81
+ padding=True,
82
+ truncation=True,
83
+ max_length=77,
84
+ )
85
+
86
+ # Sample audio (10 seconds at 16kHz)
87
+ # DAC encoder expects (batch, channels, samples) format
88
+ sample_rate = 16000
89
+ audio_len = sample_rate * 10
90
+ audio = torch.randn(batch_size, 1, audio_len, device=device) # Added channel dimension
91
+
92
+ return {
93
+ "input_ids": tokens["input_ids"].to(device),
94
+ "attention_mask": tokens["attention_mask"].to(device),
95
+ "input_values": audio,
96
+ }
97
+
98
+
99
+ def export_peaframe(
100
+ model: nn.Module,
101
+ output_path: str,
102
+ opset_version: int = 18,
103
+ device: str = "cpu",
104
+ ):
105
+ """Export PE-A-Frame to ONNX."""
106
+ import onnx
107
+
108
+ print(f"Exporting PE-A-Frame to {output_path}...")
109
+
110
+ sample_inputs = create_sample_inputs(model, device=device)
111
+
112
+ # Put model in eval mode
113
+ model = model.eval()
114
+
115
+ # Test forward pass first
116
+ with torch.no_grad():
117
+ try:
118
+ output = model(
119
+ input_ids=sample_inputs["input_ids"],
120
+ input_values=sample_inputs["input_values"],
121
+ attention_mask=sample_inputs["attention_mask"],
122
+ return_spans=False, # Disable span return for ONNX (list output)
123
+ )
124
+ print(f" Test forward pass: audio_embeds shape = {output.audio_embeds.shape}")
125
+ print(f" Test forward pass: text_embeds shape = {output.text_embeds.shape}")
126
+ except Exception as e:
127
+ print(f" Forward pass failed: {e}")
128
+ raise
129
+
130
+ # Create a wrapper that returns just the audio embeddings for simpler ONNX
131
+ class PEAFrameONNXWrapper(nn.Module):
132
+ def __init__(self, model):
133
+ super().__init__()
134
+ self.model = model
135
+
136
+ def forward(self, input_ids, input_values, attention_mask):
137
+ output = self.model(
138
+ input_ids=input_ids,
139
+ input_values=input_values,
140
+ attention_mask=attention_mask,
141
+ return_spans=False,
142
+ )
143
+ return output.audio_embeds, output.text_embeds
144
+
145
+ wrapper = PEAFrameONNXWrapper(model)
146
+ wrapper.eval()
147
+
148
+ torch.onnx.export(
149
+ wrapper,
150
+ (sample_inputs["input_ids"], sample_inputs["input_values"], sample_inputs["attention_mask"]),
151
+ output_path,
152
+ input_names=["input_ids", "input_values", "attention_mask"],
153
+ output_names=["audio_embeds", "text_embeds"],
154
+ dynamic_axes={
155
+ "input_ids": {0: "batch_size", 1: "seq_len"},
156
+ "input_values": {0: "batch_size", 1: "audio_len"},
157
+ "attention_mask": {0: "batch_size", 1: "seq_len"},
158
+ "audio_embeds": {0: "batch_size", 1: "num_frames"},
159
+ "text_embeds": {0: "batch_size"},
160
+ },
161
+ opset_version=opset_version,
162
+ do_constant_folding=True,
163
+ external_data=True,
164
+ )
165
+
166
+ print(" ✓ PE-A-Frame exported successfully")
167
+
168
+ # Validate
169
+ onnx_model = onnx.load(output_path)
170
+ onnx.checker.check_model(onnx_model)
171
+ print(" ✓ ONNX model validation passed")
172
+
173
+ return True
174
+
175
+
176
+ def verify_peaframe(
177
+ model: nn.Module,
178
+ onnx_path: str,
179
+ device: str = "cpu",
180
+ tolerance: float = 1e-3,
181
+ ) -> bool:
182
+ """Verify ONNX output matches PyTorch."""
183
+ import onnxruntime as ort
184
+ import numpy as np
185
+
186
+ print("Verifying PE-A-Frame output...")
187
+
188
+ sample_inputs = create_sample_inputs(model, device=device)
189
+
190
+ # PyTorch output
191
+ model = model.eval()
192
+ with torch.no_grad():
193
+ pytorch_output = model(
194
+ input_ids=sample_inputs["input_ids"],
195
+ input_values=sample_inputs["input_values"],
196
+ attention_mask=sample_inputs["attention_mask"],
197
+ return_spans=False,
198
+ )
199
+ pytorch_audio_embeds = pytorch_output.audio_embeds.cpu().numpy()
200
+ pytorch_text_embeds = pytorch_output.text_embeds.cpu().numpy()
201
+
202
+ # ONNX Runtime output
203
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
204
+
205
+ onnx_inputs = {
206
+ "input_ids": sample_inputs["input_ids"].cpu().numpy().astype(np.int64),
207
+ "input_values": sample_inputs["input_values"].cpu().numpy().astype(np.float32),
208
+ "attention_mask": sample_inputs["attention_mask"].cpu().numpy().astype(np.int64),
209
+ }
210
+
211
+ onnx_outputs = sess.run(["audio_embeds", "text_embeds"], onnx_inputs)
212
+ onnx_audio_embeds = onnx_outputs[0]
213
+ onnx_text_embeds = onnx_outputs[1]
214
+
215
+ # Compare
216
+ audio_max_diff = np.abs(pytorch_audio_embeds - onnx_audio_embeds).max()
217
+ text_max_diff = np.abs(pytorch_text_embeds - onnx_text_embeds).max()
218
+
219
+ print(f" Audio embeds max diff: {audio_max_diff:.2e}")
220
+ print(f" Text embeds max diff: {text_max_diff:.2e}")
221
+
222
+ max_diff = max(audio_max_diff, text_max_diff)
223
+ if max_diff < tolerance:
224
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
225
+ return True
226
+ else:
227
+ print(f" ✗ Verification failed (tolerance: {tolerance})")
228
+ return False
229
+
230
+
231
+ def main():
232
+ parser = argparse.ArgumentParser(description="Export PE-A-Frame to ONNX")
233
+ parser.add_argument(
234
+ "--config",
235
+ type=str,
236
+ default="pe-a-frame-large",
237
+ help="PE-A-Frame config name",
238
+ )
239
+ parser.add_argument(
240
+ "--output-dir",
241
+ type=str,
242
+ default="onnx_models",
243
+ help="Output directory for ONNX models",
244
+ )
245
+ parser.add_argument(
246
+ "--opset",
247
+ type=int,
248
+ default=18,
249
+ help="ONNX opset version",
250
+ )
251
+ parser.add_argument(
252
+ "--device",
253
+ type=str,
254
+ default="cpu",
255
+ help="Device to use",
256
+ )
257
+ parser.add_argument(
258
+ "--verify",
259
+ action="store_true",
260
+ help="Verify ONNX output",
261
+ )
262
+ parser.add_argument(
263
+ "--tolerance",
264
+ type=float,
265
+ default=1e-3,
266
+ help="Verification tolerance",
267
+ )
268
+
269
+ args = parser.parse_args()
270
+
271
+ os.makedirs(args.output_dir, exist_ok=True)
272
+
273
+ # Load model
274
+ model = load_peaframe_model(args.config, args.device)
275
+
276
+ # Export
277
+ output_path = os.path.join(args.output_dir, "peaframe.onnx")
278
+ export_peaframe(model, output_path, args.opset, args.device)
279
+
280
+ # Verify
281
+ if args.verify:
282
+ verify_peaframe(model, output_path, args.device, args.tolerance)
283
+
284
+ print(f"\n✓ Export complete! Model saved to {output_path}")
285
+
286
+
287
+ if __name__ == "__main__":
288
+ main()
onnx_export/export_t5.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export T5 Text Encoder to ONNX format.
4
+
5
+ The T5 encoder takes tokenized input_ids and attention_mask, and produces
6
+ hidden states. For SAM Audio inference, the output hidden states and attention
7
+ mask are used as conditioning for the DiT transformer.
8
+
9
+ Usage:
10
+ python -m onnx_export.export_t5 --output-dir onnx_models --verify
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ class T5EncoderWrapper(nn.Module):
20
+ """
21
+ Wrapper for T5EncoderModel that provides a clean interface for ONNX export.
22
+
23
+ The wrapper takes tokenized inputs (input_ids, attention_mask) and returns
24
+ the last hidden state. This matches how SAMAudio's T5TextEncoder uses the model.
25
+ """
26
+
27
+ def __init__(self, t5_model, max_length: int = 77):
28
+ super().__init__()
29
+ self.model = t5_model
30
+ self.max_length = max_length
31
+
32
+ def forward(
33
+ self,
34
+ input_ids: torch.Tensor,
35
+ attention_mask: torch.Tensor,
36
+ ) -> torch.Tensor:
37
+ """
38
+ Args:
39
+ input_ids: Tokenized input IDs, shape (batch, seq_len)
40
+ attention_mask: Attention mask, shape (batch, seq_len)
41
+
42
+ Returns:
43
+ hidden_states: T5 encoder output, shape (batch, seq_len, hidden_dim)
44
+ """
45
+ outputs = self.model(
46
+ input_ids=input_ids,
47
+ attention_mask=attention_mask,
48
+ output_hidden_states=True,
49
+ )
50
+ return outputs.last_hidden_state
51
+
52
+
53
+ def load_t5_encoder(model_name: str = "google-t5/t5-base", device: str = "cpu"):
54
+ """
55
+ Load T5 encoder model and tokenizer.
56
+
57
+ SAM Audio's DiT was trained with T5-base (768-dim) text features.
58
+ """
59
+ from transformers import T5EncoderModel, AutoTokenizer
60
+
61
+ print(f"Loading T5 encoder: {model_name}...")
62
+
63
+ model = T5EncoderModel.from_pretrained(model_name)
64
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
65
+
66
+ model = model.eval().to(device)
67
+
68
+ return model, tokenizer
69
+
70
+
71
+ def export_t5_encoder(
72
+ t5_model,
73
+ tokenizer,
74
+ output_path: str,
75
+ opset_version: int = 18,
76
+ max_length: int = 77,
77
+ device: str = "cpu",
78
+ ):
79
+ """Export T5 encoder to ONNX format."""
80
+ import onnx
81
+
82
+ print(f"Exporting T5 encoder to {output_path}...")
83
+
84
+ wrapper = T5EncoderWrapper(t5_model, max_length=max_length).eval().to(device)
85
+
86
+ # Create sample input
87
+ sample_text = ["A dog barking loudly in the background"]
88
+ encoded = tokenizer(
89
+ sample_text,
90
+ truncation=True,
91
+ max_length=max_length,
92
+ padding="max_length", # Pad to max_length for consistent shape
93
+ return_tensors="pt",
94
+ )
95
+
96
+ sample_input_ids = encoded["input_ids"].to(device)
97
+ sample_attention_mask = encoded["attention_mask"].to(device)
98
+
99
+ # Export using torch.onnx.export
100
+ torch.onnx.export(
101
+ wrapper,
102
+ (sample_input_ids, sample_attention_mask),
103
+ output_path,
104
+ input_names=["input_ids", "attention_mask"],
105
+ output_names=["hidden_states"],
106
+ dynamic_axes={
107
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
108
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
109
+ "hidden_states": {0: "batch_size", 1: "sequence_length"},
110
+ },
111
+ opset_version=opset_version,
112
+ do_constant_folding=True,
113
+ dynamo=True,
114
+ external_data=True, # T5-large is ~1GB
115
+ )
116
+
117
+ print(" ✓ T5 encoder exported successfully")
118
+
119
+ # Validate the model
120
+ model = onnx.load(output_path)
121
+ onnx.checker.check_model(model)
122
+ print(" ✓ ONNX model validation passed")
123
+
124
+ return True
125
+
126
+
127
+ def verify_t5_encoder(
128
+ t5_model,
129
+ tokenizer,
130
+ onnx_path: str,
131
+ max_length: int = 77,
132
+ device: str = "cpu",
133
+ tolerance: float = 1e-4,
134
+ ) -> bool:
135
+ """Verify ONNX T5 encoder output matches PyTorch."""
136
+ import onnxruntime as ort
137
+ import numpy as np
138
+
139
+ print("Verifying T5 encoder output...")
140
+
141
+ wrapper = T5EncoderWrapper(t5_model, max_length=max_length).eval().to(device)
142
+
143
+ # Test with multiple texts
144
+ test_texts = [
145
+ "A dog barking in the distance",
146
+ "Piano music playing softly",
147
+ "Rain falling on a rooftop",
148
+ ]
149
+
150
+ for text in test_texts:
151
+ # Tokenize
152
+ encoded = tokenizer(
153
+ [text],
154
+ truncation=True,
155
+ max_length=max_length,
156
+ padding="max_length",
157
+ return_tensors="pt",
158
+ )
159
+
160
+ input_ids = encoded["input_ids"].to(device)
161
+ attention_mask = encoded["attention_mask"].to(device)
162
+
163
+ # PyTorch output
164
+ with torch.no_grad():
165
+ pytorch_output = wrapper(input_ids, attention_mask).cpu().numpy()
166
+
167
+ # ONNX Runtime output
168
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
169
+ onnx_output = sess.run(
170
+ ["hidden_states"],
171
+ {
172
+ "input_ids": input_ids.cpu().numpy().astype(np.int64),
173
+ "attention_mask": attention_mask.cpu().numpy().astype(np.int64),
174
+ }
175
+ )[0]
176
+
177
+ # Compare
178
+ max_diff = np.abs(pytorch_output - onnx_output).max()
179
+ mean_diff = np.abs(pytorch_output - onnx_output).mean()
180
+
181
+ print(f" Text: '{text[:30]}...'")
182
+ print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
183
+
184
+ if max_diff > tolerance:
185
+ print(f" ✗ Verification failed for text: {text}")
186
+ return False
187
+
188
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
189
+ return True
190
+
191
+
192
+ def save_tokenizer_config(tokenizer, output_dir: str):
193
+ """
194
+ Save tokenizer vocabulary and configuration for runtime use.
195
+
196
+ This allows the ONNX runtime to perform tokenization without
197
+ needing the full transformers library.
198
+ """
199
+ import json
200
+
201
+ tokenizer_dir = os.path.join(output_dir, "tokenizer")
202
+ tokenizer.save_pretrained(tokenizer_dir)
203
+
204
+ # Also save a simple config for ONNX.js usage
205
+ config = {
206
+ "model_name": tokenizer.name_or_path,
207
+ "max_length": 77,
208
+ "vocab_size": tokenizer.vocab_size,
209
+ "pad_token_id": tokenizer.pad_token_id,
210
+ "eos_token_id": tokenizer.eos_token_id,
211
+ }
212
+
213
+ config_path = os.path.join(output_dir, "tokenizer_config.json")
214
+ with open(config_path, "w") as f:
215
+ json.dump(config, f, indent=2)
216
+
217
+ print(f" ✓ Tokenizer saved to {tokenizer_dir}")
218
+ return tokenizer_dir
219
+
220
+
221
+ def main():
222
+ parser = argparse.ArgumentParser(description="Export T5 Text Encoder to ONNX")
223
+ parser.add_argument(
224
+ "--model-name",
225
+ type=str,
226
+ default="google-t5/t5-base",
227
+ help="T5 model name from HuggingFace (default: google-t5/t5-base)",
228
+ )
229
+ parser.add_argument(
230
+ "--output-dir",
231
+ type=str,
232
+ default="onnx_models",
233
+ help="Output directory for ONNX models",
234
+ )
235
+ parser.add_argument(
236
+ "--max-length",
237
+ type=int,
238
+ default=77,
239
+ help="Maximum token sequence length (default: 77)",
240
+ )
241
+ parser.add_argument(
242
+ "--opset",
243
+ type=int,
244
+ default=18,
245
+ help="ONNX opset version (default: 18)",
246
+ )
247
+ parser.add_argument(
248
+ "--device",
249
+ type=str,
250
+ default="cpu",
251
+ help="Device to use for export (default: cpu)",
252
+ )
253
+ parser.add_argument(
254
+ "--verify",
255
+ action="store_true",
256
+ help="Verify ONNX output matches PyTorch",
257
+ )
258
+ parser.add_argument(
259
+ "--tolerance",
260
+ type=float,
261
+ default=1e-4,
262
+ help="Tolerance for verification (default: 1e-4)",
263
+ )
264
+ parser.add_argument(
265
+ "--save-tokenizer",
266
+ action="store_true",
267
+ default=True,
268
+ help="Save tokenizer for runtime use (default: True)",
269
+ )
270
+
271
+ args = parser.parse_args()
272
+
273
+ # Create output directory
274
+ os.makedirs(args.output_dir, exist_ok=True)
275
+
276
+ # Load T5
277
+ t5_model, tokenizer = load_t5_encoder(args.model_name, args.device)
278
+
279
+ print(f"\nT5 Configuration:")
280
+ print(f" Model: {args.model_name}")
281
+ print(f" Hidden size: {t5_model.config.d_model}")
282
+ print(f" Max length: {args.max_length}")
283
+ print(f" Vocab size: {tokenizer.vocab_size}")
284
+
285
+ # Export
286
+ encoder_path = os.path.join(args.output_dir, "t5_encoder.onnx")
287
+ export_t5_encoder(
288
+ t5_model,
289
+ tokenizer,
290
+ encoder_path,
291
+ opset_version=args.opset,
292
+ max_length=args.max_length,
293
+ device=args.device,
294
+ )
295
+
296
+ # Save tokenizer
297
+ if args.save_tokenizer:
298
+ save_tokenizer_config(tokenizer, args.output_dir)
299
+
300
+ # Verify
301
+ if args.verify:
302
+ verify_t5_encoder(
303
+ t5_model,
304
+ tokenizer,
305
+ encoder_path,
306
+ max_length=args.max_length,
307
+ device=args.device,
308
+ tolerance=args.tolerance,
309
+ )
310
+
311
+ print(f"\n✓ Export complete! Model saved to {encoder_path}")
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
onnx_export/standalone_config.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone configuration classes for ONNX export.
3
+
4
+ These are copied from sam_audio/model/config.py but without the problematic
5
+ imports that require the 'perception_models' library.
6
+ """
7
+
8
+ from typing import Optional
9
+ import numpy as np
10
+
11
+
12
+ class DACVAEConfig:
13
+ def __init__(
14
+ self,
15
+ encoder_dim: int = 64,
16
+ encoder_rates: list[int] = [2, 8, 10, 12],
17
+ latent_dim: int = 1024,
18
+ decoder_dim: int = 1536,
19
+ decoder_rates: list[int] = [12, 10, 8, 2],
20
+ n_codebooks: int = 16,
21
+ codebook_size: int = 1024,
22
+ codebook_dim: int = 128,
23
+ quantizer_dropout: bool = False,
24
+ sample_rate: int = 48_000,
25
+ mean: float = 0.0,
26
+ std: float = 1.0,
27
+ ):
28
+ self.encoder_dim = encoder_dim
29
+ self.encoder_rates = encoder_rates
30
+ self.latent_dim = latent_dim
31
+ self.decoder_dim = decoder_dim
32
+ self.decoder_rates = decoder_rates
33
+ self.n_codebooks = n_codebooks
34
+ self.codebook_size = codebook_size
35
+ self.codebook_dim = codebook_dim
36
+ self.quantizer_dropout = quantizer_dropout
37
+ self.sample_rate = sample_rate
38
+ self.mean = mean
39
+ self.std = std
40
+
41
+ @property
42
+ def hop_length(self):
43
+ return int(np.prod(self.encoder_rates))
44
+
45
+
46
+ class T5EncoderConfig:
47
+ def __init__(
48
+ self,
49
+ name: str = "t5-base",
50
+ max_length: Optional[int] = 512,
51
+ pad_mode: str = "longest",
52
+ dim: int = 768,
53
+ ):
54
+ self.dim = dim
55
+ self.name = name
56
+ self.max_length = max_length
57
+ self.pad_mode = pad_mode
58
+
59
+
60
+ class TransformerConfig:
61
+ """Configuration for the DiT transformer."""
62
+
63
+ def __init__(
64
+ self,
65
+ dim: int = 2048,
66
+ n_heads: int = 16,
67
+ n_layers: int = 16,
68
+ dropout: float = 0.1,
69
+ norm_eps: float = 1.0e-05,
70
+ qk_norm: bool = True,
71
+ fc_bias: bool = False,
72
+ ffn_exp: int = 4,
73
+ ffn_dim_multiplier: int = 1,
74
+ multiple_of: int = 64,
75
+ non_linearity: str = "swiglu",
76
+ use_rope: bool = True,
77
+ max_positions: int = 10000,
78
+ frequency_embedding_dim: int = 256,
79
+ timestep_non_linearity: str = "swiglu",
80
+ t_block_non_linearity: str = "silu",
81
+ t_block_bias: bool = True,
82
+ context_dim: int = 2048,
83
+ context_non_linearity: str = "swiglu",
84
+ context_embedder_dropout: float = 0.0,
85
+ context_norm: bool = False,
86
+ out_channels: int = 256,
87
+ in_channels: Optional[int] = None,
88
+ ):
89
+ self.dim = dim
90
+ self.n_heads = n_heads
91
+ self.n_layers = n_layers
92
+ self.dropout = dropout
93
+ self.norm_eps = norm_eps
94
+ self.qk_norm = qk_norm
95
+ self.fc_bias = fc_bias
96
+ self.ffn_exp = ffn_exp
97
+ self.ffn_dim_multiplier = ffn_dim_multiplier
98
+ self.multiple_of = multiple_of
99
+ self.non_linearity = non_linearity
100
+ self.use_rope = use_rope
101
+ self.max_positions = max_positions
102
+ self.frequency_embedding_dim = frequency_embedding_dim
103
+ self.timestep_non_linearity = timestep_non_linearity
104
+ self.t_block_non_linearity = t_block_non_linearity
105
+ self.t_block_bias = t_block_bias
106
+ self.context_dim = context_dim
107
+ self.context_non_linearity = context_non_linearity
108
+ self.context_embedder_dropout = context_embedder_dropout
109
+ self.context_norm = context_norm
110
+ self.out_channels = out_channels
111
+ self.in_channels = in_channels
112
+
113
+ @property
114
+ def d_model(self):
115
+ """Alias for dim, used in transformer code."""
116
+ return self.dim