matbee commited on
Commit
9beb89b
·
verified ·
1 Parent(s): 136212b

Upload export_all.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. export_all.py +124 -0
export_all.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export all SAM Audio components to ONNX format.
4
+
5
+ This script exports:
6
+ 1. DACVAE encoder and decoder (audio codec)
7
+ 2. T5 text encoder
8
+ 3. DiT transformer (single-step for ODE solving)
9
+ 4. Vision encoder (CLIP-based, for video-guided separation)
10
+ python -m onnx_export.export_all --output-dir onnx_models --verify
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import subprocess
16
+ import sys
17
+
18
+
19
+ def run_export(module: str, args: list[str]) -> bool:
20
+ """Run an export module with the given arguments."""
21
+ cmd = [sys.executable, "-m", module] + args
22
+ print(f"\n{'='*60}")
23
+ print(f"Running: {' '.join(cmd)}")
24
+ print(f"{'='*60}\n")
25
+
26
+ result = subprocess.run(cmd, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
27
+ return result.returncode == 0
28
+
29
+
30
+ def main():
31
+ parser = argparse.ArgumentParser(description="Export all SAM Audio components to ONNX")
32
+ parser.add_argument(
33
+ "--output-dir",
34
+ type=str,
35
+ default="onnx_models",
36
+ help="Output directory for ONNX models",
37
+ )
38
+ parser.add_argument(
39
+ "--verify",
40
+ action="store_true",
41
+ help="Verify ONNX output matches PyTorch",
42
+ )
43
+ parser.add_argument(
44
+ "--skip-dacvae",
45
+ action="store_true",
46
+ help="Skip DACVAE export",
47
+ )
48
+ parser.add_argument(
49
+ "--skip-t5",
50
+ action="store_true",
51
+ help="Skip T5 export",
52
+ )
53
+ parser.add_argument(
54
+ "--skip-dit",
55
+ action="store_true",
56
+ help="Skip DiT export",
57
+ )
58
+ parser.add_argument(
59
+ "--skip-vision",
60
+ action="store_true",
61
+ help="Skip Vision encoder export",
62
+ )
63
+
64
+ args = parser.parse_args()
65
+
66
+ os.makedirs(args.output_dir, exist_ok=True)
67
+
68
+ results = {}
69
+
70
+ # Export DACVAE
71
+ if not args.skip_dacvae:
72
+ export_args = ["--output-dir", args.output_dir]
73
+ if args.verify:
74
+ export_args.append("--verify")
75
+ results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
76
+
77
+ # Export T5
78
+ if not args.skip_t5:
79
+ export_args = ["--output-dir", args.output_dir]
80
+ if args.verify:
81
+ export_args.append("--verify")
82
+ results["T5"] = run_export("onnx_export.export_t5", export_args)
83
+
84
+ # Export DiT
85
+ if not args.skip_dit:
86
+ export_args = ["--output-dir", args.output_dir]
87
+ if args.verify:
88
+ export_args.append("--verify")
89
+ results["DiT"] = run_export("onnx_export.export_dit", export_args)
90
+
91
+ # Export Vision Encoder
92
+ if not args.skip_vision:
93
+ export_args = ["--output", args.output_dir]
94
+ results["Vision"] = run_export("onnx_export.export_vision", export_args)
95
+
96
+ # Print summary
97
+ print(f"\n{'='*60}")
98
+ print("Export Summary")
99
+ print(f"{'='*60}")
100
+
101
+ all_success = True
102
+ for name, success in results.items():
103
+ status = "✓" if success else "✗"
104
+ print(f" {status} {name}")
105
+ if not success:
106
+ all_success = False
107
+
108
+ # List exported files
109
+ print(f"\nExported files in {args.output_dir}:")
110
+ for f in sorted(os.listdir(args.output_dir)):
111
+ path = os.path.join(args.output_dir, f)
112
+ if os.path.isfile(path):
113
+ size_mb = os.path.getsize(path) / (1024 * 1024)
114
+ print(f" {f}: {size_mb:.1f} MB")
115
+
116
+ if all_success:
117
+ print("\n✓ All exports completed successfully!")
118
+ else:
119
+ print("\n✗ Some exports failed")
120
+ sys.exit(1)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()