"""Export phoneme ASR models to optimized ONNX format and upload to HF Hub. Usage: pip install optimum[onnxruntime] onnxruntime HF_TOKEN=... python scripts/export_onnx.py HF_TOKEN=... python scripts/export_onnx.py --quantize # + dynamic INT8 HF_TOKEN=... python scripts/export_onnx.py --models Base # single model Exports fp32 ONNX with ORT graph optimizations baked in. Optionally applies dynamic INT8 quantization for CPU inference. Note: ORTOptimizer's transformer-specific fusions (attention, LayerNorm, GELU) do NOT support wav2vec2. We use ORT's general graph optimizations instead (constant folding, redundant node elimination, common subexpression elimination). Runtime ORT_ENABLE_ALL adds further optimizations at session load time. """ import argparse import os import shutil import sys from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) from config import PHONEME_ASR_MODELS def get_hf_token(): """Get HF token from env or cached login.""" hf_token = os.environ.get("HF_TOKEN") if not hf_token: try: from huggingface_hub import HfFolder hf_token = HfFolder.get_token() except Exception: pass if not hf_token: print("WARNING: No HF token found. Set HF_TOKEN env var or run `huggingface-cli login`.") return hf_token def _optimize_graph(model_path: Path): """Apply general ORT graph optimizations (no transformer-specific fusions). Bakes constant folding, redundant node elimination, and common subexpression elimination into the model file so they don't need to run at session load time. """ import onnxruntime as ort model_file = str(model_path / "model.onnx") optimized_file = str(model_path / "model_optimized.onnx") sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.optimized_model_filepath = optimized_file # Create session just to trigger optimization and save ort.InferenceSession(model_file, sess_options, providers=["CPUExecutionProvider"]) # Replace original with optimized os.replace(optimized_file, model_file) def export_model(model_name: str, model_path: str, output_dir: Path, hf_token: str, quantize: bool = False): """Export a single model to optimized fp32 ONNX and push to HF Hub.""" from optimum.onnxruntime import ORTModelForCTC print(f"\n{'='*60}") print(f"Exporting '{model_name}' ({model_path})") print(f"{'='*60}") # Clean output dir for fresh export if output_dir.exists(): shutil.rmtree(output_dir) output_dir.mkdir(parents=True) # Step 1: Export to ONNX (fp32) print(f" [1/5] Exporting fp32 ONNX...") model = ORTModelForCTC.from_pretrained(model_path, export=True, token=hf_token) model.save_pretrained(output_dir) print(f" Saved to {output_dir}") # Step 2: Apply general ORT graph optimizations # (wav2vec2 is not supported by ORTOptimizer's transformer-specific fusions, # so we use ORT's built-in graph optimizations directly) print(f" [2/5] Applying ORT graph optimizations...") _optimize_graph(output_dir) print(f" Graph optimization complete") # Step 3: Optional dynamic INT8 quantization model_file = output_dir / "model.onnx" if quantize: print(f" [3/5] Applying dynamic INT8 quantization (avx2)...") from onnxruntime.quantization import QuantType, quantize_dynamic quantized_file = output_dir / "model_quantized.onnx" quantize_dynamic( model_input=str(model_file), model_output=str(quantized_file), weight_type=QuantType.QInt8, ) # Replace original with quantized os.replace(str(quantized_file), str(model_file)) print(f" INT8 quantization complete") else: print(f" [3/5] Skipping quantization (use --quantize to enable)") # Step 4: Verify with dummy forward pass print(f" [4/5] Verifying model...") import numpy as np import onnxruntime as ort sess = ort.InferenceSession(str(model_file), providers=["CPUExecutionProvider"]) input_info = sess.get_inputs()[0] print(f" Input: name={input_info.name}, type={input_info.type}, shape={input_info.shape}") dummy = np.random.randn(1, 16000).astype(np.float32) out = sess.run(None, {"input_values": dummy}) print(f" Output shape: {out[0].shape} (dtype={out[0].dtype})") del sess, dummy, out # Step 5: Push to HF Hub print(f" [5/5] Uploading to HF Hub...") from huggingface_hub import HfApi repo_name = model_path.split("/")[-1] hub_repo = f"hetchyy/{repo_name}-onnx" api = HfApi(token=hf_token) api.create_repo(repo_id=hub_repo, repo_type="model", private=True, exist_ok=True) api.upload_folder(folder_path=str(output_dir), repo_id=hub_repo, repo_type="model") print(f" Pushed to {hub_repo}") def main(): parser = argparse.ArgumentParser(description="Export phoneme ASR models to optimized ONNX") parser.add_argument("--quantize", action="store_true", help="Apply dynamic INT8 quantization after graph optimization") parser.add_argument("--models", nargs="+", choices=list(PHONEME_ASR_MODELS.keys()), default=list(PHONEME_ASR_MODELS.keys()), help="Which models to export (default: all)") args = parser.parse_args() hf_token = get_hf_token() models_dir = Path(__file__).parent.parent / "models" models_dir.mkdir(exist_ok=True) for name in args.models: path = PHONEME_ASR_MODELS[name] output_dir = models_dir / f"onnx_{name}" export_model(name, path, output_dir, hf_token, quantize=args.quantize) suffix = " + INT8 quantized" if args.quantize else "" print(f"\nDone. ONNX fp32 optimized{suffix} models exported and uploaded.") if __name__ == "__main__": main()