matbee commited on
Commit
ba60410
·
verified ·
1 Parent(s): 7ac136a

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,8 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- dacvae_decoder.onnx.data filter=lfs diff=lfs merge=lfs -text
37
- dacvae_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
38
- dit_single_step.onnx.data filter=lfs diff=lfs merge=lfs -text
39
- t5_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
40
  test_audio.wav filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.data filter=lfs diff=lfs merge=lfs -text
 
 
 
37
  test_audio.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -61,6 +61,15 @@ python onnx_inference.py \
61
  --output-video visualization.mp4
62
  ```
63
 
 
 
 
 
 
 
 
 
 
64
  ## Model Specifications
65
 
66
  - **Audio Sample Rate**: 48kHz
@@ -72,13 +81,55 @@ python onnx_inference.py \
72
 
73
  ## Exporting Models
74
 
75
- To re-export the models from PyTorch:
76
 
 
77
  ```bash
78
- python onnx_export/export_all.py --output_dir ./onnx_models
79
- python onnx_export/export_vision.py --output ./onnx_models
80
  ```
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ## License
83
 
84
  SAM-Audio is released under the [CC-BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/). See [original repository](https://huggingface.co/facebook/sam-audio-small) for full terms.
@@ -86,3 +137,4 @@ SAM-Audio is released under the [CC-BY-NC 4.0 license](https://creativecommons.o
86
  ## Acknowledgments
87
 
88
  Original model by [Meta AI Research](https://github.com/facebookresearch/sam-audio).
 
 
61
  --output-video visualization.mp4
62
  ```
63
 
64
+ ### Using a Custom Model Directory
65
+ ```bash
66
+ python onnx_inference.py \
67
+ --video input.mp4 \
68
+ --text "woman speaking" \
69
+ --model-dir ./my_onnx_models \
70
+ --output separated.wav
71
+ ```
72
+
73
  ## Model Specifications
74
 
75
  - **Audio Sample Rate**: 48kHz
 
81
 
82
  ## Exporting Models
83
 
84
+ Export scripts are in the `onnx_export/` directory.
85
 
86
+ ### Export All Models
87
  ```bash
88
+ python -m onnx_export.export_all --output_dir ./onnx_models
 
89
  ```
90
 
91
+ ### Export Individual Components
92
+ ```bash
93
+ # DiT Transformer (supports FP16 for 50% size reduction)
94
+ python -m onnx_export.export_dit --output-dir ./onnx_models --model-id facebook/sam-audio-small
95
+ python -m onnx_export.export_dit --output-dir ./onnx_models --model-id facebook/sam-audio-large --fp16 --device cuda
96
+
97
+ # DACVAE (encoder + decoder)
98
+ python -m onnx_export.export_dacvae --output-dir ./onnx_models --model-id facebook/sam-audio-small
99
+
100
+ # T5 Text Encoder
101
+ python -m onnx_export.export_t5 --output-dir ./onnx_models --model-id facebook/sam-audio-small
102
+
103
+ # Vision Encoder
104
+ python -m onnx_export.export_vision --model facebook/sam-audio-small --output ./onnx_models
105
+ ```
106
+
107
+ ### FP16 Quantization (for large models)
108
+
109
+ For the large model (sam-audio-large), use `--fp16 --device cuda` during DiT export to reduce size by 50%:
110
+
111
+ ```bash
112
+ # Export DiT in FP16 (11.7GB → 5.9GB)
113
+ python -m onnx_export.export_dit \
114
+ --output-dir ./onnx_models_large_fp16 \
115
+ --model-id facebook/sam-audio-large \
116
+ --fp16 \
117
+ --device cuda
118
+ ```
119
+
120
+ The inference script automatically detects FP16 models and handles input conversion.
121
+
122
+ ## Export Scripts Reference
123
+
124
+ | Script | Description |
125
+ |--------|-------------|
126
+ | `export_all.py` | Export all components at once |
127
+ | `export_dit.py` | DiT transformer with FP16 support |
128
+ | `export_dacvae.py` | DACVAE encoder and decoder |
129
+ | `export_t5.py` | T5 text encoder |
130
+ | `export_vision.py` | Vision encoder (CLIP-based) |
131
+ | `standalone_config.py` | Config classes for standalone export |
132
+
133
  ## License
