transcribe-diarize / src /models /parakeet_model.py
Ratnesh-dev's picture
Revert repository state to c7d2aa0
ca6855a
import gc
import time
import gradio as gr
import torch
from src.constants import MODEL_IDS, PARAKEET_V3
from src.utils import get_audio_duration_seconds, serialize
_PARAKEET_MODEL = None
def _get_parakeet_model():
global _PARAKEET_MODEL
if _PARAKEET_MODEL is not None:
return _PARAKEET_MODEL
try:
import nemo.collections.asr as nemo_asr
except Exception as exc:
raise gr.Error(
"NVIDIA Parakeet backend requested but NeMo ASR package is missing. "
"Add nemo_toolkit[asr] to requirements.txt"
) from exc
model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_IDS[PARAKEET_V3])
model.eval()
_PARAKEET_MODEL = model
return _PARAKEET_MODEL
def preload_parakeet_model() -> None:
_get_parakeet_model()
def run_parakeet(
audio_file: str,
language: str,
model_options: dict,
duration_seconds: float | None = None,
) -> dict:
model = _get_parakeet_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
long_audio_threshold_seconds = float(model_options.get("long_audio_threshold_seconds", 480))
local_attention_left = int(model_options.get("local_attention_left", 256))
local_attention_right = int(model_options.get("local_attention_right", 256))
subsampling_conv_chunking_factor = int(model_options.get("subsampling_conv_chunking_factor", 1))
enable_long_audio_optimizations = bool(model_options.get("enable_long_audio_optimizations", True))
if duration_seconds is None:
duration_seconds = get_audio_duration_seconds(audio_file)
is_long_audio = duration_seconds is not None and duration_seconds > long_audio_threshold_seconds
applied_long_audio_settings = False
optimization_errors: list[str] = []
infer_start = time.perf_counter()
try:
model.to(device)
model.to(torch.float32)
if enable_long_audio_optimizations and is_long_audio:
try:
model.change_attention_model("rel_pos_local_attn", [local_attention_left, local_attention_right])
model.change_subsampling_conv_chunking_factor(subsampling_conv_chunking_factor)
applied_long_audio_settings = True
except Exception as exc:
optimization_errors.append(f"long_audio_settings failed: {exc}")
model.to(torch.bfloat16)
outputs = model.transcribe([audio_file], timestamps=True)
except torch.cuda.OutOfMemoryError as exc:
raise gr.Error("CUDA out of memory while running Parakeet transcription.") from exc
finally:
if applied_long_audio_settings:
try:
model.change_attention_model("rel_pos")
model.change_subsampling_conv_chunking_factor(-1)
except Exception as exc:
optimization_errors.append(f"revert_long_audio_settings failed: {exc}")
try:
if device == "cuda":
model.cpu()
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
except Exception as exc:
optimization_errors.append(f"cleanup failed: {exc}")
infer_end = time.perf_counter()
item = outputs[0] if outputs else None
raw_output = {
"output": serialize(item),
"timestamp_hint": "word timestamps available in output.timestamp['word'] when provided by NeMo",
"language_hint": language or "auto",
"long_audio_settings": {
"duration_seconds": duration_seconds,
"is_long_audio": is_long_audio,
"threshold_seconds": long_audio_threshold_seconds,
"enable_long_audio_optimizations": enable_long_audio_optimizations,
"applied_long_audio_settings": applied_long_audio_settings,
"local_attention_left": local_attention_left,
"local_attention_right": local_attention_right,
"subsampling_conv_chunking_factor": subsampling_conv_chunking_factor,
"optimization_errors": optimization_errors,
},
}
return {
"raw_output": raw_output,
"timing": {
"inference_seconds": round(infer_end - infer_start, 4),
},
}