| import logging |
| from time import perf_counter |
| from baseHandler import BaseHandler |
| from lightning_whisper_mlx import LightningWhisperMLX |
| import numpy as np |
| from rich.console import Console |
| from copy import copy |
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
| console = Console() |
|
|
| SUPPORTED_LANGUAGES = [ |
| "en", |
| "fr", |
| "es", |
| "zh", |
| "ja", |
| "ko", |
| ] |
|
|
|
|
| class LightningWhisperSTTHandler(BaseHandler): |
| """ |
| Handles the Speech To Text generation using a Whisper model. |
| """ |
|
|
| def setup( |
| self, |
| model_name="distil-large-v3", |
| device="mps", |
| torch_dtype="float16", |
| compile_mode=None, |
| language=None, |
| gen_kwargs={}, |
| ): |
| if len(model_name.split("/")) > 1: |
| model_name = model_name.split("/")[-1] |
| self.device = device |
| self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) |
| self.start_language = language |
| self.last_language = language |
|
|
| self.warmup() |
|
|
| def warmup(self): |
| logger.info(f"Warming up {self.__class__.__name__}") |
|
|
| |
| n_steps = 1 |
| dummy_input = np.array([0] * 512) |
|
|
| for _ in range(n_steps): |
| _ = self.model.transcribe(dummy_input)["text"].strip() |
|
|
| def process(self, spoken_prompt): |
| logger.debug("infering whisper...") |
|
|
| global pipeline_start |
| pipeline_start = perf_counter() |
|
|
| if self.start_language != 'auto': |
| transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language) |
| else: |
| transcription_dict = self.model.transcribe(spoken_prompt) |
| language_code = transcription_dict["language"] |
| if language_code not in SUPPORTED_LANGUAGES: |
| logger.warning(f"Whisper detected unsupported language: {language_code}") |
| if self.last_language in SUPPORTED_LANGUAGES: |
| transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language) |
| else: |
| transcription_dict = {"text": "", "language": "en"} |
| else: |
| self.last_language = language_code |
|
|
| pred_text = transcription_dict["text"].strip() |
| language_code = transcription_dict["language"] |
| torch.mps.empty_cache() |
|
|
| logger.debug("finished whisper inference") |
| console.print(f"[yellow]USER: {pred_text}") |
| logger.debug(f"Language Code Whisper: {language_code}") |
|
|
| yield (pred_text, language_code) |
|
|