File size: 3,354 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Merges LoRA adapter weights into the backbone and exports to ONNX.
Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime).

Requires: optimum[onnxruntime]
"""
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from peft import PeftModel
    from transformers import WhisperProcessor

logger = logging.getLogger(__name__)


class ONNXExporter:
    """Merges a LoRA PeftModel into its base model and exports to ONNX."""

    def merge_and_export(
        self,
        peft_model: "PeftModel",
        processor: "WhisperProcessor",
        output_dir: str,
        language: str,
    ) -> Path:
        """
        1. Merge LoRA weights into base model (merge_and_unload)
        2. Export merged model to ONNX via optimum
        Returns the output directory path.
        """
        output_path = Path(output_dir) / language
        output_path.mkdir(parents=True, exist_ok=True)

        logger.info("Merging LoRA adapter '%s' into base model...", language)
        merged_model = peft_model.merge_and_unload()
        merged_model.eval()

        logger.info("Exporting to ONNX: %s", output_path)
        self._export_with_optimum(merged_model, processor, str(output_path))

        return output_path

    def _export_with_optimum(
        self,
        merged_model,
        processor: "WhisperProcessor",
        output_dir: str,
    ) -> None:
        """Use optimum's ONNX export pipeline."""
        from optimum.exporters.onnx import main_export

        # Save merged model to a temp directory first
        import tempfile

        with tempfile.TemporaryDirectory() as tmp_dir:
            logger.info("Saving merged model to temp dir for export...")
            merged_model.save_pretrained(tmp_dir)
            processor.save_pretrained(tmp_dir)

            logger.info("Running optimum ONNX export...")
            main_export(
                model_name_or_path=tmp_dir,
                output=output_dir,
                task="automatic-speech-recognition",
                opset=17,
                optimize="O2",
            )

        logger.info("ONNX export complete: %s", output_dir)

    def validate(
        self,
        onnx_dir: str,
        processor: "WhisperProcessor",
        test_audio_arrays: list,
        sample_rate: int = 16_000,
        reference_texts: list[str] | None = None,
    ) -> dict:
        """
        Run inference with the exported ONNX model and compute WER vs. references.
        """
        import numpy as np
        from optimum.onnxruntime import ORTModelForSpeechSeq2Seq

        logger.info("Validating ONNX model at %s...", onnx_dir)
        ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir)

        transcriptions = []
        for audio in test_audio_arrays:
            inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
            outputs = ort_model.generate(inputs.input_features)
            text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
            transcriptions.append(text)

        result = {"transcriptions": transcriptions}

        if reference_texts:
            import jiwer
            wer = jiwer.wer(reference_texts, transcriptions)
            result["wer"] = round(wer, 4)

        return result