Spaces:
Sleeping
Sleeping
File size: 3,354 Bytes
76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | """
Merges LoRA adapter weights into the backbone and exports to ONNX.
Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime).
Requires: optimum[onnxruntime]
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from peft import PeftModel
from transformers import WhisperProcessor
logger = logging.getLogger(__name__)
class ONNXExporter:
"""Merges a LoRA PeftModel into its base model and exports to ONNX."""
def merge_and_export(
self,
peft_model: "PeftModel",
processor: "WhisperProcessor",
output_dir: str,
language: str,
) -> Path:
"""
1. Merge LoRA weights into base model (merge_and_unload)
2. Export merged model to ONNX via optimum
Returns the output directory path.
"""
output_path = Path(output_dir) / language
output_path.mkdir(parents=True, exist_ok=True)
logger.info("Merging LoRA adapter '%s' into base model...", language)
merged_model = peft_model.merge_and_unload()
merged_model.eval()
logger.info("Exporting to ONNX: %s", output_path)
self._export_with_optimum(merged_model, processor, str(output_path))
return output_path
def _export_with_optimum(
self,
merged_model,
processor: "WhisperProcessor",
output_dir: str,
) -> None:
"""Use optimum's ONNX export pipeline."""
from optimum.exporters.onnx import main_export
# Save merged model to a temp directory first
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
logger.info("Saving merged model to temp dir for export...")
merged_model.save_pretrained(tmp_dir)
processor.save_pretrained(tmp_dir)
logger.info("Running optimum ONNX export...")
main_export(
model_name_or_path=tmp_dir,
output=output_dir,
task="automatic-speech-recognition",
opset=17,
optimize="O2",
)
logger.info("ONNX export complete: %s", output_dir)
def validate(
self,
onnx_dir: str,
processor: "WhisperProcessor",
test_audio_arrays: list,
sample_rate: int = 16_000,
reference_texts: list[str] | None = None,
) -> dict:
"""
Run inference with the exported ONNX model and compute WER vs. references.
"""
import numpy as np
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
logger.info("Validating ONNX model at %s...", onnx_dir)
ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir)
transcriptions = []
for audio in test_audio_arrays:
inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
outputs = ort_model.generate(inputs.input_features)
text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
transcriptions.append(text)
result = {"transcriptions": transcriptions}
if reference_texts:
import jiwer
wer = jiwer.wer(reference_texts, transcriptions)
result["wer"] = round(wer, 4)
return result
|