Spaces:
Running
on
Zero
Running
on
Zero
| """Export phoneme ASR models to optimized ONNX format and upload to HF Hub. | |
| Usage: | |
| pip install optimum[onnxruntime] onnxruntime | |
| HF_TOKEN=... python scripts/export_onnx.py | |
| HF_TOKEN=... python scripts/export_onnx.py --quantize # + dynamic INT8 | |
| HF_TOKEN=... python scripts/export_onnx.py --models Base # single model | |
| Exports fp32 ONNX with ORT graph optimizations baked in. | |
| Optionally applies dynamic INT8 quantization for CPU inference. | |
| Note: ORTOptimizer's transformer-specific fusions (attention, LayerNorm, GELU) | |
| do NOT support wav2vec2. We use ORT's general graph optimizations instead | |
| (constant folding, redundant node elimination, common subexpression elimination). | |
| Runtime ORT_ENABLE_ALL adds further optimizations at session load time. | |
| """ | |
| import argparse | |
| import os | |
| import shutil | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from config import PHONEME_ASR_MODELS | |
| def get_hf_token(): | |
| """Get HF token from env or cached login.""" | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| try: | |
| from huggingface_hub import HfFolder | |
| hf_token = HfFolder.get_token() | |
| except Exception: | |
| pass | |
| if not hf_token: | |
| print("WARNING: No HF token found. Set HF_TOKEN env var or run `huggingface-cli login`.") | |
| return hf_token | |
| def _optimize_graph(model_path: Path): | |
| """Apply general ORT graph optimizations (no transformer-specific fusions). | |
| Bakes constant folding, redundant node elimination, and common subexpression | |
| elimination into the model file so they don't need to run at session load time. | |
| """ | |
| import onnxruntime as ort | |
| model_file = str(model_path / "model.onnx") | |
| optimized_file = str(model_path / "model_optimized.onnx") | |
| sess_options = ort.SessionOptions() | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| sess_options.optimized_model_filepath = optimized_file | |
| # Create session just to trigger optimization and save | |
| ort.InferenceSession(model_file, sess_options, providers=["CPUExecutionProvider"]) | |
| # Replace original with optimized | |
| os.replace(optimized_file, model_file) | |
| def export_model(model_name: str, model_path: str, output_dir: Path, hf_token: str, | |
| quantize: bool = False): | |
| """Export a single model to optimized fp32 ONNX and push to HF Hub.""" | |
| from optimum.onnxruntime import ORTModelForCTC | |
| print(f"\n{'='*60}") | |
| print(f"Exporting '{model_name}' ({model_path})") | |
| print(f"{'='*60}") | |
| # Clean output dir for fresh export | |
| if output_dir.exists(): | |
| shutil.rmtree(output_dir) | |
| output_dir.mkdir(parents=True) | |
| # Step 1: Export to ONNX (fp32) | |
| print(f" [1/5] Exporting fp32 ONNX...") | |
| model = ORTModelForCTC.from_pretrained(model_path, export=True, token=hf_token) | |
| model.save_pretrained(output_dir) | |
| print(f" Saved to {output_dir}") | |
| # Step 2: Apply general ORT graph optimizations | |
| # (wav2vec2 is not supported by ORTOptimizer's transformer-specific fusions, | |
| # so we use ORT's built-in graph optimizations directly) | |
| print(f" [2/5] Applying ORT graph optimizations...") | |
| _optimize_graph(output_dir) | |
| print(f" Graph optimization complete") | |
| # Step 3: Optional dynamic INT8 quantization | |
| model_file = output_dir / "model.onnx" | |
| if quantize: | |
| print(f" [3/5] Applying dynamic INT8 quantization (avx2)...") | |
| from onnxruntime.quantization import QuantType, quantize_dynamic | |
| quantized_file = output_dir / "model_quantized.onnx" | |
| quantize_dynamic( | |
| model_input=str(model_file), | |
| model_output=str(quantized_file), | |
| weight_type=QuantType.QInt8, | |
| ) | |
| # Replace original with quantized | |
| os.replace(str(quantized_file), str(model_file)) | |
| print(f" INT8 quantization complete") | |
| else: | |
| print(f" [3/5] Skipping quantization (use --quantize to enable)") | |
| # Step 4: Verify with dummy forward pass | |
| print(f" [4/5] Verifying model...") | |
| import numpy as np | |
| import onnxruntime as ort | |
| sess = ort.InferenceSession(str(model_file), providers=["CPUExecutionProvider"]) | |
| input_info = sess.get_inputs()[0] | |
| print(f" Input: name={input_info.name}, type={input_info.type}, shape={input_info.shape}") | |
| dummy = np.random.randn(1, 16000).astype(np.float32) | |
| out = sess.run(None, {"input_values": dummy}) | |
| print(f" Output shape: {out[0].shape} (dtype={out[0].dtype})") | |
| del sess, dummy, out | |
| # Step 5: Push to HF Hub | |
| print(f" [5/5] Uploading to HF Hub...") | |
| from huggingface_hub import HfApi | |
| repo_name = model_path.split("/")[-1] | |
| hub_repo = f"hetchyy/{repo_name}-onnx" | |
| api = HfApi(token=hf_token) | |
| api.create_repo(repo_id=hub_repo, repo_type="model", private=True, exist_ok=True) | |
| api.upload_folder(folder_path=str(output_dir), repo_id=hub_repo, repo_type="model") | |
| print(f" Pushed to {hub_repo}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Export phoneme ASR models to optimized ONNX") | |
| parser.add_argument("--quantize", action="store_true", | |
| help="Apply dynamic INT8 quantization after graph optimization") | |
| parser.add_argument("--models", nargs="+", choices=list(PHONEME_ASR_MODELS.keys()), | |
| default=list(PHONEME_ASR_MODELS.keys()), | |
| help="Which models to export (default: all)") | |
| args = parser.parse_args() | |
| hf_token = get_hf_token() | |
| models_dir = Path(__file__).parent.parent / "models" | |
| models_dir.mkdir(exist_ok=True) | |
| for name in args.models: | |
| path = PHONEME_ASR_MODELS[name] | |
| output_dir = models_dir / f"onnx_{name}" | |
| export_model(name, path, output_dir, hf_token, quantize=args.quantize) | |
| suffix = " + INT8 quantized" if args.quantize else "" | |
| print(f"\nDone. ONNX fp32 optimized{suffix} models exported and uploaded.") | |
| if __name__ == "__main__": | |
| main() | |