Upload folder using huggingface_hub
Browse files- .gitattributes +1 -4
- README.md +55 -3
- onnx_export/export_all.py +20 -5
- onnx_export/export_dacvae.py +8 -6
- onnx_export/export_dit.py +37 -6
- onnx_export/export_peaframe.py +4 -4
- onnx_export/export_t5.py +10 -10
- onnx_export/export_vision.py +113 -0
- onnx_export/standalone_config.py +23 -0
- onnx_inference.py +11 -6
.gitattributes
CHANGED
|
@@ -33,8 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
-
|
| 37 |
-
dacvae_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
|
| 38 |
-
dit_single_step.onnx.data filter=lfs diff=lfs merge=lfs -text
|
| 39 |
-
t5_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
|
| 40 |
test_audio.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.data filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
| 37 |
test_audio.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -61,6 +61,15 @@ python onnx_inference.py \
|
|
| 61 |
--output-video visualization.mp4
|
| 62 |
```
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
## Model Specifications
|
| 65 |
|
| 66 |
- **Audio Sample Rate**: 48kHz
|
|
@@ -72,13 +81,55 @@ python onnx_inference.py \
|
|
| 72 |
|
| 73 |
## Exporting Models
|
| 74 |
|
| 75 |
-
|
| 76 |
|
|
|
|
| 77 |
```bash
|
| 78 |
-
python onnx_export
|
| 79 |
-
python onnx_export/export_vision.py --output ./onnx_models
|
| 80 |
```
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
## License
|
| 83 |
|
| 84 |
SAM-Audio is released under the [CC-BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/). See [original repository](https://huggingface.co/facebook/sam-audio-small) for full terms.
|
|
@@ -86,3 +137,4 @@ SAM-Audio is released under the [CC-BY-NC 4.0 license](https://creativecommons.o
|
|
| 86 |
## Acknowledgments
|
| 87 |
|
| 88 |
Original model by [Meta AI Research](https://github.com/facebookresearch/sam-audio).
|
|
|
|
|
|
| 61 |
--output-video visualization.mp4
|
| 62 |
```
|
| 63 |
|
| 64 |
+
### Using a Custom Model Directory
|
| 65 |
+
```bash
|
| 66 |
+
python onnx_inference.py \
|
| 67 |
+
--video input.mp4 \
|
| 68 |
+
--text "woman speaking" \
|
| 69 |
+
--model-dir ./my_onnx_models \
|
| 70 |
+
--output separated.wav
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
## Model Specifications
|
| 74 |
|
| 75 |
- **Audio Sample Rate**: 48kHz
|
|
|
|
| 81 |
|
| 82 |
## Exporting Models
|
| 83 |
|
| 84 |
+
Export scripts are in the `onnx_export/` directory.
|
| 85 |
|
| 86 |
+
### Export All Models
|
| 87 |
```bash
|
| 88 |
+
python -m onnx_export.export_all --output_dir ./onnx_models
|
|
|
|
| 89 |
```
|
| 90 |
|
| 91 |
+
### Export Individual Components
|
| 92 |
+
```bash
|
| 93 |
+
# DiT Transformer (supports FP16 for 50% size reduction)
|
| 94 |
+
python -m onnx_export.export_dit --output-dir ./onnx_models --model-id facebook/sam-audio-small
|
| 95 |
+
python -m onnx_export.export_dit --output-dir ./onnx_models --model-id facebook/sam-audio-large --fp16 --device cuda
|
| 96 |
+
|
| 97 |
+
# DACVAE (encoder + decoder)
|
| 98 |
+
python -m onnx_export.export_dacvae --output-dir ./onnx_models --model-id facebook/sam-audio-small
|
| 99 |
+
|
| 100 |
+
# T5 Text Encoder
|
| 101 |
+
python -m onnx_export.export_t5 --output-dir ./onnx_models --model-id facebook/sam-audio-small
|
| 102 |
+
|
| 103 |
+
# Vision Encoder
|
| 104 |
+
python -m onnx_export.export_vision --model facebook/sam-audio-small --output ./onnx_models
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
### FP16 Quantization (for large models)
|
| 108 |
+
|
| 109 |
+
For the large model (sam-audio-large), use `--fp16 --device cuda` during DiT export to reduce size by 50%:
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
# Export DiT in FP16 (11.7GB → 5.9GB)
|
| 113 |
+
python -m onnx_export.export_dit \
|
| 114 |
+
--output-dir ./onnx_models_large_fp16 \
|
| 115 |
+
--model-id facebook/sam-audio-large \
|
| 116 |
+
--fp16 \
|
| 117 |
+
--device cuda
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
The inference script automatically detects FP16 models and handles input conversion.
|
| 121 |
+
|
| 122 |
+
## Export Scripts Reference
|
| 123 |
+
|
| 124 |
+
| Script | Description |
|
| 125 |
+
|--------|-------------|
|
| 126 |
+
| `export_all.py` | Export all components at once |
|
| 127 |
+
| `export_dit.py` | DiT transformer with FP16 support |
|
| 128 |
+
| `export_dacvae.py` | DACVAE encoder and decoder |
|
| 129 |
+
| `export_t5.py` | T5 text encoder |
|
| 130 |
+
| `export_vision.py` | Vision encoder (CLIP-based) |
|
| 131 |
+
| `standalone_config.py` | Config classes for standalone export |
|
| 132 |
+
|
| 133 |
## License
|
| 134 |
|
| 135 |
SAM-Audio is released under the [CC-BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/). See [original repository](https://huggingface.co/facebook/sam-audio-small) for full terms.
|
|
|
|
| 137 |
## Acknowledgments
|
| 138 |
|
| 139 |
Original model by [Meta AI Research](https://github.com/facebookresearch/sam-audio).
|
| 140 |
+
|
onnx_export/export_all.py
CHANGED
|
@@ -6,8 +6,7 @@ This script exports:
|
|
| 6 |
1. DACVAE encoder and decoder (audio codec)
|
| 7 |
2. T5 text encoder
|
| 8 |
3. DiT transformer (single-step for ODE solving)
|
| 9 |
-
|
| 10 |
-
Usage:
|
| 11 |
python -m onnx_export.export_all --output-dir onnx_models --verify
|
| 12 |
"""
|
| 13 |
|
|
@@ -36,6 +35,12 @@ def main():
|
|
| 36 |
default="onnx_models",
|
| 37 |
help="Output directory for ONNX models",
|
| 38 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
parser.add_argument(
|
| 40 |
"--verify",
|
| 41 |
action="store_true",
|
|
@@ -56,6 +61,11 @@ def main():
|
|
| 56 |
action="store_true",
|
| 57 |
help="Skip DiT export",
|
| 58 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
args = parser.parse_args()
|
| 61 |
|
|
@@ -65,12 +75,12 @@ def main():
|
|
| 65 |
|
| 66 |
# Export DACVAE
|
| 67 |
if not args.skip_dacvae:
|
| 68 |
-
export_args = ["--output-dir", args.output_dir]
|
| 69 |
if args.verify:
|
| 70 |
export_args.append("--verify")
|
| 71 |
results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
|
| 72 |
|
| 73 |
-
# Export T5
|
| 74 |
if not args.skip_t5:
|
| 75 |
export_args = ["--output-dir", args.output_dir]
|
| 76 |
if args.verify:
|
|
@@ -79,11 +89,16 @@ def main():
|
|
| 79 |
|
| 80 |
# Export DiT
|
| 81 |
if not args.skip_dit:
|
| 82 |
-
export_args = ["--output-dir", args.output_dir]
|
| 83 |
if args.verify:
|
| 84 |
export_args.append("--verify")
|
| 85 |
results["DiT"] = run_export("onnx_export.export_dit", export_args)
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# Print summary
|
| 88 |
print(f"\n{'='*60}")
|
| 89 |
print("Export Summary")
|
|
|
|
| 6 |
1. DACVAE encoder and decoder (audio codec)
|
| 7 |
2. T5 text encoder
|
| 8 |
3. DiT transformer (single-step for ODE solving)
|
| 9 |
+
4. Vision encoder (CLIP-based, for video-guided separation)
|
|
|
|
| 10 |
python -m onnx_export.export_all --output-dir onnx_models --verify
|
| 11 |
"""
|
| 12 |
|
|
|
|
| 35 |
default="onnx_models",
|
| 36 |
help="Output directory for ONNX models",
|
| 37 |
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--model",
|
| 40 |
+
type=str,
|
| 41 |
+
default="facebook/sam-audio-small",
|
| 42 |
+
help="SAM-Audio model ID (e.g., facebook/sam-audio-small, facebook/sam-audio-large, facebook/sam-audio-base-tv)",
|
| 43 |
+
)
|
| 44 |
parser.add_argument(
|
| 45 |
"--verify",
|
| 46 |
action="store_true",
|
|
|
|
| 61 |
action="store_true",
|
| 62 |
help="Skip DiT export",
|
| 63 |
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--skip-vision",
|
| 66 |
+
action="store_true",
|
| 67 |
+
help="Skip Vision encoder export",
|
| 68 |
+
)
|
| 69 |
|
| 70 |
args = parser.parse_args()
|
| 71 |
|
|
|
|
| 75 |
|
| 76 |
# Export DACVAE
|
| 77 |
if not args.skip_dacvae:
|
| 78 |
+
export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
|
| 79 |
if args.verify:
|
| 80 |
export_args.append("--verify")
|
| 81 |
results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
|
| 82 |
|
| 83 |
+
# Export T5 (always uses google-t5/t5-base, independent of SAM-Audio model)
|
| 84 |
if not args.skip_t5:
|
| 85 |
export_args = ["--output-dir", args.output_dir]
|
| 86 |
if args.verify:
|
|
|
|
| 89 |
|
| 90 |
# Export DiT
|
| 91 |
if not args.skip_dit:
|
| 92 |
+
export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
|
| 93 |
if args.verify:
|
| 94 |
export_args.append("--verify")
|
| 95 |
results["DiT"] = run_export("onnx_export.export_dit", export_args)
|
| 96 |
|
| 97 |
+
# Export Vision Encoder
|
| 98 |
+
if not args.skip_vision:
|
| 99 |
+
export_args = ["--output", args.output_dir, "--model", args.model]
|
| 100 |
+
results["Vision"] = run_export("onnx_export.export_vision", export_args)
|
| 101 |
+
|
| 102 |
# Print summary
|
| 103 |
print(f"\n{'='*60}")
|
| 104 |
print("Export Summary")
|
onnx_export/export_dacvae.py
CHANGED
|
@@ -143,7 +143,7 @@ def create_dacvae_model(model_id: str = "facebook/sam-audio-small") -> dacvae.DA
|
|
| 143 |
def export_encoder(
|
| 144 |
dacvae_model: dacvae.DACVAE,
|
| 145 |
output_path: str,
|
| 146 |
-
opset_version: int =
|
| 147 |
device: str = "cpu",
|
| 148 |
) -> None:
|
| 149 |
"""Export DACVAE encoder to ONNX."""
|
|
@@ -178,15 +178,16 @@ def export_encoder(
|
|
| 178 |
|
| 179 |
# Validate
|
| 180 |
import onnx
|
| 181 |
-
|
| 182 |
-
onnx.
|
|
|
|
| 183 |
print(f" ✓ ONNX model validation passed")
|
| 184 |
|
| 185 |
|
| 186 |
def export_decoder(
|
| 187 |
dacvae_model: dacvae.DACVAE,
|
| 188 |
output_path: str,
|
| 189 |
-
opset_version: int =
|
| 190 |
device: str = "cpu",
|
| 191 |
) -> None:
|
| 192 |
"""Export DACVAE decoder to ONNX."""
|
|
@@ -222,8 +223,9 @@ def export_decoder(
|
|
| 222 |
|
| 223 |
# Validate
|
| 224 |
import onnx
|
| 225 |
-
|
| 226 |
-
onnx.
|
|
|
|
| 227 |
print(f" ✓ ONNX model validation passed")
|
| 228 |
|
| 229 |
|
|
|
|
| 143 |
def export_encoder(
|
| 144 |
dacvae_model: dacvae.DACVAE,
|
| 145 |
output_path: str,
|
| 146 |
+
opset_version: int = 21,
|
| 147 |
device: str = "cpu",
|
| 148 |
) -> None:
|
| 149 |
"""Export DACVAE encoder to ONNX."""
|
|
|
|
| 178 |
|
| 179 |
# Validate
|
| 180 |
import onnx
|
| 181 |
+
# Load without external data to avoid OOM - we just need to validate structure
|
| 182 |
+
model = onnx.load(output_path, load_external_data=False)
|
| 183 |
+
onnx.checker.check_model(model, full_check=False)
|
| 184 |
print(f" ✓ ONNX model validation passed")
|
| 185 |
|
| 186 |
|
| 187 |
def export_decoder(
|
| 188 |
dacvae_model: dacvae.DACVAE,
|
| 189 |
output_path: str,
|
| 190 |
+
opset_version: int = 21,
|
| 191 |
device: str = "cpu",
|
| 192 |
) -> None:
|
| 193 |
"""Export DACVAE decoder to ONNX."""
|
|
|
|
| 223 |
|
| 224 |
# Validate
|
| 225 |
import onnx
|
| 226 |
+
# Load without external data to avoid OOM - we just need to validate structure
|
| 227 |
+
model = onnx.load(output_path, load_external_data=False)
|
| 228 |
+
onnx.checker.check_model(model, full_check=False)
|
| 229 |
print(f" ✓ ONNX model validation passed")
|
| 230 |
|
| 231 |
|
onnx_export/export_dit.py
CHANGED
|
@@ -371,16 +371,28 @@ def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "
|
|
| 371 |
def export_dit_single_step(
|
| 372 |
single_step: DiTSingleStepWrapper,
|
| 373 |
output_path: str,
|
| 374 |
-
opset_version: int =
|
| 375 |
device: str = "cpu",
|
|
|
|
| 376 |
):
|
| 377 |
"""Export single-step DiT to ONNX (for runtime ODE solving)."""
|
| 378 |
import onnx
|
| 379 |
|
| 380 |
print(f"Exporting DiT single-step to {output_path}...")
|
| 381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
sample_inputs = create_sample_inputs(device=device)
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
torch.onnx.export(
|
| 385 |
single_step,
|
| 386 |
tuple(sample_inputs.values()),
|
|
@@ -407,9 +419,19 @@ def export_dit_single_step(
|
|
| 407 |
|
| 408 |
print(" ✓ DiT single-step exported successfully")
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
return True
|
| 415 |
|
|
@@ -484,8 +506,8 @@ def main():
|
|
| 484 |
parser.add_argument(
|
| 485 |
"--opset",
|
| 486 |
type=int,
|
| 487 |
-
default=
|
| 488 |
-
help="ONNX opset version (default:
|
| 489 |
)
|
| 490 |
parser.add_argument(
|
| 491 |
"--device",
|
|
@@ -504,6 +526,11 @@ def main():
|
|
| 504 |
default=1e-3,
|
| 505 |
help="Tolerance for verification (default: 1e-3)",
|
| 506 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
args = parser.parse_args()
|
| 509 |
|
|
@@ -525,8 +552,12 @@ def main():
|
|
| 525 |
single_step_path,
|
| 526 |
opset_version=args.opset,
|
| 527 |
device=args.device,
|
|
|
|
| 528 |
)
|
| 529 |
|
|
|
|
|
|
|
|
|
|
| 530 |
# Verify single-step
|
| 531 |
if args.verify:
|
| 532 |
verify_dit_single_step(
|
|
|
|
| 371 |
def export_dit_single_step(
|
| 372 |
single_step: DiTSingleStepWrapper,
|
| 373 |
output_path: str,
|
| 374 |
+
opset_version: int = 21,
|
| 375 |
device: str = "cpu",
|
| 376 |
+
fp16: bool = False,
|
| 377 |
):
|
| 378 |
"""Export single-step DiT to ONNX (for runtime ODE solving)."""
|
| 379 |
import onnx
|
| 380 |
|
| 381 |
print(f"Exporting DiT single-step to {output_path}...")
|
| 382 |
|
| 383 |
+
# Convert to FP16 if requested
|
| 384 |
+
if fp16:
|
| 385 |
+
print(" Converting model to FP16...")
|
| 386 |
+
single_step = single_step.half()
|
| 387 |
+
|
| 388 |
sample_inputs = create_sample_inputs(device=device)
|
| 389 |
|
| 390 |
+
# Convert float inputs to FP16 if exporting in FP16
|
| 391 |
+
if fp16:
|
| 392 |
+
for key, value in sample_inputs.items():
|
| 393 |
+
if value.dtype == torch.float32:
|
| 394 |
+
sample_inputs[key] = value.half()
|
| 395 |
+
|
| 396 |
torch.onnx.export(
|
| 397 |
single_step,
|
| 398 |
tuple(sample_inputs.values()),
|
|
|
|
| 419 |
|
| 420 |
print(" ✓ DiT single-step exported successfully")
|
| 421 |
|
| 422 |
+
# When using external_data=True, we can't run check_model on a model
|
| 423 |
+
# loaded without external data - the checker validates data references.
|
| 424 |
+
# Since torch.onnx.export with dynamo=True already validates the model,
|
| 425 |
+
# we just verify the files exist.
|
| 426 |
+
external_data_path = output_path + ".data"
|
| 427 |
+
if os.path.exists(external_data_path):
|
| 428 |
+
print(f" ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)")
|
| 429 |
+
else:
|
| 430 |
+
raise RuntimeError(f"External data file missing: {external_data_path}")
|
| 431 |
+
|
| 432 |
+
# Verify the ONNX file structure is valid (without loading weights)
|
| 433 |
+
model = onnx.load(output_path, load_external_data=False)
|
| 434 |
+
print(f" ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)")
|
| 435 |
|
| 436 |
return True
|
| 437 |
|
|
|
|
| 506 |
parser.add_argument(
|
| 507 |
"--opset",
|
| 508 |
type=int,
|
| 509 |
+
default=21,
|
| 510 |
+
help="ONNX opset version (default: 21)",
|
| 511 |
)
|
| 512 |
parser.add_argument(
|
| 513 |
"--device",
|
|
|
|
| 526 |
default=1e-3,
|
| 527 |
help="Tolerance for verification (default: 1e-3)",
|
| 528 |
)
|
| 529 |
+
parser.add_argument(
|
| 530 |
+
"--fp16",
|
| 531 |
+
action="store_true",
|
| 532 |
+
help="Export model in FP16 precision (half the size)",
|
| 533 |
+
)
|
| 534 |
|
| 535 |
args = parser.parse_args()
|
| 536 |
|
|
|
|
| 552 |
single_step_path,
|
| 553 |
opset_version=args.opset,
|
| 554 |
device=args.device,
|
| 555 |
+
fp16=args.fp16,
|
| 556 |
)
|
| 557 |
|
| 558 |
+
if args.fp16:
|
| 559 |
+
print(f" ✓ Model exported in FP16 precision")
|
| 560 |
+
|
| 561 |
# Verify single-step
|
| 562 |
if args.verify:
|
| 563 |
verify_dit_single_step(
|
onnx_export/export_peaframe.py
CHANGED
|
@@ -99,7 +99,7 @@ def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"):
|
|
| 99 |
def export_peaframe(
|
| 100 |
model: nn.Module,
|
| 101 |
output_path: str,
|
| 102 |
-
opset_version: int =
|
| 103 |
device: str = "cpu",
|
| 104 |
):
|
| 105 |
"""Export PE-A-Frame to ONNX."""
|
|
@@ -165,9 +165,9 @@ def export_peaframe(
|
|
| 165 |
|
| 166 |
print(" ✓ PE-A-Frame exported successfully")
|
| 167 |
|
| 168 |
-
#
|
| 169 |
-
onnx_model = onnx.load(output_path)
|
| 170 |
-
onnx.checker.check_model(onnx_model)
|
| 171 |
print(" ✓ ONNX model validation passed")
|
| 172 |
|
| 173 |
return True
|
|
|
|
| 99 |
def export_peaframe(
|
| 100 |
model: nn.Module,
|
| 101 |
output_path: str,
|
| 102 |
+
opset_version: int = 21,
|
| 103 |
device: str = "cpu",
|
| 104 |
):
|
| 105 |
"""Export PE-A-Frame to ONNX."""
|
|
|
|
| 165 |
|
| 166 |
print(" ✓ PE-A-Frame exported successfully")
|
| 167 |
|
| 168 |
+
# Load without external data to avoid OOM - we just need to validate structure
|
| 169 |
+
onnx_model = onnx.load(output_path, load_external_data=False)
|
| 170 |
+
onnx.checker.check_model(onnx_model, full_check=False)
|
| 171 |
print(" ✓ ONNX model validation passed")
|
| 172 |
|
| 173 |
return True
|
onnx_export/export_t5.py
CHANGED
|
@@ -50,7 +50,7 @@ class T5EncoderWrapper(nn.Module):
|
|
| 50 |
return outputs.last_hidden_state
|
| 51 |
|
| 52 |
|
| 53 |
-
def load_t5_encoder(model_name: str = "google-t5/t5-base", device: str = "
|
| 54 |
"""
|
| 55 |
Load T5 encoder model and tokenizer.
|
| 56 |
|
|
@@ -72,9 +72,9 @@ def export_t5_encoder(
|
|
| 72 |
t5_model,
|
| 73 |
tokenizer,
|
| 74 |
output_path: str,
|
| 75 |
-
opset_version: int =
|
| 76 |
max_length: int = 77,
|
| 77 |
-
device: str = "
|
| 78 |
):
|
| 79 |
"""Export T5 encoder to ONNX format."""
|
| 80 |
import onnx
|
|
@@ -116,9 +116,9 @@ def export_t5_encoder(
|
|
| 116 |
|
| 117 |
print(" ✓ T5 encoder exported successfully")
|
| 118 |
|
| 119 |
-
#
|
| 120 |
-
model = onnx.load(output_path)
|
| 121 |
-
onnx.checker.check_model(model)
|
| 122 |
print(" ✓ ONNX model validation passed")
|
| 123 |
|
| 124 |
return True
|
|
@@ -129,7 +129,7 @@ def verify_t5_encoder(
|
|
| 129 |
tokenizer,
|
| 130 |
onnx_path: str,
|
| 131 |
max_length: int = 77,
|
| 132 |
-
device: str = "
|
| 133 |
tolerance: float = 1e-4,
|
| 134 |
) -> bool:
|
| 135 |
"""Verify ONNX T5 encoder output matches PyTorch."""
|
|
@@ -165,7 +165,7 @@ def verify_t5_encoder(
|
|
| 165 |
pytorch_output = wrapper(input_ids, attention_mask).cpu().numpy()
|
| 166 |
|
| 167 |
# ONNX Runtime output
|
| 168 |
-
sess = ort.InferenceSession(onnx_path, providers=["
|
| 169 |
onnx_output = sess.run(
|
| 170 |
["hidden_states"],
|
| 171 |
{
|
|
@@ -247,8 +247,8 @@ def main():
|
|
| 247 |
parser.add_argument(
|
| 248 |
"--device",
|
| 249 |
type=str,
|
| 250 |
-
default="
|
| 251 |
-
help="Device to use for export (default:
|
| 252 |
)
|
| 253 |
parser.add_argument(
|
| 254 |
"--verify",
|
|
|
|
| 50 |
return outputs.last_hidden_state
|
| 51 |
|
| 52 |
|
| 53 |
+
def load_t5_encoder(model_name: str = "google-t5/t5-base", device: str = "cuda"):
|
| 54 |
"""
|
| 55 |
Load T5 encoder model and tokenizer.
|
| 56 |
|
|
|
|
| 72 |
t5_model,
|
| 73 |
tokenizer,
|
| 74 |
output_path: str,
|
| 75 |
+
opset_version: int = 21,
|
| 76 |
max_length: int = 77,
|
| 77 |
+
device: str = "cuda",
|
| 78 |
):
|
| 79 |
"""Export T5 encoder to ONNX format."""
|
| 80 |
import onnx
|
|
|
|
| 116 |
|
| 117 |
print(" ✓ T5 encoder exported successfully")
|
| 118 |
|
| 119 |
+
# Load without external data to avoid OOM - we just need to validate structure
|
| 120 |
+
model = onnx.load(output_path, load_external_data=False)
|
| 121 |
+
onnx.checker.check_model(model, full_check=False)
|
| 122 |
print(" ✓ ONNX model validation passed")
|
| 123 |
|
| 124 |
return True
|
|
|
|
| 129 |
tokenizer,
|
| 130 |
onnx_path: str,
|
| 131 |
max_length: int = 77,
|
| 132 |
+
device: str = "cuda",
|
| 133 |
tolerance: float = 1e-4,
|
| 134 |
) -> bool:
|
| 135 |
"""Verify ONNX T5 encoder output matches PyTorch."""
|
|
|
|
| 165 |
pytorch_output = wrapper(input_ids, attention_mask).cpu().numpy()
|
| 166 |
|
| 167 |
# ONNX Runtime output
|
| 168 |
+
sess = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
|
| 169 |
onnx_output = sess.run(
|
| 170 |
["hidden_states"],
|
| 171 |
{
|
|
|
|
| 247 |
parser.add_argument(
|
| 248 |
"--device",
|
| 249 |
type=str,
|
| 250 |
+
default="cuda",
|
| 251 |
+
help="Device to use for export (default: cuda)",
|
| 252 |
)
|
| 253 |
parser.add_argument(
|
| 254 |
"--verify",
|
onnx_export/export_vision.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import onnx
|
| 6 |
+
from sam_audio.model.vision_encoder import PerceptionEncoder
|
| 7 |
+
from onnx_export.standalone_config import PerceptionEncoderConfig
|
| 8 |
+
|
| 9 |
+
class VisionEncoderWrapper(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Wrapper for the Vision Encoder (CLIP visual backbone).
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, vision_encoder):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.model = vision_encoder.model
|
| 16 |
+
self.normalize = vision_encoder.normalize_feature
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
# x: (N, 3, H, W) where N is number of frames
|
| 20 |
+
# returns: (N, 1024) features
|
| 21 |
+
return self.model.encode_image(x, normalize=self.normalize)
|
| 22 |
+
|
| 23 |
+
def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models", device="cpu"):
|
| 24 |
+
"""Export the vision encoder to ONNX."""
|
| 25 |
+
print(f"Loading Vision Encoder from {model_id}...")
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from transformers import AutoConfig
|
| 29 |
+
from sam_audio.model.vision_encoder import PerceptionEncoder
|
| 30 |
+
from onnx_export.standalone_config import PerceptionEncoderConfig
|
| 31 |
+
|
| 32 |
+
print("Fetching config...")
|
| 33 |
+
cfg_hf = AutoConfig.from_pretrained(model_id)
|
| 34 |
+
cfg_dict = cfg_hf.to_dict()
|
| 35 |
+
|
| 36 |
+
# Extract vision encoder config
|
| 37 |
+
v_cfg_dict = cfg_dict.get("vision_encoder", {})
|
| 38 |
+
v_cfg = PerceptionEncoderConfig(**v_cfg_dict)
|
| 39 |
+
|
| 40 |
+
print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...")
|
| 41 |
+
vision_encoder = PerceptionEncoder(v_cfg)
|
| 42 |
+
|
| 43 |
+
# Load weights from checkpoint
|
| 44 |
+
print("Loading weights from SAM Audio checkpoint...")
|
| 45 |
+
from huggingface_hub import hf_hub_download
|
| 46 |
+
checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
|
| 47 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
|
| 48 |
+
|
| 49 |
+
# Filter vision encoder weights
|
| 50 |
+
vision_state = {}
|
| 51 |
+
prefix = "vision_encoder."
|
| 52 |
+
for key, value in state_dict.items():
|
| 53 |
+
if key.startswith(prefix):
|
| 54 |
+
new_key = key[len(prefix):]
|
| 55 |
+
vision_state[new_key] = value
|
| 56 |
+
|
| 57 |
+
if vision_state:
|
| 58 |
+
print(f" Loading {len(vision_state)} tensors into vision encoder...")
|
| 59 |
+
vision_encoder.load_state_dict(vision_state)
|
| 60 |
+
print(" ✓ Vision encoder weights loaded.")
|
| 61 |
+
else:
|
| 62 |
+
print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.")
|
| 63 |
+
|
| 64 |
+
image_size = vision_encoder.image_size
|
| 65 |
+
print(f" Image size: {image_size}")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
wrapper = VisionEncoderWrapper(vision_encoder).eval().to(device)
|
| 69 |
+
|
| 70 |
+
# Create dummy input on device
|
| 71 |
+
image_size = vision_encoder.image_size
|
| 72 |
+
dummy_input = torch.randn(1, 3, image_size, image_size, device=device)
|
| 73 |
+
|
| 74 |
+
output_path = os.path.join(output_dir, "vision_encoder.onnx")
|
| 75 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
print(f"Exporting to {output_path}...")
|
| 78 |
+
input_names = ["video_frames"]
|
| 79 |
+
output_names = ["vision_features"]
|
| 80 |
+
opset_version = 18 # Use opset 18 for better CUDA compatibility
|
| 81 |
+
torch.onnx.export(
|
| 82 |
+
wrapper,
|
| 83 |
+
dummy_input,
|
| 84 |
+
output_path,
|
| 85 |
+
input_names=input_names,
|
| 86 |
+
output_names=output_names,
|
| 87 |
+
dynamic_axes={
|
| 88 |
+
"video_frames": {0: "num_frames"},
|
| 89 |
+
"vision_features": {0: "num_frames"},
|
| 90 |
+
},
|
| 91 |
+
opset_version=opset_version,
|
| 92 |
+
do_constant_folding=True,
|
| 93 |
+
dynamo=True,
|
| 94 |
+
external_data=True,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Check if data was saved separately
|
| 98 |
+
data_path = output_path + ".data"
|
| 99 |
+
if os.path.exists(data_path):
|
| 100 |
+
print(f" Large model detected, weights saved to {data_path}")
|
| 101 |
+
|
| 102 |
+
print("✓ Vision encoder export complete!")
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
import argparse
|
| 106 |
+
parser = argparse.ArgumentParser()
|
| 107 |
+
parser.add_argument("--model", type=str, default="facebook/sam-audio-small")
|
| 108 |
+
parser.add_argument("--output", type=str, default="onnx_models")
|
| 109 |
+
parser.add_argument("--device", type=str, default="cpu", help="Device for export (cpu or cuda)")
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
|
| 112 |
+
export_vision_encoder(args.model, args.output, device=args.device)
|
| 113 |
+
|
onnx_export/standalone_config.py
CHANGED
|
@@ -57,6 +57,29 @@ class T5EncoderConfig:
|
|
| 57 |
self.pad_mode = pad_mode
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
class TransformerConfig:
|
| 61 |
"""Configuration for the DiT transformer."""
|
| 62 |
|
|
|
|
| 57 |
self.pad_mode = pad_mode
|
| 58 |
|
| 59 |
|
| 60 |
+
class VisionEncoderConfig:
|
| 61 |
+
def __init__(self, dim: int = 1024, batch_size: int = 300):
|
| 62 |
+
self.dim = dim
|
| 63 |
+
self.batch_size = batch_size
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class PerceptionEncoderConfig(VisionEncoderConfig):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
dim: int = 1024,
|
| 70 |
+
batch_size: int = 300,
|
| 71 |
+
name: str = "PE-Core-L14-336",
|
| 72 |
+
normalize_feature: bool = True,
|
| 73 |
+
interpolation_mode: str = "BICUBIC",
|
| 74 |
+
image_size: int = 336,
|
| 75 |
+
):
|
| 76 |
+
super().__init__(dim=dim, batch_size=batch_size)
|
| 77 |
+
self.name = name
|
| 78 |
+
self.normalize_feature = normalize_feature
|
| 79 |
+
self.interpolation_mode = interpolation_mode
|
| 80 |
+
self.image_size = image_size
|
| 81 |
+
|
| 82 |
+
|
| 83 |
class TransformerConfig:
|
| 84 |
"""Configuration for the DiT transformer."""
|
| 85 |
|
onnx_inference.py
CHANGED
|
@@ -377,6 +377,11 @@ class SAMAudioONNXPipeline:
|
|
| 377 |
batch_size = noisy_audio.shape[0]
|
| 378 |
seq_len = noisy_audio.shape[1]
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
# Prepare placeholders for anchors if not used
|
| 381 |
# anchor_ids: <null>=0, <pad>=3. [B, 2]
|
| 382 |
anchor_ids = np.zeros((batch_size, 2), dtype=np.int64)
|
|
@@ -392,15 +397,15 @@ class SAMAudioONNXPipeline:
|
|
| 392 |
if masked_video_features is None:
|
| 393 |
# Vision dimension is 1024 for small
|
| 394 |
vision_dim = 1024
|
| 395 |
-
masked_video_features = np.zeros((batch_size, vision_dim, seq_len), dtype=
|
| 396 |
|
| 397 |
inputs = {
|
| 398 |
-
"noisy_audio": noisy_audio.astype(
|
| 399 |
-
"time": np.array([time], dtype=
|
| 400 |
-
"audio_features": audio_features.astype(
|
| 401 |
-
"text_features": text_features.astype(
|
| 402 |
"text_mask": text_mask.astype(np.bool_),
|
| 403 |
-
"masked_video_features": masked_video_features.astype(
|
| 404 |
"anchor_ids": anchor_ids.astype(np.int64),
|
| 405 |
"anchor_alignment": anchor_alignment.astype(np.int64),
|
| 406 |
"audio_pad_mask": audio_pad_mask.astype(np.bool_),
|
|
|
|
| 377 |
batch_size = noisy_audio.shape[0]
|
| 378 |
seq_len = noisy_audio.shape[1]
|
| 379 |
|
| 380 |
+
# Detect if model expects FP16 inputs
|
| 381 |
+
first_input = self.dit.get_inputs()[0]
|
| 382 |
+
use_fp16 = first_input.type == 'tensor(float16)'
|
| 383 |
+
float_dtype = np.float16 if use_fp16 else np.float32
|
| 384 |
+
|
| 385 |
# Prepare placeholders for anchors if not used
|
| 386 |
# anchor_ids: <null>=0, <pad>=3. [B, 2]
|
| 387 |
anchor_ids = np.zeros((batch_size, 2), dtype=np.int64)
|
|
|
|
| 397 |
if masked_video_features is None:
|
| 398 |
# Vision dimension is 1024 for small
|
| 399 |
vision_dim = 1024
|
| 400 |
+
masked_video_features = np.zeros((batch_size, vision_dim, seq_len), dtype=float_dtype)
|
| 401 |
|
| 402 |
inputs = {
|
| 403 |
+
"noisy_audio": noisy_audio.astype(float_dtype),
|
| 404 |
+
"time": np.array([time], dtype=float_dtype),
|
| 405 |
+
"audio_features": audio_features.astype(float_dtype),
|
| 406 |
+
"text_features": text_features.astype(float_dtype),
|
| 407 |
"text_mask": text_mask.astype(np.bool_),
|
| 408 |
+
"masked_video_features": masked_video_features.astype(float_dtype),
|
| 409 |
"anchor_ids": anchor_ids.astype(np.int64),
|
| 410 |
"anchor_alignment": anchor_alignment.astype(np.int64),
|
| 411 |
"audio_pad_mask": audio_pad_mask.astype(np.bool_),
|