#!/usr/bin/env python3 """ Export all SAM Audio components to ONNX format. This script exports: 1. DACVAE encoder and decoder (audio codec) 2. T5 text encoder 3. DiT transformer (single-step for ODE solving) 4. Vision encoder (CLIP-based, for video-guided separation) python -m onnx_export.export_all --output-dir onnx_models --verify """ import os import argparse import subprocess import sys def run_export(module: str, args: list[str]) -> bool: """Run an export module with the given arguments.""" cmd = [sys.executable, "-m", module] + args print(f"\n{'='*60}") print(f"Running: {' '.join(cmd)}") print(f"{'='*60}\n") result = subprocess.run(cmd, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) return result.returncode == 0 def main(): parser = argparse.ArgumentParser(description="Export all SAM Audio components to ONNX") parser.add_argument( "--output-dir", type=str, default="onnx_models", help="Output directory for ONNX models", ) parser.add_argument( "--model", type=str, default="facebook/sam-audio-small", help="SAM-Audio model ID (e.g., facebook/sam-audio-small, facebook/sam-audio-large, facebook/sam-audio-base-tv)", ) parser.add_argument( "--verify", action="store_true", help="Verify ONNX output matches PyTorch", ) parser.add_argument( "--skip-dacvae", action="store_true", help="Skip DACVAE export", ) parser.add_argument( "--skip-t5", action="store_true", help="Skip T5 export", ) parser.add_argument( "--skip-dit", action="store_true", help="Skip DiT export", ) parser.add_argument( "--skip-vision", action="store_true", help="Skip Vision encoder export", ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) results = {} # Export DACVAE if not args.skip_dacvae: export_args = ["--output-dir", args.output_dir, "--model-id", args.model] if args.verify: export_args.append("--verify") results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args) # Export T5 if not args.skip_t5: export_args = ["--output-dir", args.output_dir, "--model-id", args.model] if args.verify: export_args.append("--verify") results["T5"] = run_export("onnx_export.export_t5", export_args) # Export DiT if not args.skip_dit: export_args = ["--output-dir", args.output_dir, "--model-id", args.model] if args.verify: export_args.append("--verify") results["DiT"] = run_export("onnx_export.export_dit", export_args) # Export Vision Encoder if not args.skip_vision: export_args = ["--output", args.output_dir, "--model", args.model] results["Vision"] = run_export("onnx_export.export_vision", export_args) # Print summary print(f"\n{'='*60}") print("Export Summary") print(f"{'='*60}") all_success = True for name, success in results.items(): status = "āœ“" if success else "āœ—" print(f" {status} {name}") if not success: all_success = False # List exported files print(f"\nExported files in {args.output_dir}:") for f in sorted(os.listdir(args.output_dir)): path = os.path.join(args.output_dir, f) if os.path.isfile(path): size_mb = os.path.getsize(path) / (1024 * 1024) print(f" {f}: {size_mb:.1f} MB") if all_success: print("\nāœ“ All exports completed successfully!") else: print("\nāœ— Some exports failed") sys.exit(1) if __name__ == "__main__": main()