134
 
135
  SAM-Audio is released under the [CC-BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/). See [original repository](https://huggingface.co/facebook/sam-audio-small) for full terms.
 
137
  ## Acknowledgments
138
 
139
  Original model by [Meta AI Research](https://github.com/facebookresearch/sam-audio).
140
+
onnx_export/export_all.py CHANGED
@@ -6,8 +6,7 @@ This script exports:
6
  1. DACVAE encoder and decoder (audio codec)
7
  2. T5 text encoder
8
  3. DiT transformer (single-step for ODE solving)
9
-
10
- Usage:
11
  python -m onnx_export.export_all --output-dir onnx_models --verify
12
  """
13
 
@@ -36,6 +35,12 @@ def main():
36
  default="onnx_models",
37
  help="Output directory for ONNX models",
38
  )
 
 
 
 
 
 
39
  parser.add_argument(
40
  "--verify",
41
  action="store_true",
@@ -56,6 +61,11 @@ def main():
56
  action="store_true",
57
  help="Skip DiT export",
58
  )
 
 
 
 
 
59
 
60
  args = parser.parse_args()
61
 
@@ -65,12 +75,12 @@ def main():
65
 
66
  # Export DACVAE
67
  if not args.skip_dacvae:
68
- export_args = ["--output-dir", args.output_dir]
69
  if args.verify:
70
  export_args.append("--verify")
71
  results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
72
 
73
- # Export T5
74
  if not args.skip_t5:
75
  export_args = ["--output-dir", args.output_dir]
76
  if args.verify:
@@ -79,11 +89,16 @@ def main():
79
 
80
  # Export DiT
81
  if not args.skip_dit:
82
- export_args = ["--output-dir", args.output_dir]
83
  if args.verify:
84
  export_args.append("--verify")
85
  results["DiT"] = run_export("onnx_export.export_dit", export_args)
86
 
 
 
 
 
 
87
  # Print summary
88
  print(f"\n{'='*60}")
89
  print("Export Summary")
 
6
  1. DACVAE encoder and decoder (audio codec)
7
  2. T5 text encoder
8
  3. DiT transformer (single-step for ODE solving)
9
+ 4. Vision encoder (CLIP-based, for video-guided separation)
 
10
  python -m onnx_export.export_all --output-dir onnx_models --verify
11
  """
12
 
 
35
  default="onnx_models",
36
  help="Output directory for ONNX models",
37
  )
38
+ parser.add_argument(
39
+ "--model",
40
+ type=str,
41
+ default="facebook/sam-audio-small",
42
+ help="SAM-Audio model ID (e.g., facebook/sam-audio-small, facebook/sam-audio-large, facebook/sam-audio-base-tv)",
43
+ )
44
  parser.add_argument(
45
  "--verify",
46
  action="store_true",
 
61
  action="store_true",
62
  help="Skip DiT export",
63
  )
64
+ parser.add_argument(
65
+ "--skip-vision",
66
+ action="store_true",
67
+ help="Skip Vision encoder export",
68
+ )
69
 
70
  args = parser.parse_args()
71
 
 
75
 
76
  # Export DACVAE
77
  if not args.skip_dacvae:
78
+ export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
79
  if args.verify:
80
  export_args.append("--verify")
81
  results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
82
 
83
+ # Export T5 (always uses google-t5/t5-base, independent of SAM-Audio model)
84
  if not args.skip_t5:
85
  export_args = ["--output-dir", args.output_dir]
86
  if args.verify:
 
89
 
90
  # Export DiT
91
  if not args.skip_dit:
92
+ export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
93
  if args.verify:
94
  export_args.append("--verify")
95
  results["DiT"] = run_export("onnx_export.export_dit", export_args)
96
 
97
+ # Export Vision Encoder
98
+ if not args.skip_vision:
99
+ export_args = ["--output", args.output_dir, "--model", args.model]
100
+ results["Vision"] = run_export("onnx_export.export_vision", export_args)
101
+
102
  # Print summary
103
  print(f"\n{'='*60}")
104
  print("Export Summary")
onnx_export/export_dacvae.py CHANGED
@@ -143,7 +143,7 @@ def create_dacvae_model(model_id: str = "facebook/sam-audio-small") -> dacvae.DA
143
  def export_encoder(
144
  dacvae_model: dacvae.DACVAE,
145
  output_path: str,
146
- opset_version: int = 18,
147
  device: str = "cpu",
148
  ) -> None:
149
  """Export DACVAE encoder to ONNX."""
@@ -178,15 +178,16 @@ def export_encoder(
178
 
179
  # Validate
180
  import onnx
181
- model = onnx.load(output_path)
182
- onnx.checker.check_model(model)
 
183
  print(f" ✓ ONNX model validation passed")
184
 
185
 
186
  def export_decoder(
187
  dacvae_model: dacvae.DACVAE,
188
  output_path: str,
189
- opset_version: int = 18,
190
  device: str = "cpu",
191
  ) -> None:
192
  """Export DACVAE decoder to ONNX."""
@@ -222,8 +223,9 @@ def export_decoder(
222
 
223
  # Validate
224
  import onnx
225
- model = onnx.load(output_path)
226
- onnx.checker.check_model(model)
 
227
  print(f" ✓ ONNX model validation passed")
228
 
229
 
 
143
  def export_encoder(
144
  dacvae_model: dacvae.DACVAE,
145
  output_path: str,
146
+ opset_version: int = 21,
147
  device: str = "cpu",
148
  ) -> None:
149
  """Export DACVAE encoder to ONNX."""
 
178
 
179
  # Validate
180
  import onnx
181
+ # Load without external data to avoid OOM - we just need to validate structure
182
+ model = onnx.load(output_path, load_external_data=False)
183
+ onnx.checker.check_model(model, full_check=False)
184
  print(f" ✓ ONNX model validation passed")
185
 
186
 
187
  def export_decoder(
188
  dacvae_model: dacvae.DACVAE,
189
  output_path: str,
190
+ opset_version: int = 21,
191
  device: str = "cpu",
192
  ) -> None:
193
  """Export DACVAE decoder to ONNX."""
 
223
 
224
  # Validate
225
  import onnx
226
+ # Load without external data to avoid OOM - we just need to validate structure
227
+ model = onnx.load(output_path, load_external_data=False)
228
+ onnx.checker.check_model(model, full_check=False)
229
  print(f" ✓ ONNX model validation passed")
230
 
231
 
onnx_export/export_dit.py CHANGED
@@ -371,16 +371,28 @@ def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "
371
  def export_dit_single_step(
372
  single_step: DiTSingleStepWrapper,
373
  output_path: str,
374
- opset_version: int = 18,
375
  device: str = "cpu",
 
376
  ):
377
  """Export single-step DiT to ONNX (for runtime ODE solving)."""
378
  import onnx
379
 
380
  print(f"Exporting DiT single-step to {output_path}...")
381
 
 
 
 
 
 
382
  sample_inputs = create_sample_inputs(device=device)
383
 
 
 
 
 
 
 
384
  torch.onnx.export(
385
  single_step,
386
  tuple(sample_inputs.values()),
@@ -407,9 +419,19 @@ def export_dit_single_step(
407
 
408
  print(" ✓ DiT single-step exported successfully")
409
 
410
- model = onnx.load(output_path)
411
- onnx.checker.check_model(model)
412
- print(" ✓ ONNX model validation passed")
 
 
 
 
 
 
 
 
 
 
413
 
414
  return True
415
 
@@ -484,8 +506,8 @@ def main():
484
  parser.add_argument(
485
  "--opset",
486
  type=int,
487
- default=18,
488
- help="ONNX opset version (default: 18)",
489
  )
490
  parser.add_argument(
491
  "--device",
@@ -504,6 +526,11 @@ def main():
504
  default=1e-3,
505
  help="Tolerance for verification (default: 1e-3)",
506
  )
 
 
 
 
 
507
 
508
  args = parser.parse_args()
509
 
@@ -525,8 +552,12 @@ def main():
525
  single_step_path,
526
  opset_version=args.opset,
527
  device=args.device,
 
528
  )
529
 
 
 
 
530
  # Verify single-step
531
  if args.verify:
532
  verify_dit_single_step(
 
371
  def export_dit_single_step(
372
  single_step: DiTSingleStepWrapper,
373
  output_path: str,
374
+ opset_version: int = 21,
375
  device: str = "cpu",
376
+ fp16: bool = False,
377
  ):
378
  """Export single-step DiT to ONNX (for runtime ODE solving)."""
379
  import onnx
380
 
381
  print(f"Exporting DiT single-step to {output_path}...")
382
 
383
+ # Convert to FP16 if requested
384
+ if fp16:
385
+ print(" Converting model to FP16...")
386
+ single_step = single_step.half()
387
+
388
  sample_inputs = create_sample_inputs(device=device)
389
 
390
+ # Convert float inputs to FP16 if exporting in FP16
391
+ if fp16:
392
+ for key, value in sample_inputs.items():
393
+ if value.dtype == torch.float32:
394
+ sample_inputs[key] = value.half()
395
+
396
  torch.onnx.export(
397
  single_step,
398
  tuple(sample_inputs.values()),
 
419
 
420
  print(" ✓ DiT single-step exported successfully")
421
 
422
+ # When using external_data=True, we can't run check_model on a model
423
+ # loaded without external data - the checker validates data references.
424
+ # Since torch.onnx.export with dynamo=True already validates the model,
425
+ # we just verify the files exist.
426
+ external_data_path = output_path + ".data"
427
+ if os.path.exists(external_data_path):
428
+ print(f" ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)")
429
+ else:
430
+ raise RuntimeError(f"External data file missing: {external_data_path}")
431
+
432
+ # Verify the ONNX file structure is valid (without loading weights)
433
+ model = onnx.load(output_path, load_external_data=False)
434
+ print(f" ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)")
435
 
436
  return True
437
 
 
506
  parser.add_argument(
507
  "--opset",
508
  type=int,
509
+ default=21,
510
+ help="ONNX opset version (default: 21)",
511
  )
512
  parser.add_argument(
513
  "--device",
 
526
  default=1e-3,
527
  help="Tolerance for verification (default: 1e-3)",
528
  )
529
+ parser.add_argument(
530
+ "--fp16",
531
+ action="store_true",
532
+ help="Export model in FP16 precision (half the size)",
533
+ )
534
 
535
  args = parser.parse_args()
536
 
 
552
  single_step_path,
553
  opset_version=args.opset,
554
  device=args.device,
555
+ fp16=args.fp16,
556
  )
557
 
558
+ if args.fp16:
559
+ print(f" ✓ Model exported in FP16 precision")
560
+
561
  # Verify single-step
562
  if args.verify:
563
  verify_dit_single_step(
onnx_export/export_peaframe.py CHANGED
@@ -99,7 +99,7 @@ def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"):
99
  def export_peaframe(
100
  model: nn.Module,
101
  output_path: str,
102
- opset_version: int = 18,
103
  device: str = "cpu",
104
  ):
105
  """Export PE-A-Frame to ONNX."""
@@ -165,9 +165,9 @@ def export_peaframe(
165
 
166
  print(" ✓ PE-A-Frame exported successfully")
167
 
168
- # Validate
169
- onnx_model = onnx.load(output_path)
170
- onnx.checker.check_model(onnx_model)
171
  print(" ✓ ONNX model validation passed")
172
 
173
  return True
 
99
  def export_peaframe(
100
  model: nn.Module,
101
  output_path: str,
102
+ opset_version: int = 21,
103
  device: str = "cpu",
104
  ):
105
  """Export PE-A-Frame to ONNX."""
 
165
 
166
  print(" ✓ PE-A-Frame exported successfully")
167
 
168
+ # Load without external data to avoid OOM - we just need to validate structure
169
+ onnx_model = onnx.load(output_path, load_external_data=False)
170
+ onnx.checker.check_model(onnx_model, full_check=False)
171
  print(" ✓ ONNX model validation passed")
172
 
173
  return True
onnx_export/export_t5.py CHANGED
@@ -50,7 +50,7 @@ class T5EncoderWrapper(nn.Module):
50
  return outputs.last_hidden_state
51
 
52
 
53
- def load_t5_encoder(model_name: str = "google-t5/t5-base", device: str = "cpu"):
54
  """
55
  Load T5 encoder model and tokenizer.
56
 
@@ -72,9 +72,9 @@ def export_t5_encoder(
72
  t5_model,
73
  tokenizer,
74
  output_path: str,
75
- opset_version: int = 18,
76
  max_length: int = 77,
77
- device: str = "cpu",
78
  ):
79
  """Export T5 encoder to ONNX format."""
80
  import onnx
@@ -116,9 +116,9 @@ def export_t5_encoder(
116
 
117
  print(" ✓ T5 encoder exported successfully")
118
 
119
- # Validate the model
120
- model = onnx.load(output_path)
121
- onnx.checker.check_model(model)
122
  print(" ✓ ONNX model validation passed")
123
 
124
  return True
@@ -129,7 +129,7 @@ def verify_t5_encoder(
129
  tokenizer,
130
  onnx_path: str,
131
  max_length: int = 77,
132
- device: str = "cpu",
133
  tolerance: float = 1e-4,
134
  ) -> bool:
135
  """Verify ONNX T5 encoder output matches PyTorch."""
@@ -165,7 +165,7 @@ def verify_t5_encoder(
165
  pytorch_output = wrapper(input_ids, attention_mask).cpu().numpy()
166
 
167
  # ONNX Runtime output
168
- sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
169
  onnx_output = sess.run(
170
  ["hidden_states"],
171
  {
@@ -247,8 +247,8 @@ def main():
247
  parser.add_argument(
248
  "--device",
249
  type=str,
250
- default="cpu",
251
- help="Device to use for export (default: cpu)",
252
  )
253
  parser.add_argument(
254
  "--verify",
 
50
  return outputs.last_hidden_state
51
 
52
 
53
+ def load_t5_encoder(model_name: str = "google-t5/t5-base", device: str = "cuda"):
54
  """
55
  Load T5 encoder model and tokenizer.
56
 
 
72
  t5_model,
73
  tokenizer,
74
  output_path: str,
75
+ opset_version: int = 21,
76
  max_length: int = 77,
77
+ device: str = "cuda",
78
  ):
79
  """Export T5 encoder to ONNX format."""
80
  import onnx
 
116
 
117
  print(" ✓ T5 encoder exported successfully")
118
 
119
+ # Load without external data to avoid OOM - we just need to validate structure
120
+ model = onnx.load(output_path, load_external_data=False)
121
+ onnx.checker.check_model(model, full_check=False)
122
  print(" ✓ ONNX model validation passed")
123
 
124
  return True
 
129
  tokenizer,
130
  onnx_path: str,
131
  max_length: int = 77,
132
+ device: str = "cuda",
133
  tolerance: float = 1e-4,
134
  ) -> bool:
135
  """Verify ONNX T5 encoder output matches PyTorch."""
 
165
  pytorch_output = wrapper(input_ids, attention_mask).cpu().numpy()
166
 
167
  # ONNX Runtime output
168
+ sess = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
169
  onnx_output = sess.run(
170
  ["hidden_states"],
171
  {
 
247
  parser.add_argument(
248
  "--device",
249
  type=str,
250
+ default="cuda",
251
+ help="Device to use for export (default: cuda)",
252
  )
253
  parser.add_argument(
254
  "--verify",
onnx_export/export_vision.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import onnx
6
+ from sam_audio.model.vision_encoder import PerceptionEncoder
7
+ from onnx_export.standalone_config import PerceptionEncoderConfig
8
+
9
+ class VisionEncoderWrapper(nn.Module):
10
+ """
11
+ Wrapper for the Vision Encoder (CLIP visual backbone).
12
+ """
13
+ def __init__(self, vision_encoder):
14
+ super().__init__()
15
+ self.model = vision_encoder.model
16
+ self.normalize = vision_encoder.normalize_feature
17
+
18
+ def forward(self, x):
19
+ # x: (N, 3, H, W) where N is number of frames
20
+ # returns: (N, 1024) features
21
+ return self.model.encode_image(x, normalize=self.normalize)
22
+
23
+ def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models", device="cpu"):
24
+ """Export the vision encoder to ONNX."""
25
+ print(f"Loading Vision Encoder from {model_id}...")
26
+
27
+ import torch
28
+ from transformers import AutoConfig
29
+ from sam_audio.model.vision_encoder import PerceptionEncoder
30
+ from onnx_export.standalone_config import PerceptionEncoderConfig
31
+
32
+ print("Fetching config...")
33
+ cfg_hf = AutoConfig.from_pretrained(model_id)
34
+ cfg_dict = cfg_hf.to_dict()
35
+
36
+ # Extract vision encoder config
37
+ v_cfg_dict = cfg_dict.get("vision_encoder", {})
38
+ v_cfg = PerceptionEncoderConfig(**v_cfg_dict)
39
+
40
+ print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...")
41
+ vision_encoder = PerceptionEncoder(v_cfg)
42
+
43
+ # Load weights from checkpoint
44
+ print("Loading weights from SAM Audio checkpoint...")
45
+ from huggingface_hub import hf_hub_download
46
+ checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
47
+ state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
48
+
49
+ # Filter vision encoder weights
50
+ vision_state = {}
51
+ prefix = "vision_encoder."
52
+ for key, value in state_dict.items():
53
+ if key.startswith(prefix):
54
+ new_key = key[len(prefix):]
55
+ vision_state[new_key] = value
56
+
57
+ if vision_state:
58
+ print(f" Loading {len(vision_state)} tensors into vision encoder...")
59
+ vision_encoder.load_state_dict(vision_state)
60
+ print(" ✓ Vision encoder weights loaded.")
61
+ else:
62
+ print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.")
63
+
64
+ image_size = vision_encoder.image_size
65
+ print(f" Image size: {image_size}")
66
+
67
+
68
+ wrapper = VisionEncoderWrapper(vision_encoder).eval().to(device)
69
+
70
+ # Create dummy input on device
71
+ image_size = vision_encoder.image_size
72
+ dummy_input = torch.randn(1, 3, image_size, image_size, device=device)
73
+
74
+ output_path = os.path.join(output_dir, "vision_encoder.onnx")
75
+ os.makedirs(output_dir, exist_ok=True)
76
+
77
+ print(f"Exporting to {output_path}...")
78
+ input_names = ["video_frames"]
79
+ output_names = ["vision_features"]
80
+ opset_version = 18 # Use opset 18 for better CUDA compatibility
81
+ torch.onnx.export(
82
+ wrapper,
83
+ dummy_input,
84
+ output_path,
85
+ input_names=input_names,
86
+ output_names=output_names,
87
+ dynamic_axes={
88
+ "video_frames": {0: "num_frames"},
89
+ "vision_features": {0: "num_frames"},
90
+ },
91
+ opset_version=opset_version,
92
+ do_constant_folding=True,
93
+ dynamo=True,
94
+ external_data=True,
95
+ )
96
+
97
+ # Check if data was saved separately
98
+ data_path = output_path + ".data"
99
+ if os.path.exists(data_path):
100
+ print(f" Large model detected, weights saved to {data_path}")
101
+
102
+ print("✓ Vision encoder export complete!")
103
+
104
+ if __name__ == "__main__":
105
+ import argparse
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--model", type=str, default="facebook/sam-audio-small")
108
+ parser.add_argument("--output", type=str, default="onnx_models")
109
+ parser.add_argument("--device", type=str, default="cpu", help="Device for export (cpu or cuda)")
110
+ args = parser.parse_args()
111
+
112
+ export_vision_encoder(args.model, args.output, device=args.device)
113
+
onnx_export/standalone_config.py CHANGED
@@ -57,6 +57,29 @@ class T5EncoderConfig:
57
  self.pad_mode = pad_mode
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  class TransformerConfig:
61
  """Configuration for the DiT transformer."""
62
 
 
57
  self.pad_mode = pad_mode
58
 
59
 
60
+ class VisionEncoderConfig:
61
+ def __init__(self, dim: int = 1024, batch_size: int = 300):
62
+ self.dim = dim
63
+ self.batch_size = batch_size
64
+
65
+
66
+ class PerceptionEncoderConfig(VisionEncoderConfig):
67
+ def __init__(
68
+ self,
69
+ dim: int = 1024,
70
+ batch_size: int = 300,
71
+ name: str = "PE-Core-L14-336",
72
+ normalize_feature: bool = True,
73
+ interpolation_mode: str = "BICUBIC",
74
+ image_size: int = 336,
75
+ ):
76
+ super().__init__(dim=dim, batch_size=batch_size)
77
+ self.name = name
78
+ self.normalize_feature = normalize_feature
79
+ self.interpolation_mode = interpolation_mode
80
+ self.image_size = image_size
81
+
82
+
83
  class TransformerConfig:
84
  """Configuration for the DiT transformer."""
85
 
onnx_inference.py CHANGED
@@ -377,6 +377,11 @@ class SAMAudioONNXPipeline:
377
  batch_size = noisy_audio.shape[0]
378
  seq_len = noisy_audio.shape[1]
379
 
 
 
 
 
 
380
  # Prepare placeholders for anchors if not used
381
  # anchor_ids: <null>=0, <pad>=3. [B, 2]
382
  anchor_ids = np.zeros((batch_size, 2), dtype=np.int64)
@@ -392,15 +397,15 @@ class SAMAudioONNXPipeline:
392
  if masked_video_features is None:
393
  # Vision dimension is 1024 for small
394
  vision_dim = 1024
395
- masked_video_features = np.zeros((batch_size, vision_dim, seq_len), dtype=np.float32)
396
 
397
  inputs = {
398
- "noisy_audio": noisy_audio.astype(np.float32),
399
- "time": np.array([time], dtype=np.float32),
400
- "audio_features": audio_features.astype(np.float32),
401
- "text_features": text_features.astype(np.float32),
402
  "text_mask": text_mask.astype(np.bool_),
403
- "masked_video_features": masked_video_features.astype(np.float32),
404
  "anchor_ids": anchor_ids.astype(np.int64),
405
  "anchor_alignment": anchor_alignment.astype(np.int64),
406
  "audio_pad_mask": audio_pad_mask.astype(np.bool_),
 
377
  batch_size = noisy_audio.shape[0]
378
  seq_len = noisy_audio.shape[1]
379
 
380
+ # Detect if model expects FP16 inputs
381
+ first_input = self.dit.get_inputs()[0]
382
+ use_fp16 = first_input.type == 'tensor(float16)'
383
+ float_dtype = np.float16 if use_fp16 else np.float32
384
+
385
  # Prepare placeholders for anchors if not used
386
  # anchor_ids: <null>=0, <pad>=3. [B, 2]
387
  anchor_ids = np.zeros((batch_size, 2), dtype=np.int64)
 
397
  if masked_video_features is None:
398
  # Vision dimension is 1024 for small
399
  vision_dim = 1024
400
+ masked_video_features = np.zeros((batch_size, vision_dim, seq_len), dtype=float_dtype)
401
 
402
  inputs = {
403
+ "noisy_audio": noisy_audio.astype(float_dtype),
404
+ "time": np.array([time], dtype=float_dtype),
405
+ "audio_features": audio_features.astype(float_dtype),
406
+ "text_features": text_features.astype(float_dtype),
407
  "text_mask": text_mask.astype(np.bool_),
408
+ "masked_video_features": masked_video_features.astype(float_dtype),
409
  "anchor_ids": anchor_ids.astype(np.int64),
410
  "anchor_alignment": anchor_alignment.astype(np.int64),
411
  "audio_pad_mask": audio_pad_mask.astype(np.bool_),