Quran-multi-aligner / scripts /export_onnx.py
hetchyy's picture
Initial commit
20e9692
"""Export phoneme ASR models to optimized ONNX format and upload to HF Hub.
Usage:
pip install optimum[onnxruntime] onnxruntime
HF_TOKEN=... python scripts/export_onnx.py
HF_TOKEN=... python scripts/export_onnx.py --quantize # + dynamic INT8
HF_TOKEN=... python scripts/export_onnx.py --models Base # single model
Exports fp32 ONNX with ORT graph optimizations baked in.
Optionally applies dynamic INT8 quantization for CPU inference.
Note: ORTOptimizer's transformer-specific fusions (attention, LayerNorm, GELU)
do NOT support wav2vec2. We use ORT's general graph optimizations instead
(constant folding, redundant node elimination, common subexpression elimination).
Runtime ORT_ENABLE_ALL adds further optimizations at session load time.
"""
import argparse
import os
import shutil
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from config import PHONEME_ASR_MODELS
def get_hf_token():
"""Get HF token from env or cached login."""
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
try:
from huggingface_hub import HfFolder
hf_token = HfFolder.get_token()
except Exception:
pass
if not hf_token:
print("WARNING: No HF token found. Set HF_TOKEN env var or run `huggingface-cli login`.")
return hf_token
def _optimize_graph(model_path: Path):
"""Apply general ORT graph optimizations (no transformer-specific fusions).
Bakes constant folding, redundant node elimination, and common subexpression
elimination into the model file so they don't need to run at session load time.
"""
import onnxruntime as ort
model_file = str(model_path / "model.onnx")
optimized_file = str(model_path / "model_optimized.onnx")
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.optimized_model_filepath = optimized_file
# Create session just to trigger optimization and save
ort.InferenceSession(model_file, sess_options, providers=["CPUExecutionProvider"])
# Replace original with optimized
os.replace(optimized_file, model_file)
def export_model(model_name: str, model_path: str, output_dir: Path, hf_token: str,
quantize: bool = False):
"""Export a single model to optimized fp32 ONNX and push to HF Hub."""
from optimum.onnxruntime import ORTModelForCTC
print(f"\n{'='*60}")
print(f"Exporting '{model_name}' ({model_path})")
print(f"{'='*60}")
# Clean output dir for fresh export
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True)
# Step 1: Export to ONNX (fp32)
print(f" [1/5] Exporting fp32 ONNX...")
model = ORTModelForCTC.from_pretrained(model_path, export=True, token=hf_token)
model.save_pretrained(output_dir)
print(f" Saved to {output_dir}")
# Step 2: Apply general ORT graph optimizations
# (wav2vec2 is not supported by ORTOptimizer's transformer-specific fusions,
# so we use ORT's built-in graph optimizations directly)
print(f" [2/5] Applying ORT graph optimizations...")
_optimize_graph(output_dir)
print(f" Graph optimization complete")
# Step 3: Optional dynamic INT8 quantization
model_file = output_dir / "model.onnx"
if quantize:
print(f" [3/5] Applying dynamic INT8 quantization (avx2)...")
from onnxruntime.quantization import QuantType, quantize_dynamic
quantized_file = output_dir / "model_quantized.onnx"
quantize_dynamic(
model_input=str(model_file),
model_output=str(quantized_file),
weight_type=QuantType.QInt8,
)
# Replace original with quantized
os.replace(str(quantized_file), str(model_file))
print(f" INT8 quantization complete")
else:
print(f" [3/5] Skipping quantization (use --quantize to enable)")
# Step 4: Verify with dummy forward pass
print(f" [4/5] Verifying model...")
import numpy as np
import onnxruntime as ort
sess = ort.InferenceSession(str(model_file), providers=["CPUExecutionProvider"])
input_info = sess.get_inputs()[0]
print(f" Input: name={input_info.name}, type={input_info.type}, shape={input_info.shape}")
dummy = np.random.randn(1, 16000).astype(np.float32)
out = sess.run(None, {"input_values": dummy})
print(f" Output shape: {out[0].shape} (dtype={out[0].dtype})")
del sess, dummy, out
# Step 5: Push to HF Hub
print(f" [5/5] Uploading to HF Hub...")
from huggingface_hub import HfApi
repo_name = model_path.split("/")[-1]
hub_repo = f"hetchyy/{repo_name}-onnx"
api = HfApi(token=hf_token)
api.create_repo(repo_id=hub_repo, repo_type="model", private=True, exist_ok=True)
api.upload_folder(folder_path=str(output_dir), repo_id=hub_repo, repo_type="model")
print(f" Pushed to {hub_repo}")
def main():
parser = argparse.ArgumentParser(description="Export phoneme ASR models to optimized ONNX")
parser.add_argument("--quantize", action="store_true",
help="Apply dynamic INT8 quantization after graph optimization")
parser.add_argument("--models", nargs="+", choices=list(PHONEME_ASR_MODELS.keys()),
default=list(PHONEME_ASR_MODELS.keys()),
help="Which models to export (default: all)")
args = parser.parse_args()
hf_token = get_hf_token()
models_dir = Path(__file__).parent.parent / "models"
models_dir.mkdir(exist_ok=True)
for name in args.models:
path = PHONEME_ASR_MODELS[name]
output_dir = models_dir / f"onnx_{name}"
export_model(name, path, output_dir, hf_token, quantize=args.quantize)
suffix = " + INT8 quantized" if args.quantize else ""
print(f"\nDone. ONNX fp32 optimized{suffix} models exported and uploaded.")
if __name__ == "__main__":
main()