Lior-0618 Claude Sonnet 4.6 commited on
Commit
04ce75c
Β·
1 Parent(s): d0f985d

feat: switch to Voxtral-Mini-3B-2507 + evoxtral-lora adapter

Browse files

- Load base model (VoxtralForConditionalGeneration) then apply
PeftModel LoRA from YongkangZOU/evoxtral-lora
- Inference uses apply_chat_template conversation format
- Add peft>=0.18.0 to requirements; bump transformers to >=4.54.0
- Extract _transcribe() helper; update warm-up accordingly

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

model/voxtral-server/main.py CHANGED
@@ -17,7 +17,8 @@ import soundfile as sf
17
  from fastapi import FastAPI, File, UploadFile, HTTPException, Query
18
  from fastapi.middleware.cors import CORSMiddleware
19
 
20
- REPO_ID = os.environ.get("VOXTRAL_MODEL_ID", "mistralai/Voxtral-Mini-4B-Realtime-2602")
 
21
  MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_MB", "100")) * 1024 * 1024
22
  HF_TOKEN = os.environ.get("HF_TOKEN") # optional: enables pyannote speaker diarization
23
 
@@ -92,18 +93,19 @@ async def lifespan(app: FastAPI):
92
  _dtype = torch.bfloat16 # halves memory vs float32 (8 GB vs 16 GB); supported on modern x86
93
  print(f"[voxtral] Device: {_device} dtype: {_dtype}")
94
 
95
- print(f"[voxtral] Loading model: {REPO_ID} (first run may download ~8–16GB)...")
 
96
  try:
97
- from transformers import (
98
- VoxtralRealtimeForConditionalGeneration,
99
- AutoProcessor,
100
- )
101
- processor = AutoProcessor.from_pretrained(REPO_ID)
102
- model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
103
- REPO_ID, torch_dtype=_dtype
104
  ).to(_device)
 
105
  model.eval()
106
- print(f"[voxtral] Model loaded: {REPO_ID} on {_device}")
107
  except Exception as e:
108
  raise RuntimeError(
109
  f"Model load failed: {e}\n"
@@ -112,14 +114,15 @@ async def lifespan(app: FastAPI):
112
  ) from e
113
 
114
  # Warm-up: run one silent dummy inference to pre-compile MPS Metal shaders.
115
- # Without this the first real request pays a ~15s compilation penalty.
116
- print("[voxtral] Warming up MPS shaders (dummy inference)...")
117
  try:
118
- sr = processor.feature_extractor.sampling_rate
119
  dummy = np.zeros(sr, dtype=np.float32) # 1 second of silence
 
120
  with torch.inference_mode():
121
- dummy_inputs = processor(dummy, return_tensors="pt")
122
- dummy_inputs = dummy_inputs.to(_device, dtype=_dtype)
 
123
  model.generate(**dummy_inputs, max_new_tokens=1)
124
  print("[voxtral] Warm-up complete β€” first request will be fast")
125
  except Exception as e:
@@ -381,6 +384,23 @@ def _analyze_emotion(chunk: np.ndarray, sr: int) -> dict:
381
  return {"emotion": "Neutral", "valence": 0.0, "arousal": 0.0}
382
 
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # ─── Endpoints ─────────────────────────────────────────────────────────────────
385
 
386
  @app.post("/transcribe")
@@ -405,7 +425,7 @@ async def transcribe(audio: UploadFile = File(...)):
405
  if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"):
406
  suffix = ".wav"
407
 
408
- target_sr = processor.feature_extractor.sampling_rate
409
  with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
410
  tmp.write(contents)
411
  tmp_path = tmp.name
@@ -421,13 +441,7 @@ async def transcribe(audio: UploadFile = File(...)):
421
  except OSError:
422
  pass
423
 
424
- with torch.inference_mode():
425
- inputs = processor(audio_array, return_tensors="pt")
426
- inputs = inputs.to(model.device, dtype=model.dtype)
427
- outputs = model.generate(**{k: v for k, v in inputs.items()}, max_new_tokens=1024)
428
- decoded = processor.batch_decode(outputs, skip_special_tokens=True)
429
-
430
- text = (decoded[0] or "").strip()
431
  total_ms = (time.perf_counter() - req_start) * 1000
432
  print(f"[voxtral] {req_id} done total={total_ms:.0f}ms text_len={len(text)}")
433
  return {"text": text, "words": [], "languageCode": None}
