Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -7,7 +7,6 @@ import io
|
|
| 7 |
import json
|
| 8 |
import logging
|
| 9 |
import os
|
| 10 |
-
import re
|
| 11 |
import threading
|
| 12 |
from contextlib import asynccontextmanager
|
| 13 |
|
|
@@ -20,6 +19,8 @@ from fastapi.responses import JSONResponse
|
|
| 20 |
from pydub import AudioSegment
|
| 21 |
from sse_starlette.sse import EventSourceResponse
|
| 22 |
|
|
|
|
|
|
|
| 23 |
logger = logging.getLogger("gpu_service")
|
| 24 |
|
| 25 |
# ---------------------------------------------------------------------------
|
|
@@ -32,18 +33,14 @@ PYANNOTE_MIN_SPEAKERS = int(os.environ.get("PYANNOTE_MIN_SPEAKERS", "1"))
|
|
| 32 |
PYANNOTE_MAX_SPEAKERS = int(os.environ.get("PYANNOTE_MAX_SPEAKERS", "10"))
|
| 33 |
TARGET_SR = 16000
|
| 34 |
|
|
|
|
|
|
|
| 35 |
# ---------------------------------------------------------------------------
|
| 36 |
# Singletons
|
| 37 |
# ---------------------------------------------------------------------------
|
| 38 |
_diarize_pipeline = None
|
| 39 |
_embed_model = None
|
| 40 |
-
|
| 41 |
-
_voxtral_processor = None
|
| 42 |
-
|
| 43 |
-
VOXTRAL_MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
|
| 44 |
-
|
| 45 |
-
# Markers to strip from Voxtral output
|
| 46 |
-
_MARKER_RE = re.compile(r"\[STREAMING_PAD\]|\[STREAMING_WORD\]")
|
| 47 |
|
| 48 |
|
| 49 |
def _load_diarize_pipeline():
|
|
@@ -68,26 +65,12 @@ def _load_embed_model():
|
|
| 68 |
|
| 69 |
|
| 70 |
def _load_voxtral():
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
logger.info("Loading Voxtral model %s ...", VOXTRAL_MODEL_ID)
|
| 77 |
-
_voxtral_processor = AutoProcessor.from_pretrained(
|
| 78 |
-
VOXTRAL_MODEL_ID, trust_remote_code=True
|
| 79 |
-
)
|
| 80 |
-
_voxtral_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
| 81 |
-
VOXTRAL_MODEL_ID, torch_dtype=torch.float16, trust_remote_code=True
|
| 82 |
-
).to("cuda")
|
| 83 |
logger.info("Voxtral model loaded.")
|
| 84 |
-
return
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def _clean_voxtral_text(text: str) -> str:
|
| 88 |
-
"""Strip Voxtral streaming markers and collapse whitespace."""
|
| 89 |
-
text = _MARKER_RE.sub("", text)
|
| 90 |
-
return " ".join(text.split()).strip()
|
| 91 |
|
| 92 |
|
| 93 |
# ---------------------------------------------------------------------------
|
|
@@ -198,25 +181,13 @@ async def embed(
|
|
| 198 |
|
| 199 |
|
| 200 |
@app.post("/transcribe")
|
| 201 |
-
async def transcribe(
|
| 202 |
-
audio: UploadFile = File(...),
|
| 203 |
-
prompt: str = Form("Transcribe this audio."),
|
| 204 |
-
):
|
| 205 |
try:
|
| 206 |
raw = await audio.read()
|
| 207 |
audio_16k = prepare_audio(raw)
|
| 208 |
|
| 209 |
-
model
|
| 210 |
-
|
| 211 |
-
audios=audio_16k,
|
| 212 |
-
sampling_rate=TARGET_SR,
|
| 213 |
-
text=prompt,
|
| 214 |
-
return_tensors="pt",
|
| 215 |
-
).to("cuda")
|
| 216 |
-
|
| 217 |
-
output_ids = model.generate(**inputs, max_new_tokens=1024)
|
| 218 |
-
text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 219 |
-
text = _clean_voxtral_text(text)
|
| 220 |
|
| 221 |
return {"text": text}
|
| 222 |
except Exception as e:
|
|
@@ -225,10 +196,7 @@ async def transcribe(
|
|
| 225 |
|
| 226 |
|
| 227 |
@app.post("/transcribe/stream")
|
| 228 |
-
async def transcribe_stream(
|
| 229 |
-
audio: UploadFile = File(...),
|
| 230 |
-
prompt: str = Form("Transcribe this audio."),
|
| 231 |
-
):
|
| 232 |
try:
|
| 233 |
raw = await audio.read()
|
| 234 |
audio_16k = prepare_audio(raw)
|
|
@@ -238,33 +206,25 @@ async def transcribe_stream(
|
|
| 238 |
|
| 239 |
async def event_generator():
|
| 240 |
try:
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
model, processor = _load_voxtral()
|
| 244 |
-
inputs = processor(
|
| 245 |
-
audios=audio_16k,
|
| 246 |
-
sampling_rate=TARGET_SR,
|
| 247 |
-
text=prompt,
|
| 248 |
-
return_tensors="pt",
|
| 249 |
-
).to("cuda")
|
| 250 |
-
|
| 251 |
-
streamer = TextIteratorStreamer(
|
| 252 |
-
processor.tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 253 |
-
)
|
| 254 |
-
gen_kwargs = {**inputs, "max_new_tokens": 1024, "streamer": streamer}
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
if chunk:
|
| 263 |
-
full_text += chunk
|
| 264 |
-
yield {"event": "token", "data": json.dumps({"token": chunk})}
|
| 265 |
|
|
|
|
|
|
|
| 266 |
thread.join()
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
yield {"event": "done", "data": json.dumps({"text": full_text})}
|
| 269 |
except Exception as e:
|
| 270 |
logger.exception("Streaming transcription failed")
|
|
|
|
| 7 |
import json
|
| 8 |
import logging
|
| 9 |
import os
|
|
|
|
| 10 |
import threading
|
| 11 |
from contextlib import asynccontextmanager
|
| 12 |
|
|
|
|
| 19 |
from pydub import AudioSegment
|
| 20 |
from sse_starlette.sse import EventSourceResponse
|
| 21 |
|
| 22 |
+
from voxtral_inference import VoxtralModel
|
| 23 |
+
|
| 24 |
logger = logging.getLogger("gpu_service")
|
| 25 |
|
| 26 |
# ---------------------------------------------------------------------------
|
|
|
|
| 33 |
PYANNOTE_MAX_SPEAKERS = int(os.environ.get("PYANNOTE_MAX_SPEAKERS", "10"))
|
| 34 |
TARGET_SR = 16000
|
| 35 |
|
| 36 |
+
MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model")
|
| 37 |
+
|
| 38 |
# ---------------------------------------------------------------------------
|
| 39 |
# Singletons
|
| 40 |
# ---------------------------------------------------------------------------
|
| 41 |
_diarize_pipeline = None
|
| 42 |
_embed_model = None
|
| 43 |
+
_voxtral: VoxtralModel | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def _load_diarize_pipeline():
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def _load_voxtral():
|
| 68 |
+
global _voxtral
|
| 69 |
+
if _voxtral is None:
|
| 70 |
+
logger.info("Loading Voxtral from %s ...", MODEL_DIR)
|
| 71 |
+
_voxtral = VoxtralModel(MODEL_DIR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
logger.info("Voxtral model loaded.")
|
| 73 |
+
return _voxtral
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
# ---------------------------------------------------------------------------
|
|
|
|
| 181 |
|
| 182 |
|
| 183 |
@app.post("/transcribe")
|
| 184 |
+
async def transcribe(audio: UploadFile = File(...)):
|
|
|
|
|
|
|
|
|
|
| 185 |
try:
|
| 186 |
raw = await audio.read()
|
| 187 |
audio_16k = prepare_audio(raw)
|
| 188 |
|
| 189 |
+
model = _load_voxtral()
|
| 190 |
+
text = model.transcribe(audio_16k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
return {"text": text}
|
| 193 |
except Exception as e:
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
@app.post("/transcribe/stream")
|
| 199 |
+
async def transcribe_stream(audio: UploadFile = File(...)):
|
|
|
|
|
|
|
|
|
|
| 200 |
try:
|
| 201 |
raw = await audio.read()
|
| 202 |
audio_16k = prepare_audio(raw)
|
|
|
|
| 206 |
|
| 207 |
async def event_generator():
|
| 208 |
try:
|
| 209 |
+
model = _load_voxtral()
|
| 210 |
+
full_text = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
# Run blocking generator in a thread
|
| 213 |
+
tokens = []
|
| 214 |
|
| 215 |
+
def _run():
|
| 216 |
+
for tok in model.transcribe_stream(audio_16k):
|
| 217 |
+
tokens.append(tok)
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
+
thread = threading.Thread(target=_run)
|
| 220 |
+
thread.start()
|
| 221 |
thread.join()
|
| 222 |
+
|
| 223 |
+
for tok in tokens:
|
| 224 |
+
full_text += tok
|
| 225 |
+
yield {"event": "token", "data": json.dumps({"token": tok})}
|
| 226 |
+
|
| 227 |
+
full_text = full_text.strip()
|
| 228 |
yield {"event": "done", "data": json.dumps({"text": full_text})}
|
| 229 |
except Exception as e:
|
| 230 |
logger.exception("Streaming transcription failed")
|