ground-zero / src /optimization /onnx_exporter.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
Raw
History Blame Contribute Delete
3.35 kB
"""
Merges LoRA adapter weights into the backbone and exports to ONNX.
Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime).
Requires: optimum[onnxruntime]
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from peft import PeftModel
from transformers import WhisperProcessor
logger = logging.getLogger(__name__)
class ONNXExporter:
"""Merges a LoRA PeftModel into its base model and exports to ONNX."""
def merge_and_export(
self,
peft_model: "PeftModel",
processor: "WhisperProcessor",
output_dir: str,
language: str,
) -> Path:
"""
1. Merge LoRA weights into base model (merge_and_unload)
2. Export merged model to ONNX via optimum
Returns the output directory path.
"""
output_path = Path(output_dir) / language
output_path.mkdir(parents=True, exist_ok=True)
logger.info("Merging LoRA adapter '%s' into base model...", language)
merged_model = peft_model.merge_and_unload()
merged_model.eval()
logger.info("Exporting to ONNX: %s", output_path)
self._export_with_optimum(merged_model, processor, str(output_path))
return output_path
def _export_with_optimum(
self,
merged_model,
processor: "WhisperProcessor",
output_dir: str,
) -> None:
"""Use optimum's ONNX export pipeline."""
from optimum.exporters.onnx import main_export
# Save merged model to a temp directory first
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
logger.info("Saving merged model to temp dir for export...")
merged_model.save_pretrained(tmp_dir)
processor.save_pretrained(tmp_dir)
logger.info("Running optimum ONNX export...")
main_export(
model_name_or_path=tmp_dir,
output=output_dir,
task="automatic-speech-recognition",
opset=17,
optimize="O2",
)
logger.info("ONNX export complete: %s", output_dir)
def validate(
self,
onnx_dir: str,
processor: "WhisperProcessor",
test_audio_arrays: list,
sample_rate: int = 16_000,
reference_texts: list[str] | None = None,
) -> dict:
"""
Run inference with the exported ONNX model and compute WER vs. references.
"""
import numpy as np
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
logger.info("Validating ONNX model at %s...", onnx_dir)
ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir)
transcriptions = []
for audio in test_audio_arrays:
inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
outputs = ort_model.generate(inputs.input_features)
text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
transcriptions.append(text)
result = {"transcriptions": transcriptions}
if reference_texts:
import jiwer
wer = jiwer.wer(reference_texts, transcriptions)
result["wer"] = round(wer, 4)
return result