@@ -458,7 +472,7 @@ async def transcribe_diarize(
458
  if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"):
459
  suffix = ".wav"
460
 
461
- target_sr = processor.feature_extractor.sampling_rate
462
 
463
  with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
464
  tmp.write(contents)
@@ -481,12 +495,7 @@ async def transcribe_diarize(
481
 
482
  # ── Step 1: full transcription via Voxtral ──────────────────────────────
483
  t0 = time.perf_counter()
484
- with torch.inference_mode():
485
- inputs = processor(audio_array, return_tensors="pt")
486
- inputs = inputs.to(model.device, dtype=model.dtype)
487
- outputs = model.generate(**{k: v for k, v in inputs.items()}, max_new_tokens=1024)
488
- decoded = processor.batch_decode(outputs, skip_special_tokens=True)
489
- full_text = (decoded[0] or "").strip()
490
  print(f"[voxtral] {req_id} transcription done in {(time.perf_counter()-t0)*1000:.0f}ms text_len={len(full_text)}")
491
 
492
  # ── Step 2: VAD sentence segmentation ───────────────────────────────────
 
17
  from fastapi import FastAPI, File, UploadFile, HTTPException, Query
18
  from fastapi.middleware.cors import CORSMiddleware
19
 
20
+ REPO_ID = os.environ.get("VOXTRAL_MODEL_ID", "YongkangZOU/evoxtral-lora")
21
+ BASE_MODEL_ID = "mistralai/Voxtral-Mini-3B-2507"
22
  MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_MB", "100")) * 1024 * 1024
23
  HF_TOKEN = os.environ.get("HF_TOKEN") # optional: enables pyannote speaker diarization
24
 
 
93
  _dtype = torch.bfloat16 # halves memory vs float32 (8 GB vs 16 GB); supported on modern x86
94
  print(f"[voxtral] Device: {_device} dtype: {_dtype}")
95
 
96
+ print(f"[voxtral] Loading base model: {BASE_MODEL_ID} ...")
97
+ print(f"[voxtral] Applying LoRA adapter: {REPO_ID} ...")
98
  try:
99
+ from transformers import VoxtralForConditionalGeneration, AutoProcessor
100
+ from peft import PeftModel
101
+
102
+ processor = AutoProcessor.from_pretrained(BASE_MODEL_ID)
103
+ base = VoxtralForConditionalGeneration.from_pretrained(
104
+ BASE_MODEL_ID, torch_dtype=_dtype
 
105
  ).to(_device)
106
+ model = PeftModel.from_pretrained(base, REPO_ID)
107
  model.eval()
108
+ print(f"[voxtral] Model ready: {BASE_MODEL_ID} + LoRA {REPO_ID} on {_device}")
109
  except Exception as e:
110
  raise RuntimeError(
111
  f"Model load failed: {e}\n"
 
114
  ) from e
115
 
116
  # Warm-up: run one silent dummy inference to pre-compile MPS Metal shaders.
117
+ print("[voxtral] Warming up (dummy inference)...")
 
118
  try:
119
+ sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
120
  dummy = np.zeros(sr, dtype=np.float32) # 1 second of silence
121
+ conversation = [{"role": "user", "content": [{"type": "audio", "audio": dummy}]}]
122
  with torch.inference_mode():
123
+ dummy_inputs = processor.apply_chat_template(
124
+ conversation, return_tensors="pt", tokenize=True
125
+ ).to(_device)
126
  model.generate(**dummy_inputs, max_new_tokens=1)
127
  print("[voxtral] Warm-up complete β€” first request will be fast")
128
  except Exception as e:
 
384
  return {"emotion": "Neutral", "valence": 0.0, "arousal": 0.0}
385
 
386
 
387
+ # ─── Inference helper ──────────────────────────────────────────────────────────
388
+
389
+ def _transcribe(audio_array: np.ndarray) -> str:
390
+ """Run Voxtral-3B + LoRA inference via chat template; return transcribed text."""
391
+ conversation = [{"role": "user", "content": [{"type": "audio", "audio": audio_array}]}]
392
+ with torch.inference_mode():
393
+ inputs = processor.apply_chat_template(
394
+ conversation, return_tensors="pt", tokenize=True
395
+ ).to(model.device)
396
+ outputs = model.generate(**inputs, max_new_tokens=1024)
397
+ text = processor.decode(
398
+ outputs[0][inputs["input_ids"].shape[1]:],
399
+ skip_special_tokens=True,
400
+ )
401
+ return text.strip()
402
+
403
+
404
  # ─── Endpoints ─────────────────────────────────────────────────────────────────
405
 
406
  @app.post("/transcribe")
 
425
  if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"):
426
  suffix = ".wav"
427
 
428
+ target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
429
  with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
430
  tmp.write(contents)
431
  tmp_path = tmp.name
 
441
  except OSError:
442
  pass
443
 
444
+ text = _transcribe(audio_array)
 
 
 
 
 
 
445
  total_ms = (time.perf_counter() - req_start) * 1000
446
  print(f"[voxtral] {req_id} done total={total_ms:.0f}ms text_len={len(text)}")
447
  return {"text": text, "words": [], "languageCode": None}
 
472
  if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"):
473
  suffix = ".wav"
474
 
475
+ target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
476
 
477
  with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
478
  tmp.write(contents)
 
495
 
496
  # ── Step 1: full transcription via Voxtral ──────────────────────────────
497
  t0 = time.perf_counter()
498
+ full_text = _transcribe(audio_array)
 
 
 
 
 
499
  print(f"[voxtral] {req_id} transcription done in {(time.perf_counter()-t0)*1000:.0f}ms text_len={len(full_text)}")
500
 
501
  # ── Step 2: VAD sentence segmentation ───────────────────────────────────
model/voxtral-server/requirements.txt CHANGED
@@ -1,12 +1,12 @@
1
- # Voxtral Mini 4B Realtime - speech-to-text inference
2
- # https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602
3
  fastapi>=0.115.0
4
  uvicorn[standard]>=0.32.0
5
  python-multipart>=0.0.9
6
- transformers>=5.2.0
 
7
  torch>=2.0.0
8
  accelerate>=0.33.0
9
- mistral-common[audio]>=1.9.0
10
  librosa>=0.10.0
11
  soundfile>=0.12.0
12
  numpy>=1.24.0
 
1
+ # Voxtral-Mini-3B-2507 + LoRA adapter (YongkangZOU/evoxtral-lora)
 
2
  fastapi>=0.115.0
3
  uvicorn[standard]>=0.32.0
4
  python-multipart>=0.0.9
5
+ transformers>=4.54.0
6
+ peft>=0.18.0
7
  torch>=2.0.0
8
  accelerate>=0.33.0
9
+ mistral-common[audio]>=1.5.0
10
  librosa>=0.10.0
11
  soundfile>=0.12.0
12
  numpy>=1.24.0