File size: 3,864 Bytes
9beb89b 7ac136a 9beb89b 7ac136a 9beb89b 7ac136a 9beb89b 7ac136a 9beb89b 7ac136a 9beb89b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
#!/usr/bin/env python3
"""
Export all SAM Audio components to ONNX format.
This script exports:
1. DACVAE encoder and decoder (audio codec)
2. T5 text encoder
3. DiT transformer (single-step for ODE solving)
4. Vision encoder (CLIP-based, for video-guided separation)
python -m onnx_export.export_all --output-dir onnx_models --verify
"""
import os
import argparse
import subprocess
import sys
def run_export(module: str, args: list[str]) -> bool:
"""Run an export module with the given arguments."""
cmd = [sys.executable, "-m", module] + args
print(f"\n{'='*60}")
print(f"Running: {' '.join(cmd)}")
print(f"{'='*60}\n")
result = subprocess.run(cmd, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
return result.returncode == 0
def main():
parser = argparse.ArgumentParser(description="Export all SAM Audio components to ONNX")
parser.add_argument(
"--output-dir",
type=str,
default="onnx_models",
help="Output directory for ONNX models",
)
parser.add_argument(
"--model",
type=str,
default="facebook/sam-audio-small",
help="SAM-Audio model ID (e.g., facebook/sam-audio-small, facebook/sam-audio-large, facebook/sam-audio-base-tv)",
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify ONNX output matches PyTorch",
)
parser.add_argument(
"--skip-dacvae",
action="store_true",
help="Skip DACVAE export",
)
parser.add_argument(
"--skip-t5",
action="store_true",
help="Skip T5 export",
)
parser.add_argument(
"--skip-dit",
action="store_true",
help="Skip DiT export",
)
parser.add_argument(
"--skip-vision",
action="store_true",
help="Skip Vision encoder export",
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
results = {}
# Export DACVAE
if not args.skip_dacvae:
export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
if args.verify:
export_args.append("--verify")
results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
# Export T5
if not args.skip_t5:
export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
if args.verify:
export_args.append("--verify")
results["T5"] = run_export("onnx_export.export_t5", export_args)
# Export DiT
if not args.skip_dit:
export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
if args.verify:
export_args.append("--verify")
results["DiT"] = run_export("onnx_export.export_dit", export_args)
# Export Vision Encoder
if not args.skip_vision:
export_args = ["--output", args.output_dir, "--model", args.model]
results["Vision"] = run_export("onnx_export.export_vision", export_args)
# Print summary
print(f"\n{'='*60}")
print("Export Summary")
print(f"{'='*60}")
all_success = True
for name, success in results.items():
status = "✓" if success else "✗"
print(f" {status} {name}")
if not success:
all_success = False
# List exported files
print(f"\nExported files in {args.output_dir}:")
for f in sorted(os.listdir(args.output_dir)):
path = os.path.join(args.output_dir, f)
if os.path.isfile(path):
size_mb = os.path.getsize(path) / (1024 * 1024)
print(f" {f}: {size_mb:.1f} MB")
if all_success:
print("\n✓ All exports completed successfully!")
else:
print("\n✗ Some exports failed")
sys.exit(1)
if __name__ == "__main__":
main()
|