Spaces:
Running on Zero
Running on Zero
File size: 5,289 Bytes
7d1e08d 13fe947 7d1e08d 13fe947 6d9770a 7d1e08d 13fe947 6d9770a 7d1e08d 6d9770a 7d1e08d 13fe947 7d1e08d 13fe947 7d1e08d 6d9770a 7d1e08d 6d9770a 7d1e08d 6d9770a 13fe947 6d9770a 7d1e08d 6d9770a 13fe947 6d9770a 7d1e08d 13fe947 7d1e08d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | from __future__ import annotations
from dataclasses import dataclass
import os
from pathlib import Path
import shutil
import subprocess
import tempfile
from typing import Any
from hackathon_advisor.config import int_env
DEFAULT_ASR_MODEL_ID = "nvidia/nemotron-speech-streaming-en-0.6b"
DEFAULT_ASR_BACKEND = "nemo-asr"
DEFAULT_ASR_SAMPLE_RATE = 16_000
@dataclass(frozen=True)
class AsrTranscript:
transcript: str
model_id: str
backend: str
sample_rate: int
def to_dict(self) -> dict[str, Any]:
return {
"transcript": self.transcript,
"model_id": self.model_id,
"backend": self.backend,
"sample_rate": self.sample_rate,
}
@dataclass(frozen=True)
class AsrStatus:
backend: str
model_id: str
loaded: bool
sample_rate: int
def to_dict(self) -> dict[str, Any]:
return {
"backend": self.backend,
"model_id": self.model_id,
"loaded": self.loaded,
"sample_rate": self.sample_rate,
}
class NemotronAsrTranscriber:
"""Nemotron voice input through NVIDIA NeMo ASR."""
backend = DEFAULT_ASR_BACKEND
def __init__(
self,
model_id: str = DEFAULT_ASR_MODEL_ID,
sample_rate: int = DEFAULT_ASR_SAMPLE_RATE,
) -> None:
self.model_id = model_id.strip() or DEFAULT_ASR_MODEL_ID
self.sample_rate = sample_rate
self._engine: Any | None = None
self._active_backend = ""
self._active_model_id = ""
def status(self) -> AsrStatus:
return AsrStatus(
backend=self._active_backend or self.backend,
model_id=self._active_model_id or self.model_id,
loaded=self._engine is not None,
sample_rate=self.sample_rate,
)
def transcribe(self, audio_path: Path) -> AsrTranscript:
source = Path(audio_path)
if not source.is_file():
raise RuntimeError("Voice note was not saved before transcription.")
self._ensure_loaded()
engine = self._engine
with tempfile.TemporaryDirectory(prefix="advisor-asr-") as directory:
wav_path = Path(directory) / "voice.wav"
normalize_audio_for_asr(source, wav_path, self.sample_rate)
outputs = engine.transcribe([str(wav_path)], batch_size=1)
transcript = extract_transcript(outputs).strip()
if not transcript:
raise RuntimeError(f"{self._active_backend or self.backend} returned an empty transcript.")
return AsrTranscript(
transcript=transcript,
model_id=self._active_model_id or self.model_id,
backend=self._active_backend or self.backend,
sample_rate=self.sample_rate,
)
def _ensure_loaded(self) -> None:
if self._engine is not None:
return
self._load_nemo()
def _load_nemo(self) -> None:
try:
import torch
import nemo.collections.asr as nemo_asr
except ImportError as error:
raise RuntimeError(
"Nemotron voice input requires NVIDIA NeMo ASR. Install `nemo_toolkit[asr]` "
"before enabling voice transcription."
) from error
model = nemo_asr.models.ASRModel.from_pretrained(model_name=self.model_id)
device = os.environ.get("ADVISOR_ASR_DEVICE", "").strip() or ("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
self._engine = model
self._active_backend = self.backend
self._active_model_id = self.model_id
def create_asr_transcriber() -> NemotronAsrTranscriber:
sample_rate = int_env("ADVISOR_ASR_SAMPLE_RATE", DEFAULT_ASR_SAMPLE_RATE, minimum=1)
return NemotronAsrTranscriber(
model_id=os.environ.get("ADVISOR_ASR_MODEL_ID", DEFAULT_ASR_MODEL_ID),
sample_rate=sample_rate,
)
def normalize_audio_for_asr(source: Path, target: Path, sample_rate: int = DEFAULT_ASR_SAMPLE_RATE) -> None:
ffmpeg = shutil.which("ffmpeg")
if not ffmpeg:
raise RuntimeError("Voice transcription requires ffmpeg to normalize audio.")
command = [
ffmpeg,
"-hide_banner",
"-loglevel",
"error",
"-y",
"-i",
str(source),
"-ac",
"1",
"-ar",
str(sample_rate),
"-sample_fmt",
"s16",
str(target),
]
completed = subprocess.run(command, check=False, capture_output=True, text=True)
if completed.returncode != 0:
message = completed.stderr.strip() or "ffmpeg could not read this audio file."
raise RuntimeError(message)
def extract_transcript(outputs: Any) -> str:
if isinstance(outputs, str):
return outputs
if isinstance(outputs, dict):
return str(outputs.get("text") or outputs.get("transcript") or "")
if isinstance(outputs, (list, tuple)):
if not outputs:
return ""
return extract_transcript(outputs[0])
text = getattr(outputs, "text", None)
if text is not None:
return str(text)
transcript = getattr(outputs, "transcript", None)
if transcript is not None:
return str(transcript)
return str(outputs)
|