import gc import time import gradio as gr import torch from src.constants import MODEL_IDS, PARAKEET_V3 from src.utils import get_audio_duration_seconds, serialize _PARAKEET_MODEL = None def _get_parakeet_model(): global _PARAKEET_MODEL if _PARAKEET_MODEL is not None: return _PARAKEET_MODEL try: import nemo.collections.asr as nemo_asr except Exception as exc: raise gr.Error( "NVIDIA Parakeet backend requested but NeMo ASR package is missing. " "Add nemo_toolkit[asr] to requirements.txt" ) from exc model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_IDS[PARAKEET_V3]) model.eval() _PARAKEET_MODEL = model return _PARAKEET_MODEL def preload_parakeet_model() -> None: _get_parakeet_model() def run_parakeet( audio_file: str, language: str, model_options: dict, duration_seconds: float | None = None, ) -> dict: model = _get_parakeet_model() device = "cuda" if torch.cuda.is_available() else "cpu" long_audio_threshold_seconds = float(model_options.get("long_audio_threshold_seconds", 480)) local_attention_left = int(model_options.get("local_attention_left", 256)) local_attention_right = int(model_options.get("local_attention_right", 256)) subsampling_conv_chunking_factor = int(model_options.get("subsampling_conv_chunking_factor", 1)) enable_long_audio_optimizations = bool(model_options.get("enable_long_audio_optimizations", True)) if duration_seconds is None: duration_seconds = get_audio_duration_seconds(audio_file) is_long_audio = duration_seconds is not None and duration_seconds > long_audio_threshold_seconds applied_long_audio_settings = False optimization_errors: list[str] = [] infer_start = time.perf_counter() try: model.to(device) model.to(torch.float32) if enable_long_audio_optimizations and is_long_audio: try: model.change_attention_model("rel_pos_local_attn", [local_attention_left, local_attention_right]) model.change_subsampling_conv_chunking_factor(subsampling_conv_chunking_factor) applied_long_audio_settings = True except Exception as exc: optimization_errors.append(f"long_audio_settings failed: {exc}") model.to(torch.bfloat16) outputs = model.transcribe([audio_file], timestamps=True) except torch.cuda.OutOfMemoryError as exc: raise gr.Error("CUDA out of memory while running Parakeet transcription.") from exc finally: if applied_long_audio_settings: try: model.change_attention_model("rel_pos") model.change_subsampling_conv_chunking_factor(-1) except Exception as exc: optimization_errors.append(f"revert_long_audio_settings failed: {exc}") try: if device == "cuda": model.cpu() gc.collect() if device == "cuda": torch.cuda.empty_cache() except Exception as exc: optimization_errors.append(f"cleanup failed: {exc}") infer_end = time.perf_counter() item = outputs[0] if outputs else None raw_output = { "output": serialize(item), "timestamp_hint": "word timestamps available in output.timestamp['word'] when provided by NeMo", "language_hint": language or "auto", "long_audio_settings": { "duration_seconds": duration_seconds, "is_long_audio": is_long_audio, "threshold_seconds": long_audio_threshold_seconds, "enable_long_audio_optimizations": enable_long_audio_optimizations, "applied_long_audio_settings": applied_long_audio_settings, "local_attention_left": local_attention_left, "local_attention_right": local_attention_right, "subsampling_conv_chunking_factor": subsampling_conv_chunking_factor, "optimization_errors": optimization_errors, }, } return { "raw_output": raw_output, "timing": { "inference_seconds": round(infer_end - infer_start, 4), }, }