Spaces:
Sleeping
Sleeping
feat: switch to Voxtral-Mini-3B-2507 + evoxtral-lora adapter
Browse files- Load base model (VoxtralForConditionalGeneration) then apply
PeftModel LoRA from YongkangZOU/evoxtral-lora
- Inference uses apply_chat_template conversation format
- Add peft>=0.18.0 to requirements; bump transformers to >=4.54.0
- Extract _transcribe() helper; update warm-up accordingly
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
model/voxtral-server/main.py
CHANGED
|
@@ -17,7 +17,8 @@ import soundfile as sf
|
|
| 17 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
| 18 |
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
|
| 20 |
-
REPO_ID = os.environ.get("VOXTRAL_MODEL_ID", "
|
|
|
|
| 21 |
MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_MB", "100")) * 1024 * 1024
|
| 22 |
HF_TOKEN = os.environ.get("HF_TOKEN") # optional: enables pyannote speaker diarization
|
| 23 |
|
|
@@ -92,18 +93,19 @@ async def lifespan(app: FastAPI):
|
|
| 92 |
_dtype = torch.bfloat16 # halves memory vs float32 (8 GB vs 16 GB); supported on modern x86
|
| 93 |
print(f"[voxtral] Device: {_device} dtype: {_dtype}")
|
| 94 |
|
| 95 |
-
print(f"[voxtral] Loading model: {
|
|
|
|
| 96 |
try:
|
| 97 |
-
from transformers import
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
REPO_ID, torch_dtype=_dtype
|
| 104 |
).to(_device)
|
|
|
|
| 105 |
model.eval()
|
| 106 |
-
print(f"[voxtral] Model
|
| 107 |
except Exception as e:
|
| 108 |
raise RuntimeError(
|
| 109 |
f"Model load failed: {e}\n"
|
|
@@ -112,14 +114,15 @@ async def lifespan(app: FastAPI):
|
|
| 112 |
) from e
|
| 113 |
|
| 114 |
# Warm-up: run one silent dummy inference to pre-compile MPS Metal shaders.
|
| 115 |
-
|
| 116 |
-
print("[voxtral] Warming up MPS shaders (dummy inference)...")
|
| 117 |
try:
|
| 118 |
-
sr = processor
|
| 119 |
dummy = np.zeros(sr, dtype=np.float32) # 1 second of silence
|
|
|
|
| 120 |
with torch.inference_mode():
|
| 121 |
-
dummy_inputs = processor(
|
| 122 |
-
|
|
|
|
| 123 |
model.generate(**dummy_inputs, max_new_tokens=1)
|
| 124 |
print("[voxtral] Warm-up complete β first request will be fast")
|
| 125 |
except Exception as e:
|
|
@@ -381,6 +384,23 @@ def _analyze_emotion(chunk: np.ndarray, sr: int) -> dict:
|
|
| 381 |
return {"emotion": "Neutral", "valence": 0.0, "arousal": 0.0}
|
| 382 |
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
# βββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 385 |
|
| 386 |
@app.post("/transcribe")
|
|
@@ -405,7 +425,7 @@ async def transcribe(audio: UploadFile = File(...)):
|
|
| 405 |
if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"):
|
| 406 |
suffix = ".wav"
|
| 407 |
|
| 408 |
-
target_sr = processor
|
| 409 |
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 410 |
tmp.write(contents)
|
| 411 |
tmp_path = tmp.name
|
|
@@ -421,13 +441,7 @@ async def transcribe(audio: UploadFile = File(...)):
|
|
| 421 |
except OSError:
|
| 422 |
pass
|
| 423 |
|
| 424 |
-
|
| 425 |
-
inputs = processor(audio_array, return_tensors="pt")
|
| 426 |
-
inputs = inputs.to(model.device, dtype=model.dtype)
|
| 427 |
-
outputs = model.generate(**{k: v for k, v in inputs.items()}, max_new_tokens=1024)
|
| 428 |
-
decoded = processor.batch_decode(outputs, skip_special_tokens=True)
|
| 429 |
-
|
| 430 |
-
text = (decoded[0] or "").strip()
|
| 431 |
total_ms = (time.perf_counter() - req_start) * 1000
|
| 432 |
print(f"[voxtral] {req_id} done total={total_ms:.0f}ms text_len={len(text)}")
|
| 433 |
return {"text": text, "words": [], "languageCode": None}
|
|
@@ -458,7 +472,7 @@ async def transcribe_diarize(
|
|
| 458 |
if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"):
|
| 459 |
suffix = ".wav"
|
| 460 |
|
| 461 |
-
target_sr = processor
|
| 462 |
|
| 463 |
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 464 |
tmp.write(contents)
|
|
@@ -481,12 +495,7 @@ async def transcribe_diarize(
|
|
| 481 |
|
| 482 |
# ββ Step 1: full transcription via Voxtral ββββββββββββββββββββββββββββββ
|
| 483 |
t0 = time.perf_counter()
|
| 484 |
-
|
| 485 |
-
inputs = processor(audio_array, return_tensors="pt")
|
| 486 |
-
inputs = inputs.to(model.device, dtype=model.dtype)
|
| 487 |
-
outputs = model.generate(**{k: v for k, v in inputs.items()}, max_new_tokens=1024)
|
| 488 |
-
decoded = processor.batch_decode(outputs, skip_special_tokens=True)
|
| 489 |
-
full_text = (decoded[0] or "").strip()
|
| 490 |
print(f"[voxtral] {req_id} transcription done in {(time.perf_counter()-t0)*1000:.0f}ms text_len={len(full_text)}")
|
| 491 |
|
| 492 |
# ββ Step 2: VAD sentence segmentation βββββββββββββββββββββββββββββββββββ
|
|
|
|
| 17 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
| 18 |
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
|
| 20 |
+
REPO_ID = os.environ.get("VOXTRAL_MODEL_ID", "YongkangZOU/evoxtral-lora")
|
| 21 |
+
BASE_MODEL_ID = "mistralai/Voxtral-Mini-3B-2507"
|
| 22 |
MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_MB", "100")) * 1024 * 1024
|
| 23 |
HF_TOKEN = os.environ.get("HF_TOKEN") # optional: enables pyannote speaker diarization
|
| 24 |
|
|
|
|
| 93 |
_dtype = torch.bfloat16 # halves memory vs float32 (8 GB vs 16 GB); supported on modern x86
|
| 94 |
print(f"[voxtral] Device: {_device} dtype: {_dtype}")
|
| 95 |
|
| 96 |
+
print(f"[voxtral] Loading base model: {BASE_MODEL_ID} ...")
|
| 97 |
+
print(f"[voxtral] Applying LoRA adapter: {REPO_ID} ...")
|
| 98 |
try:
|
| 99 |
+
from transformers import VoxtralForConditionalGeneration, AutoProcessor
|
| 100 |
+
from peft import PeftModel
|
| 101 |
+
|
| 102 |
+
processor = AutoProcessor.from_pretrained(BASE_MODEL_ID)
|
| 103 |
+
base = VoxtralForConditionalGeneration.from_pretrained(
|
| 104 |
+
BASE_MODEL_ID, torch_dtype=_dtype
|
|
|
|
| 105 |
).to(_device)
|
| 106 |
+
model = PeftModel.from_pretrained(base, REPO_ID)
|
| 107 |
model.eval()
|
| 108 |
+
print(f"[voxtral] Model ready: {BASE_MODEL_ID} + LoRA {REPO_ID} on {_device}")
|
| 109 |
except Exception as e:
|
| 110 |
raise RuntimeError(
|
| 111 |
f"Model load failed: {e}\n"
|
|
|
|
| 114 |
) from e
|
| 115 |
|
| 116 |
# Warm-up: run one silent dummy inference to pre-compile MPS Metal shaders.
|
| 117 |
+
print("[voxtral] Warming up (dummy inference)...")
|
|
|
|
| 118 |
try:
|
| 119 |
+
sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
|
| 120 |
dummy = np.zeros(sr, dtype=np.float32) # 1 second of silence
|
| 121 |
+
conversation = [{"role": "user", "content": [{"type": "audio", "audio": dummy}]}]
|
| 122 |
with torch.inference_mode():
|
| 123 |
+
dummy_inputs = processor.apply_chat_template(
|
| 124 |
+
conversation, return_tensors="pt", tokenize=True
|
| 125 |
+
).to(_device)
|
| 126 |
model.generate(**dummy_inputs, max_new_tokens=1)
|
| 127 |
print("[voxtral] Warm-up complete β first request will be fast")
|
| 128 |
except Exception as e:
|
|
|
|
| 384 |
return {"emotion": "Neutral", "valence": 0.0, "arousal": 0.0}
|
| 385 |
|
| 386 |
|
| 387 |
+
# βββ Inference helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 388 |
+
|
| 389 |
+
def _transcribe(audio_array: np.ndarray) -> str:
|
| 390 |
+
"""Run Voxtral-3B + LoRA inference via chat template; return transcribed text."""
|
| 391 |
+
conversation = [{"role": "user", "content": [{"type": "audio", "audio": audio_array}]}]
|
| 392 |
+
with torch.inference_mode():
|
| 393 |
+
inputs = processor.apply_chat_template(
|
| 394 |
+
conversation, return_tensors="pt", tokenize=True
|
| 395 |
+
).to(model.device)
|
| 396 |
+
outputs = model.generate(**inputs, max_new_tokens=1024)
|
| 397 |
+
text = processor.decode(
|
| 398 |
+
outputs[0][inputs["input_ids"].shape[1]:],
|
| 399 |
+
skip_special_tokens=True,
|
| 400 |
+
)
|
| 401 |
+
return text.strip()
|
| 402 |
+
|
| 403 |
+
|
| 404 |
# βββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 405 |
|
| 406 |
@app.post("/transcribe")
|
|
|
|
| 425 |
if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"):
|
| 426 |
suffix = ".wav"
|
| 427 |
|
| 428 |
+
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
|
| 429 |
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 430 |
tmp.write(contents)
|
| 431 |
tmp_path = tmp.name
|
|
|
|
| 441 |
except OSError:
|
| 442 |
pass
|
| 443 |
|
| 444 |
+
text = _transcribe(audio_array)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
total_ms = (time.perf_counter() - req_start) * 1000
|
| 446 |
print(f"[voxtral] {req_id} done total={total_ms:.0f}ms text_len={len(text)}")
|
| 447 |
return {"text": text, "words": [], "languageCode": None}
|
|
|
|
| 472 |
if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"):
|
| 473 |
suffix = ".wav"
|
| 474 |
|
| 475 |
+
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
|
| 476 |
|
| 477 |
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 478 |
tmp.write(contents)
|
|
|
|
| 495 |
|
| 496 |
# ββ Step 1: full transcription via Voxtral ββββββββββββββββββββββββββββββ
|
| 497 |
t0 = time.perf_counter()
|
| 498 |
+
full_text = _transcribe(audio_array)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
print(f"[voxtral] {req_id} transcription done in {(time.perf_counter()-t0)*1000:.0f}ms text_len={len(full_text)}")
|
| 500 |
|
| 501 |
# ββ Step 2: VAD sentence segmentation βββββββββββββββββββββββββββββββββββ
|
model/voxtral-server/requirements.txt
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
# Voxtral
|
| 2 |
-
# https://huggingface.co/mistralai/Voxtral-Mini-4B-Realtime-2602
|
| 3 |
fastapi>=0.115.0
|
| 4 |
uvicorn[standard]>=0.32.0
|
| 5 |
python-multipart>=0.0.9
|
| 6 |
-
transformers>=
|
|
|
|
| 7 |
torch>=2.0.0
|
| 8 |
accelerate>=0.33.0
|
| 9 |
-
mistral-common[audio]>=1.
|
| 10 |
librosa>=0.10.0
|
| 11 |
soundfile>=0.12.0
|
| 12 |
numpy>=1.24.0
|
|
|
|
| 1 |
+
# Voxtral-Mini-3B-2507 + LoRA adapter (YongkangZOU/evoxtral-lora)
|
|
|
|
| 2 |
fastapi>=0.115.0
|
| 3 |
uvicorn[standard]>=0.32.0
|
| 4 |
python-multipart>=0.0.9
|
| 5 |
+
transformers>=4.54.0
|
| 6 |
+
peft>=0.18.0
|
| 7 |
torch>=2.0.0
|
| 8 |
accelerate>=0.33.0
|
| 9 |
+
mistral-common[audio]>=1.5.0
|
| 10 |
librosa>=0.10.0
|
| 11 |
soundfile>=0.12.0
|
| 12 |
numpy>=1.24.0
|