File size: 3,864 Bytes
9beb89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ac136a
 
 
 
 
 
9beb89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ac136a
9beb89b
 
 
 
 
 
7ac136a
9beb89b
 
 
 
 
 
7ac136a
9beb89b
 
 
 
 
 
7ac136a
9beb89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
"""
Export all SAM Audio components to ONNX format.

This script exports:
1. DACVAE encoder and decoder (audio codec)
2. T5 text encoder
3. DiT transformer (single-step for ODE solving)
4. Vision encoder (CLIP-based, for video-guided separation)
    python -m onnx_export.export_all --output-dir onnx_models --verify
"""

import os
import argparse
import subprocess
import sys


def run_export(module: str, args: list[str]) -> bool:
    """Run an export module with the given arguments."""
    cmd = [sys.executable, "-m", module] + args
    print(f"\n{'='*60}")
    print(f"Running: {' '.join(cmd)}")
    print(f"{'='*60}\n")
    
    result = subprocess.run(cmd, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    return result.returncode == 0


def main():
    parser = argparse.ArgumentParser(description="Export all SAM Audio components to ONNX")
    parser.add_argument(
        "--output-dir",
        type=str,
        default="onnx_models",
        help="Output directory for ONNX models",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="facebook/sam-audio-small",
        help="SAM-Audio model ID (e.g., facebook/sam-audio-small, facebook/sam-audio-large, facebook/sam-audio-base-tv)",
    )
    parser.add_argument(
        "--verify",
        action="store_true",
        help="Verify ONNX output matches PyTorch",
    )
    parser.add_argument(
        "--skip-dacvae",
        action="store_true",
        help="Skip DACVAE export",
    )
    parser.add_argument(
        "--skip-t5",
        action="store_true",
        help="Skip T5 export",
    )
    parser.add_argument(
        "--skip-dit",
        action="store_true",
        help="Skip DiT export",
    )
    parser.add_argument(
        "--skip-vision",
        action="store_true",
        help="Skip Vision encoder export",
    )
    
    args = parser.parse_args()
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    results = {}
    
    # Export DACVAE
    if not args.skip_dacvae:
        export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
        if args.verify:
            export_args.append("--verify")
        results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
    
    # Export T5
    if not args.skip_t5:
        export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
        if args.verify:
            export_args.append("--verify")
        results["T5"] = run_export("onnx_export.export_t5", export_args)
    
    # Export DiT
    if not args.skip_dit:
        export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
        if args.verify:
            export_args.append("--verify")
        results["DiT"] = run_export("onnx_export.export_dit", export_args)
    
    # Export Vision Encoder
    if not args.skip_vision:
        export_args = ["--output", args.output_dir, "--model", args.model]
        results["Vision"] = run_export("onnx_export.export_vision", export_args)
    
    # Print summary
    print(f"\n{'='*60}")
    print("Export Summary")
    print(f"{'='*60}")
    
    all_success = True
    for name, success in results.items():
        status = "✓" if success else "✗"
        print(f"  {status} {name}")
        if not success:
            all_success = False
    
    # List exported files
    print(f"\nExported files in {args.output_dir}:")
    for f in sorted(os.listdir(args.output_dir)):
        path = os.path.join(args.output_dir, f)
        if os.path.isfile(path):
            size_mb = os.path.getsize(path) / (1024 * 1024)
            print(f"  {f}: {size_mb:.1f} MB")
    
    if all_success:
        print("\n✓ All exports completed successfully!")
    else:
        print("\n✗ Some exports failed")
        sys.exit(1)


if __name__ == "__main__":
    main()