unknownfriend00007 commited on
Commit
2ffbd85
·
verified ·
1 Parent(s): 48c3b28

Upload 8 files

Browse files
Files changed (3) hide show
  1. config.py +3 -0
  2. inference.py +230 -58
  3. requirements.txt +1 -0
config.py CHANGED
@@ -62,6 +62,9 @@ class VoiceRuntimeConfig:
62
  diarization_min_speakers: int = int(os.environ.get("VOICE_DIARIZATION_MIN_SPEAKERS", "0"))
63
  diarization_max_speakers: int = int(os.environ.get("VOICE_DIARIZATION_MAX_SPEAKERS", "0"))
64
 
 
 
 
65
  @classmethod
66
  def from_env(cls) -> "VoiceRuntimeConfig":
67
  return cls()
 
62
  diarization_min_speakers: int = int(os.environ.get("VOICE_DIARIZATION_MIN_SPEAKERS", "0"))
63
  diarization_max_speakers: int = int(os.environ.get("VOICE_DIARIZATION_MAX_SPEAKERS", "0"))
64
 
65
+ groq_api_key: str = os.environ.get("GROQ_API_KEY", "")
66
+ groq_model_id: str = os.environ.get("GROQ_MODEL_ID", "whisper-large-v3-turbo")
67
+
68
  @classmethod
69
  def from_env(cls) -> "VoiceRuntimeConfig":
70
  return cls()
