Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- config.py +3 -0
- inference.py +230 -58
- requirements.txt +1 -0
config.py
CHANGED
|
@@ -62,6 +62,9 @@ class VoiceRuntimeConfig:
|
|
| 62 |
diarization_min_speakers: int = int(os.environ.get("VOICE_DIARIZATION_MIN_SPEAKERS", "0"))
|
| 63 |
diarization_max_speakers: int = int(os.environ.get("VOICE_DIARIZATION_MAX_SPEAKERS", "0"))
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
@classmethod
|
| 66 |
def from_env(cls) -> "VoiceRuntimeConfig":
|
| 67 |
return cls()
|
|
|
|
| 62 |
diarization_min_speakers: int = int(os.environ.get("VOICE_DIARIZATION_MIN_SPEAKERS", "0"))
|
| 63 |
diarization_max_speakers: int = int(os.environ.get("VOICE_DIARIZATION_MAX_SPEAKERS", "0"))
|
| 64 |
|
| 65 |
+
groq_api_key: str = os.environ.get("GROQ_API_KEY", "")
|
| 66 |
+
groq_model_id: str = os.environ.get("GROQ_MODEL_ID", "whisper-large-v3-turbo")
|
| 67 |
+
|
| 68 |
@classmethod
|
| 69 |
def from_env(cls) -> "VoiceRuntimeConfig":
|
| 70 |
return cls()
|
inference.py
CHANGED
|
@@ -1,58 +1,230 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
cls.
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import threading
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
from faster_whisper import WhisperModel
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from .config import VoiceRuntimeConfig
|
| 14 |
+
except ImportError: # HF flat-root execution fallback
|
| 15 |
+
from config import VoiceRuntimeConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# Local Whisper (CPU fallback)
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
class WhisperRuntime:
|
| 23 |
+
_lock = threading.Lock()
|
| 24 |
+
_model: WhisperModel | None = None
|
| 25 |
+
_loaded_id: str | None = None
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def get_model(cls, config: VoiceRuntimeConfig) -> WhisperModel:
|
| 29 |
+
with cls._lock:
|
| 30 |
+
if cls._model is not None and cls._loaded_id == config.runtime_model_id:
|
| 31 |
+
return cls._model
|
| 32 |
+
cls._model = WhisperModel(
|
| 33 |
+
config.runtime_model_id,
|
| 34 |
+
device="cpu",
|
| 35 |
+
compute_type=config.compute_type,
|
| 36 |
+
cpu_threads=config.cpu_threads,
|
| 37 |
+
num_workers=1,
|
| 38 |
+
)
|
| 39 |
+
cls._loaded_id = config.runtime_model_id
|
| 40 |
+
return cls._model
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _transcribe_local(
|
| 44 |
+
wav_path: str,
|
| 45 |
+
config: VoiceRuntimeConfig,
|
| 46 |
+
language_hint: str | None,
|
| 47 |
+
) -> tuple[list[Any], str, str]:
|
| 48 |
+
model = WhisperRuntime.get_model(config)
|
| 49 |
+
requested_language = None if not language_hint or language_hint == "auto" else language_hint
|
| 50 |
+
|
| 51 |
+
segments_iter, info = model.transcribe(
|
| 52 |
+
wav_path,
|
| 53 |
+
task="transcribe",
|
| 54 |
+
language=requested_language,
|
| 55 |
+
beam_size=1,
|
| 56 |
+
best_of=1,
|
| 57 |
+
temperature=0.0,
|
| 58 |
+
condition_on_previous_text=False,
|
| 59 |
+
word_timestamps=True,
|
| 60 |
+
vad_filter=False,
|
| 61 |
+
)
|
| 62 |
+
segments = list(segments_iter)
|
| 63 |
+
detected_language = (info.language or requested_language or "unknown").lower()
|
| 64 |
+
language_source = "request" if requested_language else "auto_detect"
|
| 65 |
+
return segments, detected_language, language_source
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
# Groq API path
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
# Stay safely under Groq's 25 MB per-file limit
|
| 73 |
+
_GROQ_MAX_BYTES = 23 * 1024 * 1024
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class _GroqWord:
|
| 77 |
+
"""Mimics faster_whisper Word namedtuple so service.py needs zero changes."""
|
| 78 |
+
__slots__ = ("word", "start", "end", "probability")
|
| 79 |
+
|
| 80 |
+
def __init__(self, word: str, start: float, end: float) -> None:
|
| 81 |
+
self.word = word
|
| 82 |
+
self.start = start
|
| 83 |
+
self.end = end
|
| 84 |
+
self.probability = None # Groq doesn't provide per-word confidence
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class _GroqSegment:
|
| 88 |
+
"""Mimics faster_whisper Segment namedtuple so service.py needs zero changes."""
|
| 89 |
+
__slots__ = ("start", "end", "text", "words")
|
| 90 |
+
|
| 91 |
+
def __init__(self, start: float, end: float, text: str, words: list[_GroqWord]) -> None:
|
| 92 |
+
self.start = start
|
| 93 |
+
self.end = end
|
| 94 |
+
self.text = text
|
| 95 |
+
self.words = words
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _chunk_wav(wav_path: str, sample_rate: int) -> list[tuple[str, float]]:
|
| 99 |
+
"""
|
| 100 |
+
Split WAV into chunks that fit within _GROQ_MAX_BYTES.
|
| 101 |
+
Returns list of (chunk_wav_path, start_time_offset_sec).
|
| 102 |
+
Chunks are written to a temp dir and must be cleaned up by the caller.
|
| 103 |
+
"""
|
| 104 |
+
audio, _ = sf.read(wav_path, dtype="float32")
|
| 105 |
+
|
| 106 |
+
bytes_per_sec = sample_rate * 2 # mono PCM_16 = 2 bytes/sample
|
| 107 |
+
max_samples = int(math.floor(_GROQ_MAX_BYTES / bytes_per_sec) * sample_rate)
|
| 108 |
+
|
| 109 |
+
tmp_dir = tempfile.mkdtemp(prefix="groq-chunks-")
|
| 110 |
+
chunks: list[tuple[str, float]] = []
|
| 111 |
+
cursor = 0
|
| 112 |
+
idx = 0
|
| 113 |
+
|
| 114 |
+
while cursor < len(audio):
|
| 115 |
+
end = min(cursor + max_samples, len(audio))
|
| 116 |
+
chunk_path = os.path.join(tmp_dir, f"chunk_{idx:04d}.wav")
|
| 117 |
+
sf.write(chunk_path, audio[cursor:end], sample_rate, subtype="PCM_16")
|
| 118 |
+
chunks.append((chunk_path, cursor / sample_rate))
|
| 119 |
+
cursor = end
|
| 120 |
+
idx += 1
|
| 121 |
+
|
| 122 |
+
return chunks
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _call_groq(
|
| 126 |
+
wav_path: str,
|
| 127 |
+
api_key: str,
|
| 128 |
+
groq_model: str,
|
| 129 |
+
language_hint: str | None,
|
| 130 |
+
) -> dict:
|
| 131 |
+
"""Call Groq transcriptions endpoint for a single chunk file."""
|
| 132 |
+
from groq import Groq # imported lazily so local-only installs don't break
|
| 133 |
+
|
| 134 |
+
client = Groq(api_key=api_key)
|
| 135 |
+
kwargs: dict[str, Any] = {
|
| 136 |
+
"model": groq_model,
|
| 137 |
+
"response_format": "verbose_json",
|
| 138 |
+
"timestamp_granularities": ["word", "segment"],
|
| 139 |
+
}
|
| 140 |
+
if language_hint and language_hint != "auto":
|
| 141 |
+
kwargs["language"] = language_hint
|
| 142 |
+
|
| 143 |
+
with open(wav_path, "rb") as f:
|
| 144 |
+
result = client.audio.transcriptions.create(file=f, **kwargs)
|
| 145 |
+
|
| 146 |
+
return result.model_dump() if hasattr(result, "model_dump") else dict(result)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _transcribe_groq(
|
| 150 |
+
wav_path: str,
|
| 151 |
+
config: VoiceRuntimeConfig,
|
| 152 |
+
language_hint: str | None,
|
| 153 |
+
) -> tuple[list[Any], str, str]:
|
| 154 |
+
api_key = config.groq_api_key
|
| 155 |
+
groq_model = config.groq_model_id
|
| 156 |
+
requested_language = None if not language_hint or language_hint == "auto" else language_hint
|
| 157 |
+
|
| 158 |
+
chunks = _chunk_wav(wav_path, config.sample_rate)
|
| 159 |
+
all_segments: list[_GroqSegment] = []
|
| 160 |
+
detected_language: str = requested_language or "unknown"
|
| 161 |
+
|
| 162 |
+
for chunk_path, time_offset in chunks:
|
| 163 |
+
try:
|
| 164 |
+
result = _call_groq(chunk_path, api_key, groq_model, language_hint)
|
| 165 |
+
|
| 166 |
+
# Capture language from the first chunk that reports it
|
| 167 |
+
if detected_language in ("unknown", None):
|
| 168 |
+
detected_language = (result.get("language") or "unknown").lower()
|
| 169 |
+
|
| 170 |
+
raw_segments: list[dict] = result.get("segments") or []
|
| 171 |
+
raw_words: list[dict] = result.get("words") or []
|
| 172 |
+
|
| 173 |
+
# Build segment-id → words mapping by time overlap
|
| 174 |
+
seg_words: dict[int, list[_GroqWord]] = {}
|
| 175 |
+
for w in raw_words:
|
| 176 |
+
w_start = float(w.get("start", 0.0)) + time_offset
|
| 177 |
+
w_end = float(w.get("end", w_start)) + time_offset
|
| 178 |
+
w_text = str(w.get("word", "")).strip()
|
| 179 |
+
if not w_text:
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
best_sid: int = 0
|
| 183 |
+
best_overlap: float = -1.0
|
| 184 |
+
for seg in raw_segments:
|
| 185 |
+
s_start = float(seg.get("start", 0.0)) + time_offset
|
| 186 |
+
s_end = float(seg.get("end", s_start)) + time_offset
|
| 187 |
+
overlap = min(w_end, s_end) - max(w_start, s_start)
|
| 188 |
+
if overlap > best_overlap:
|
| 189 |
+
best_overlap = overlap
|
| 190 |
+
best_sid = int(seg.get("id", 0))
|
| 191 |
+
|
| 192 |
+
seg_words.setdefault(best_sid, []).append(
|
| 193 |
+
_GroqWord(word=w_text, start=w_start, end=w_end)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
for seg in raw_segments:
|
| 197 |
+
sid = int(seg.get("id", 0))
|
| 198 |
+
all_segments.append(_GroqSegment(
|
| 199 |
+
start=float(seg.get("start", 0.0)) + time_offset,
|
| 200 |
+
end=float(seg.get("end", 0.0)) + time_offset,
|
| 201 |
+
text=str(seg.get("text", "")).strip(),
|
| 202 |
+
words=seg_words.get(sid, []),
|
| 203 |
+
))
|
| 204 |
+
finally:
|
| 205 |
+
try:
|
| 206 |
+
os.remove(chunk_path)
|
| 207 |
+
except OSError:
|
| 208 |
+
pass
|
| 209 |
+
|
| 210 |
+
language_source = "request" if requested_language else "auto_detect"
|
| 211 |
+
return all_segments, detected_language, language_source
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ---------------------------------------------------------------------------
|
| 215 |
+
# Public entry point — called by service.py
|
| 216 |
+
# ---------------------------------------------------------------------------
|
| 217 |
+
|
| 218 |
+
def transcribe(
|
| 219 |
+
wav_path: str,
|
| 220 |
+
config: VoiceRuntimeConfig,
|
| 221 |
+
language_hint: str | None,
|
| 222 |
+
) -> tuple[list[Any], str, str]:
|
| 223 |
+
"""
|
| 224 |
+
Routes to Groq API when GROQ_API_KEY is configured, otherwise falls back
|
| 225 |
+
to local faster-whisper. Both paths return objects compatible with
|
| 226 |
+
_build_alignment_payload in service.py.
|
| 227 |
+
"""
|
| 228 |
+
if config.groq_api_key:
|
| 229 |
+
return _transcribe_groq(wav_path, config, language_hint)
|
| 230 |
+
return _transcribe_local(wav_path, config, language_hint)
|
requirements.txt
CHANGED
|
@@ -5,3 +5,4 @@ faster-whisper>=1.1.1
|
|
| 5 |
numpy>=1.26.0
|
| 6 |
soundfile>=0.12.1
|
| 7 |
pyannote.audio>=3.3.2
|
|
|
|
|
|
| 5 |
numpy>=1.26.0
|
| 6 |
soundfile>=0.12.1
|
| 7 |
pyannote.audio>=3.3.2
|
| 8 |
+
groq>=0.9.0
|