File size: 5,289 Bytes
7d1e08d
 
 
 
 
 
 
 
 
 
13fe947
7d1e08d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13fe947
6d9770a
7d1e08d
 
 
 
 
 
 
 
 
13fe947
6d9770a
 
7d1e08d
 
 
6d9770a
 
 
7d1e08d
 
 
 
 
 
 
 
13fe947
7d1e08d
 
 
13fe947
 
7d1e08d
6d9770a
7d1e08d
 
6d9770a
 
7d1e08d
 
 
 
6d9770a
 
13fe947
6d9770a
 
7d1e08d
 
 
 
 
 
 
 
 
6d9770a
 
 
13fe947
6d9770a
 
 
7d1e08d
 
13fe947
7d1e08d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from __future__ import annotations

from dataclasses import dataclass
import os
from pathlib import Path
import shutil
import subprocess
import tempfile
from typing import Any

from hackathon_advisor.config import int_env

DEFAULT_ASR_MODEL_ID = "nvidia/nemotron-speech-streaming-en-0.6b"
DEFAULT_ASR_BACKEND = "nemo-asr"
DEFAULT_ASR_SAMPLE_RATE = 16_000


@dataclass(frozen=True)
class AsrTranscript:
    transcript: str
    model_id: str
    backend: str
    sample_rate: int

    def to_dict(self) -> dict[str, Any]:
        return {
            "transcript": self.transcript,
            "model_id": self.model_id,
            "backend": self.backend,
            "sample_rate": self.sample_rate,
        }


@dataclass(frozen=True)
class AsrStatus:
    backend: str
    model_id: str
    loaded: bool
    sample_rate: int

    def to_dict(self) -> dict[str, Any]:
        return {
            "backend": self.backend,
            "model_id": self.model_id,
            "loaded": self.loaded,
            "sample_rate": self.sample_rate,
        }


class NemotronAsrTranscriber:
    """Nemotron voice input through NVIDIA NeMo ASR."""

    backend = DEFAULT_ASR_BACKEND

    def __init__(
        self,
        model_id: str = DEFAULT_ASR_MODEL_ID,
        sample_rate: int = DEFAULT_ASR_SAMPLE_RATE,
    ) -> None:
        self.model_id = model_id.strip() or DEFAULT_ASR_MODEL_ID
        self.sample_rate = sample_rate
        self._engine: Any | None = None
        self._active_backend = ""
        self._active_model_id = ""

    def status(self) -> AsrStatus:
        return AsrStatus(
            backend=self._active_backend or self.backend,
            model_id=self._active_model_id or self.model_id,
            loaded=self._engine is not None,
            sample_rate=self.sample_rate,
        )

    def transcribe(self, audio_path: Path) -> AsrTranscript:
        source = Path(audio_path)
        if not source.is_file():
            raise RuntimeError("Voice note was not saved before transcription.")
        self._ensure_loaded()
        engine = self._engine
        with tempfile.TemporaryDirectory(prefix="advisor-asr-") as directory:
            wav_path = Path(directory) / "voice.wav"
            normalize_audio_for_asr(source, wav_path, self.sample_rate)
            outputs = engine.transcribe([str(wav_path)], batch_size=1)
            transcript = extract_transcript(outputs).strip()
        if not transcript:
            raise RuntimeError(f"{self._active_backend or self.backend} returned an empty transcript.")
        return AsrTranscript(
            transcript=transcript,
            model_id=self._active_model_id or self.model_id,
            backend=self._active_backend or self.backend,
            sample_rate=self.sample_rate,
        )

    def _ensure_loaded(self) -> None:
        if self._engine is not None:
            return
        self._load_nemo()

    def _load_nemo(self) -> None:
        try:
            import torch
            import nemo.collections.asr as nemo_asr
        except ImportError as error:
            raise RuntimeError(
                "Nemotron voice input requires NVIDIA NeMo ASR. Install `nemo_toolkit[asr]` "
                "before enabling voice transcription."
            ) from error
        model = nemo_asr.models.ASRModel.from_pretrained(model_name=self.model_id)
        device = os.environ.get("ADVISOR_ASR_DEVICE", "").strip() or ("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        model.eval()
        self._engine = model
        self._active_backend = self.backend
        self._active_model_id = self.model_id


def create_asr_transcriber() -> NemotronAsrTranscriber:
    sample_rate = int_env("ADVISOR_ASR_SAMPLE_RATE", DEFAULT_ASR_SAMPLE_RATE, minimum=1)
    return NemotronAsrTranscriber(
        model_id=os.environ.get("ADVISOR_ASR_MODEL_ID", DEFAULT_ASR_MODEL_ID),
        sample_rate=sample_rate,
    )


def normalize_audio_for_asr(source: Path, target: Path, sample_rate: int = DEFAULT_ASR_SAMPLE_RATE) -> None:
    ffmpeg = shutil.which("ffmpeg")
    if not ffmpeg:
        raise RuntimeError("Voice transcription requires ffmpeg to normalize audio.")
    command = [
        ffmpeg,
        "-hide_banner",
        "-loglevel",
        "error",
        "-y",
        "-i",
        str(source),
        "-ac",
        "1",
        "-ar",
        str(sample_rate),
        "-sample_fmt",
        "s16",
        str(target),
    ]
    completed = subprocess.run(command, check=False, capture_output=True, text=True)
    if completed.returncode != 0:
        message = completed.stderr.strip() or "ffmpeg could not read this audio file."
        raise RuntimeError(message)


def extract_transcript(outputs: Any) -> str:
    if isinstance(outputs, str):
        return outputs
    if isinstance(outputs, dict):
        return str(outputs.get("text") or outputs.get("transcript") or "")
    if isinstance(outputs, (list, tuple)):
        if not outputs:
            return ""
        return extract_transcript(outputs[0])
    text = getattr(outputs, "text", None)
    if text is not None:
        return str(text)
    transcript = getattr(outputs, "transcript", None)
    if transcript is not None:
        return str(transcript)
    return str(outputs)