inference.py CHANGED
@@ -1,58 +1,230 @@
1
- from __future__ import annotations
2
-
3
- import threading
4
- from typing import Any
5
-
6
- from faster_whisper import WhisperModel
7
-
8
- try:
9
- from .config import VoiceRuntimeConfig
10
- except ImportError: # HF flat-root execution fallback
11
- from config import VoiceRuntimeConfig
12
-
13
-
14
- class WhisperRuntime:
15
- _lock = threading.Lock()
16
- _model: WhisperModel | None = None
17
- _loaded_id: str | None = None
18
-
19
- @classmethod
20
- def get_model(cls, config: VoiceRuntimeConfig) -> WhisperModel:
21
- with cls._lock:
22
- if cls._model is not None and cls._loaded_id == config.runtime_model_id:
23
- return cls._model
24
-
25
- cls._model = WhisperModel(
26
- config.runtime_model_id,
27
- device="cpu",
28
- compute_type=config.compute_type,
29
- cpu_threads=config.cpu_threads,
30
- num_workers=1,
31
- )
32
- cls._loaded_id = config.runtime_model_id
33
- return cls._model
34
-
35
-
36
- def transcribe(
37
- wav_path: str,
38
- config: VoiceRuntimeConfig,
39
- language_hint: str | None,
40
- ) -> tuple[list[Any], str, str]:
41
- model = WhisperRuntime.get_model(config)
42
- requested_language = None if not language_hint or language_hint == "auto" else language_hint
43
-
44
- segments_iter, info = model.transcribe(
45
- wav_path,
46
- task="transcribe",
47
- language=requested_language,
48
- beam_size=1,
49
- best_of=1,
50
- temperature=0.0,
51
- condition_on_previous_text=False,
52
- word_timestamps=True,
53
- vad_filter=False,
54
- )
55
- segments = list(segments_iter)
56
- detected_language = (info.language or requested_language or "unknown").lower()
57
- language_source = "request" if requested_language else "auto_detect"
58
- return segments, detected_language, language_source
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import os
5
+ import tempfile
6
+ import threading
7
+ from typing import Any
8
+
9
+ import soundfile as sf
10
+ from faster_whisper import WhisperModel
11
+
12
+ try:
13
+ from .config import VoiceRuntimeConfig
14
+ except ImportError: # HF flat-root execution fallback
15
+ from config import VoiceRuntimeConfig
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Local Whisper (CPU fallback)
20
+ # ---------------------------------------------------------------------------
21
+
22
+ class WhisperRuntime:
23
+ _lock = threading.Lock()
24
+ _model: WhisperModel | None = None
25
+ _loaded_id: str | None = None
26
+
27
+ @classmethod
28
+ def get_model(cls, config: VoiceRuntimeConfig) -> WhisperModel:
29
+ with cls._lock:
30
+ if cls._model is not None and cls._loaded_id == config.runtime_model_id:
31
+ return cls._model
32
+ cls._model = WhisperModel(
33
+ config.runtime_model_id,
34
+ device="cpu",
35
+ compute_type=config.compute_type,
36
+ cpu_threads=config.cpu_threads,
37
+ num_workers=1,
38
+ )
39
+ cls._loaded_id = config.runtime_model_id
40
+ return cls._model
41
+
42
+
43
+ def _transcribe_local(
44
+ wav_path: str,
45
+ config: VoiceRuntimeConfig,
46
+ language_hint: str | None,
47
+ ) -> tuple[list[Any], str, str]:
48
+ model = WhisperRuntime.get_model(config)
49
+ requested_language = None if not language_hint or language_hint == "auto" else language_hint
50
+
51
+ segments_iter, info = model.transcribe(
52
+ wav_path,
53
+ task="transcribe",
54
+ language=requested_language,
55
+ beam_size=1,
56
+ best_of=1,
57
+ temperature=0.0,
58
+ condition_on_previous_text=False,
59
+ word_timestamps=True,
60
+ vad_filter=False,
61
+ )
62
+ segments = list(segments_iter)
63
+ detected_language = (info.language or requested_language or "unknown").lower()
64
+ language_source = "request" if requested_language else "auto_detect"
65
+ return segments, detected_language, language_source
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Groq API path
70
+ # ---------------------------------------------------------------------------
71
+
72
+ # Stay safely under Groq's 25 MB per-file limit
73
+ _GROQ_MAX_BYTES = 23 * 1024 * 1024
74
+
75
+
76
+ class _GroqWord:
77
+ """Mimics faster_whisper Word namedtuple so service.py needs zero changes."""
78
+ __slots__ = ("word", "start", "end", "probability")
79
+
80
+ def __init__(self, word: str, start: float, end: float) -> None:
81
+ self.word = word
82
+ self.start = start
83
+ self.end = end
84
+ self.probability = None # Groq doesn't provide per-word confidence
85
+
86
+
87
+ class _GroqSegment:
88
+ """Mimics faster_whisper Segment namedtuple so service.py needs zero changes."""
89
+ __slots__ = ("start", "end", "text", "words")
90
+
91
+ def __init__(self, start: float, end: float, text: str, words: list[_GroqWord]) -> None:
92
+ self.start = start
93
+ self.end = end
94
+ self.text = text
95
+ self.words = words
96
+
97
+
98
+ def _chunk_wav(wav_path: str, sample_rate: int) -> list[tuple[str, float]]:
99
+ """
100
+ Split WAV into chunks that fit within _GROQ_MAX_BYTES.
101
+ Returns list of (chunk_wav_path, start_time_offset_sec).
102
+ Chunks are written to a temp dir and must be cleaned up by the caller.
103
+ """
104
+ audio, _ = sf.read(wav_path, dtype="float32")
105
+
106
+ bytes_per_sec = sample_rate * 2 # mono PCM_16 = 2 bytes/sample
107
+ max_samples = int(math.floor(_GROQ_MAX_BYTES / bytes_per_sec) * sample_rate)
108
+
109
+ tmp_dir = tempfile.mkdtemp(prefix="groq-chunks-")
110
+ chunks: list[tuple[str, float]] = []
111
+ cursor = 0
112
+ idx = 0
113
+
114
+ while cursor < len(audio):
115
+ end = min(cursor + max_samples, len(audio))
116
+ chunk_path = os.path.join(tmp_dir, f"chunk_{idx:04d}.wav")
117
+ sf.write(chunk_path, audio[cursor:end], sample_rate, subtype="PCM_16")
118
+ chunks.append((chunk_path, cursor / sample_rate))
119
+ cursor = end
120
+ idx += 1
121
+
122
+ return chunks
123
+
124
+
125
+ def _call_groq(
126
+ wav_path: str,
127
+ api_key: str,
128
+ groq_model: str,
129
+ language_hint: str | None,
130
+ ) -> dict:
131
+ """Call Groq transcriptions endpoint for a single chunk file."""
132
+ from groq import Groq # imported lazily so local-only installs don't break
133
+
134
+ client = Groq(api_key=api_key)
135
+ kwargs: dict[str, Any] = {
136
+ "model": groq_model,
137
+ "response_format": "verbose_json",
138
+ "timestamp_granularities": ["word", "segment"],
139
+ }
140
+ if language_hint and language_hint != "auto":
141
+ kwargs["language"] = language_hint
142
+
143
+ with open(wav_path, "rb") as f:
144
+ result = client.audio.transcriptions.create(file=f, **kwargs)
145
+
146
+ return result.model_dump() if hasattr(result, "model_dump") else dict(result)
147
+
148
+
149
+ def _transcribe_groq(
150
+ wav_path: str,
151
+ config: VoiceRuntimeConfig,
152
+ language_hint: str | None,
153
+ ) -> tuple[list[Any], str, str]:
154
+ api_key = config.groq_api_key
155
+ groq_model = config.groq_model_id
156
+ requested_language = None if not language_hint or language_hint == "auto" else language_hint
157
+
158
+ chunks = _chunk_wav(wav_path, config.sample_rate)
159
+ all_segments: list[_GroqSegment] = []
160
+ detected_language: str = requested_language or "unknown"
161
+
162
+ for chunk_path, time_offset in chunks:
163
+ try:
164
+ result = _call_groq(chunk_path, api_key, groq_model, language_hint)
165
+
166
+ # Capture language from the first chunk that reports it
167
+ if detected_language in ("unknown", None):
168
+ detected_language = (result.get("language") or "unknown").lower()
169
+
170
+ raw_segments: list[dict] = result.get("segments") or []
171
+ raw_words: list[dict] = result.get("words") or []
172
+
173
+ # Build segment-id → words mapping by time overlap
174
+ seg_words: dict[int, list[_GroqWord]] = {}
175
+ for w in raw_words:
176
+ w_start = float(w.get("start", 0.0)) + time_offset
177
+ w_end = float(w.get("end", w_start)) + time_offset
178
+ w_text = str(w.get("word", "")).strip()
179
+ if not w_text:
180
+ continue
181
+
182
+ best_sid: int = 0
183
+ best_overlap: float = -1.0
184
+ for seg in raw_segments:
185
+ s_start = float(seg.get("start", 0.0)) + time_offset
186
+ s_end = float(seg.get("end", s_start)) + time_offset
187
+ overlap = min(w_end, s_end) - max(w_start, s_start)
188
+ if overlap > best_overlap:
189
+ best_overlap = overlap
190
+ best_sid = int(seg.get("id", 0))
191
+
192
+ seg_words.setdefault(best_sid, []).append(
193
+ _GroqWord(word=w_text, start=w_start, end=w_end)
194
+ )
195
+
196
+ for seg in raw_segments:
197
+ sid = int(seg.get("id", 0))
198
+ all_segments.append(_GroqSegment(
199
+ start=float(seg.get("start", 0.0)) + time_offset,
200
+ end=float(seg.get("end", 0.0)) + time_offset,
201
+ text=str(seg.get("text", "")).strip(),
202
+ words=seg_words.get(sid, []),
203
+ ))
204
+ finally:
205
+ try:
206
+ os.remove(chunk_path)
207
+ except OSError:
208
+ pass
209
+
210
+ language_source = "request" if requested_language else "auto_detect"
211
+ return all_segments, detected_language, language_source
212
+
213
+
214
+ # ---------------------------------------------------------------------------
215
+ # Public entry point — called by service.py
216
+ # ---------------------------------------------------------------------------
217
+
218
+ def transcribe(
219
+ wav_path: str,
220
+ config: VoiceRuntimeConfig,
221
+ language_hint: str | None,
222
+ ) -> tuple[list[Any], str, str]:
223
+ """
224
+ Routes to Groq API when GROQ_API_KEY is configured, otherwise falls back
225
+ to local faster-whisper. Both paths return objects compatible with
226
+ _build_alignment_payload in service.py.
227
+ """
228
+ if config.groq_api_key:
229
+ return _transcribe_groq(wav_path, config, language_hint)
230
+ return _transcribe_local(wav_path, config, language_hint)
requirements.txt CHANGED
@@ -5,3 +5,4 @@ faster-whisper>=1.1.1
5
  numpy>=1.26.0
6
  soundfile>=0.12.1
7
  pyannote.audio>=3.3.2
 
 
5
  numpy>=1.26.0
6
  soundfile>=0.12.1
7
  pyannote.audio>=3.3.2
8
+ groq>=0.9.0