tantk commited on
Commit
a0d2600
·
verified ·
1 Parent(s): a28871c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +29 -69
app.py CHANGED
@@ -7,7 +7,6 @@ import io
7
  import json
8
  import logging
9
  import os
10
- import re
11
  import threading
12
  from contextlib import asynccontextmanager
13
 
@@ -20,6 +19,8 @@ from fastapi.responses import JSONResponse
20
  from pydub import AudioSegment
21
  from sse_starlette.sse import EventSourceResponse
22
 
 
 
23
  logger = logging.getLogger("gpu_service")
24
 
25
  # ---------------------------------------------------------------------------
@@ -32,18 +33,14 @@ PYANNOTE_MIN_SPEAKERS = int(os.environ.get("PYANNOTE_MIN_SPEAKERS", "1"))
32
  PYANNOTE_MAX_SPEAKERS = int(os.environ.get("PYANNOTE_MAX_SPEAKERS", "10"))
33
  TARGET_SR = 16000
34
 
 
 
35
  # ---------------------------------------------------------------------------
36
  # Singletons
37
  # ---------------------------------------------------------------------------
38
  _diarize_pipeline = None
39
  _embed_model = None
40
- _voxtral_model = None
41
- _voxtral_processor = None
42
-
43
- VOXTRAL_MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
44
-
45
- # Markers to strip from Voxtral output
46
- _MARKER_RE = re.compile(r"\[STREAMING_PAD\]|\[STREAMING_WORD\]")
47
 
48
 
49
  def _load_diarize_pipeline():
@@ -68,26 +65,12 @@ def _load_embed_model():
68
 
69
 
70
  def _load_voxtral():
71
- """Lazy-load Voxtral model and processor (first call only)."""
72
- global _voxtral_model, _voxtral_processor
73
- if _voxtral_model is None:
74
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
75
-
76
- logger.info("Loading Voxtral model %s ...", VOXTRAL_MODEL_ID)
77
- _voxtral_processor = AutoProcessor.from_pretrained(
78
- VOXTRAL_MODEL_ID, trust_remote_code=True
79
- )
80
- _voxtral_model = AutoModelForSpeechSeq2Seq.from_pretrained(
81
- VOXTRAL_MODEL_ID, torch_dtype=torch.float16, trust_remote_code=True
82
- ).to("cuda")
83
  logger.info("Voxtral model loaded.")
84
- return _voxtral_model, _voxtral_processor
85
-
86
-
87
- def _clean_voxtral_text(text: str) -> str:
88
- """Strip Voxtral streaming markers and collapse whitespace."""
89
- text = _MARKER_RE.sub("", text)
90
- return " ".join(text.split()).strip()
91
 
92
 
93
  # ---------------------------------------------------------------------------
