Spaces:
Running on Zero
Running on Zero
| import gc | |
| import time | |
| import gradio as gr | |
| import torch | |
| from src.constants import MODEL_IDS, PARAKEET_V3 | |
| from src.utils import get_audio_duration_seconds, serialize | |
| _PARAKEET_MODEL = None | |
| def _get_parakeet_model(): | |
| global _PARAKEET_MODEL | |
| if _PARAKEET_MODEL is not None: | |
| return _PARAKEET_MODEL | |
| try: | |
| import nemo.collections.asr as nemo_asr | |
| except Exception as exc: | |
| raise gr.Error( | |
| "NVIDIA Parakeet backend requested but NeMo ASR package is missing. " | |
| "Add nemo_toolkit[asr] to requirements.txt" | |
| ) from exc | |
| model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_IDS[PARAKEET_V3]) | |
| model.eval() | |
| _PARAKEET_MODEL = model | |
| return _PARAKEET_MODEL | |
| def preload_parakeet_model() -> None: | |
| _get_parakeet_model() | |
| def run_parakeet( | |
| audio_file: str, | |
| language: str, | |
| model_options: dict, | |
| duration_seconds: float | None = None, | |
| ) -> dict: | |
| model = _get_parakeet_model() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| long_audio_threshold_seconds = float(model_options.get("long_audio_threshold_seconds", 480)) | |
| local_attention_left = int(model_options.get("local_attention_left", 256)) | |
| local_attention_right = int(model_options.get("local_attention_right", 256)) | |
| subsampling_conv_chunking_factor = int(model_options.get("subsampling_conv_chunking_factor", 1)) | |
| enable_long_audio_optimizations = bool(model_options.get("enable_long_audio_optimizations", True)) | |
| if duration_seconds is None: | |
| duration_seconds = get_audio_duration_seconds(audio_file) | |
| is_long_audio = duration_seconds is not None and duration_seconds > long_audio_threshold_seconds | |
| applied_long_audio_settings = False | |
| optimization_errors: list[str] = [] | |
| infer_start = time.perf_counter() | |
| try: | |
| model.to(device) | |
| model.to(torch.float32) | |
| if enable_long_audio_optimizations and is_long_audio: | |
| try: | |
| model.change_attention_model("rel_pos_local_attn", [local_attention_left, local_attention_right]) | |
| model.change_subsampling_conv_chunking_factor(subsampling_conv_chunking_factor) | |
| applied_long_audio_settings = True | |
| except Exception as exc: | |
| optimization_errors.append(f"long_audio_settings failed: {exc}") | |
| model.to(torch.bfloat16) | |
| outputs = model.transcribe([audio_file], timestamps=True) | |
| except torch.cuda.OutOfMemoryError as exc: | |
| raise gr.Error("CUDA out of memory while running Parakeet transcription.") from exc | |
| finally: | |
| if applied_long_audio_settings: | |
| try: | |
| model.change_attention_model("rel_pos") | |
| model.change_subsampling_conv_chunking_factor(-1) | |
| except Exception as exc: | |
| optimization_errors.append(f"revert_long_audio_settings failed: {exc}") | |
| try: | |
| if device == "cuda": | |
| model.cpu() | |
| gc.collect() | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| except Exception as exc: | |
| optimization_errors.append(f"cleanup failed: {exc}") | |
| infer_end = time.perf_counter() | |
| item = outputs[0] if outputs else None | |
| raw_output = { | |
| "output": serialize(item), | |
| "timestamp_hint": "word timestamps available in output.timestamp['word'] when provided by NeMo", | |
| "language_hint": language or "auto", | |
| "long_audio_settings": { | |
| "duration_seconds": duration_seconds, | |
| "is_long_audio": is_long_audio, | |
| "threshold_seconds": long_audio_threshold_seconds, | |
| "enable_long_audio_optimizations": enable_long_audio_optimizations, | |
| "applied_long_audio_settings": applied_long_audio_settings, | |
| "local_attention_left": local_attention_left, | |
| "local_attention_right": local_attention_right, | |
| "subsampling_conv_chunking_factor": subsampling_conv_chunking_factor, | |
| "optimization_errors": optimization_errors, | |
| }, | |
| } | |
| return { | |
| "raw_output": raw_output, | |
| "timing": { | |
| "inference_seconds": round(infer_end - infer_start, 4), | |
| }, | |
| } | |