@@ -198,25 +181,13 @@ async def embed(
198
 
199
 
200
  @app.post("/transcribe")
201
- async def transcribe(
202
- audio: UploadFile = File(...),
203
- prompt: str = Form("Transcribe this audio."),
204
- ):
205
  try:
206
  raw = await audio.read()
207
  audio_16k = prepare_audio(raw)
208
 
209
- model, processor = _load_voxtral()
210
- inputs = processor(
211
- audios=audio_16k,
212
- sampling_rate=TARGET_SR,
213
- text=prompt,
214
- return_tensors="pt",
215
- ).to("cuda")
216
-
217
- output_ids = model.generate(**inputs, max_new_tokens=1024)
218
- text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
219
- text = _clean_voxtral_text(text)
220
 
221
  return {"text": text}
222
  except Exception as e:
@@ -225,10 +196,7 @@ async def transcribe(
225
 
226
 
227
  @app.post("/transcribe/stream")
228
- async def transcribe_stream(
229
- audio: UploadFile = File(...),
230
- prompt: str = Form("Transcribe this audio."),
231
- ):
232
  try:
233
  raw = await audio.read()
234
  audio_16k = prepare_audio(raw)
@@ -238,33 +206,25 @@ async def transcribe_stream(
238
 
239
  async def event_generator():
240
  try:
241
- from transformers import TextIteratorStreamer
242
-
243
- model, processor = _load_voxtral()
244
- inputs = processor(
245
- audios=audio_16k,
246
- sampling_rate=TARGET_SR,
247
- text=prompt,
248
- return_tensors="pt",
249
- ).to("cuda")
250
-
251
- streamer = TextIteratorStreamer(
252
- processor.tokenizer, skip_prompt=True, skip_special_tokens=True
253
- )
254
- gen_kwargs = {**inputs, "max_new_tokens": 1024, "streamer": streamer}
255
 
256
- thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
257
- thread.start()
258
 
259
- full_text = ""
260
- for chunk in streamer:
261
- chunk = _MARKER_RE.sub("", chunk)
262
- if chunk:
263
- full_text += chunk
264
- yield {"event": "token", "data": json.dumps({"token": chunk})}
265
 
 
 
266
  thread.join()
267
- full_text = " ".join(full_text.split()).strip()
 
 
 
 
 
268
  yield {"event": "done", "data": json.dumps({"text": full_text})}
269
  except Exception as e:
270
  logger.exception("Streaming transcription failed")
 
7
  import json
8
  import logging
9
  import os
 
10
  import threading
11
  from contextlib import asynccontextmanager
12
 
 
19
  from pydub import AudioSegment
20
  from sse_starlette.sse import EventSourceResponse
21
 
22
+ from voxtral_inference import VoxtralModel
23
+
24
  logger = logging.getLogger("gpu_service")
25
 
26
  # ---------------------------------------------------------------------------
 
33
  PYANNOTE_MAX_SPEAKERS = int(os.environ.get("PYANNOTE_MAX_SPEAKERS", "10"))
34
  TARGET_SR = 16000
35
 
36
+ MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model")
37
+
38
  # ---------------------------------------------------------------------------
39
  # Singletons
40
  # ---------------------------------------------------------------------------
41
  _diarize_pipeline = None
42
  _embed_model = None
43
+ _voxtral: VoxtralModel | None = None
 
 
 
 
 
 
44
 
45
 
46
  def _load_diarize_pipeline():
 
65
 
66
 
67
  def _load_voxtral():
68
+ global _voxtral
69
+ if _voxtral is None:
70
+ logger.info("Loading Voxtral from %s ...", MODEL_DIR)
71
+ _voxtral = VoxtralModel(MODEL_DIR)
 
 
 
 
 
 
 
 
72
  logger.info("Voxtral model loaded.")
73
+ return _voxtral
 
 
 
 
 
 
74
 
75
 
76
  # ---------------------------------------------------------------------------
 
181
 
182
 
183
  @app.post("/transcribe")
184
+ async def transcribe(audio: UploadFile = File(...)):
 
 
 
185
  try:
186
  raw = await audio.read()
187
  audio_16k = prepare_audio(raw)
188
 
189
+ model = _load_voxtral()
190
+ text = model.transcribe(audio_16k)
 
 
 
 
 
 
 
 
 
191
 
192
  return {"text": text}
193
  except Exception as e:
 
196
 
197
 
198
  @app.post("/transcribe/stream")
199
+ async def transcribe_stream(audio: UploadFile = File(...)):
 
 
 
200
  try:
201
  raw = await audio.read()
202
  audio_16k = prepare_audio(raw)
 
206
 
207
  async def event_generator():
208
  try:
209
+ model = _load_voxtral()
210
+ full_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ # Run blocking generator in a thread
213
+ tokens = []
214
 
215
+ def _run():
216
+ for tok in model.transcribe_stream(audio_16k):
217
+ tokens.append(tok)
 
 
 
218
 
219
+ thread = threading.Thread(target=_run)
220
+ thread.start()
221
  thread.join()
222
+
223
+ for tok in tokens:
224
+ full_text += tok
225
+ yield {"event": "token", "data": json.dumps({"token": tok})}
226
+
227
+ full_text = full_text.strip()
228
  yield {"event": "done", "data": json.dumps({"text": full_text})}
229
  except Exception as e:
230
  logger.exception("Streaming transcription failed")