ethos / api /main.py
Lior-0618's picture
chore: merge master β†’ dev/video-fer (SSE transcribe-stream)
aa15e90
"""
Evoxtral speech-to-text server (Model layer).
Runs Voxtral-Mini-3B + evoxtral-rl locally for transcription with expressive
tags. For video files, also runs FER (MobileViT-XXS ONNX) per segment.
"""
import asyncio
import os
import re
import shutil
import subprocess
import tempfile
import time
from contextlib import asynccontextmanager
from typing import Optional
import librosa
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
MODEL_ID = os.environ.get("MODEL_ID", "mistralai/Voxtral-Mini-3B-2507")
ADAPTER_ID = os.environ.get("ADAPTER_ID", "YongkangZOU/evoxtral-rl")
MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_MB", "100")) * 1024 * 1024
TARGET_SR = 16000
# ─── STT model ────────────────────────────────────────────────────────────────
_model = None
_processor = None
_model_dtype = None
_model_device = None
def _init_model() -> None:
global _model, _processor, _model_dtype, _model_device
import torch
from transformers import VoxtralForConditionalGeneration, AutoProcessor
from peft import PeftModel
# Use all available CPU cores for parallel compute
n_threads = os.cpu_count() or 4
torch.set_num_threads(n_threads)
torch.set_num_interop_threads(max(1, n_threads // 2))
print(f"[voxtral] torch threads={n_threads}, interop={max(1, n_threads // 2)}")
# bfloat16 on both GPU and CPU β€” halves memory vs float32 (~6 GB vs ~12 GB)
# PyTorch CPU supports bfloat16 natively since 1.12
_model_dtype = torch.bfloat16
if torch.cuda.is_available():
device_map = "auto"
else:
device_map = "cpu"
print(f"[voxtral] Loading processor {MODEL_ID} ...")
_processor = AutoProcessor.from_pretrained(MODEL_ID)
print(f"[voxtral] Loading base model {MODEL_ID} (dtype={_model_dtype}) ...")
base_model = VoxtralForConditionalGeneration.from_pretrained(
MODEL_ID,
dtype=_model_dtype,
device_map=device_map,
)
print(f"[voxtral] Applying LoRA adapter {ADAPTER_ID} ...")
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
# Merge LoRA weights into base model and unload adapter β€” removes per-forward overhead
print(f"[voxtral] Merging LoRA weights into base model ...")
_model = peft_model.merge_and_unload()
_model.eval()
_model_device = next(_model.parameters()).device
print(f"[voxtral] Model ready on {_model_device} (dtype={_model_dtype})")
def _transcribe_sync(wav_path: str) -> str:
"""Run local Voxtral inference (blocking β€” call via run_in_executor)."""
import torch
audio_array, _ = librosa.load(wav_path, sr=TARGET_SR, mono=True)
inputs = _processor.apply_transcription_request(
audio=[audio_array],
format=["WAV"],
language="en",
model_id=MODEL_ID,
return_tensors="pt",
)
# Move inputs to model device / dtype
inputs = {
k: (v.to(_model_device, dtype=_model_dtype)
if v.dtype in (torch.float32, torch.float16, torch.bfloat16)
else v.to(_model_device))
if hasattr(v, "to") else v
for k, v in inputs.items()
}
with torch.inference_mode():
output_ids = _model.generate(**inputs, max_new_tokens=448, do_sample=False)
input_len = inputs["input_ids"].shape[1]
text = _processor.tokenizer.decode(
output_ids[0][input_len:], skip_special_tokens=True
).strip()
return text
# ─── FER setup ────────────────────────────────────────────────────────────────
_fer_session = None
_fer_input_name = "input"
_face_cascade = None
_FER_CLASSES = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
_VIDEO_EXTS = {".mp4", ".mkv", ".avi", ".mov", ".m4v"}
def _is_lfs_pointer(path: str) -> bool:
"""Return True if the file looks like a Git LFS pointer (small text file)."""
try:
size = os.path.getsize(path)
if size > 10_000:
return False
with open(path, "rb") as f:
header = f.read(64)
return header.startswith(b"version https://git-lfs")
except Exception:
return False
def _resolve_lfs_model(fer_path: str) -> str:
"""
If fer_path is a Git LFS pointer, download the real binary.
Returns the path to the actual model file.
"""
import urllib.request
real_path = fer_path + ".resolved"
# Use HF Space's own file resolution URL to download the actual binary
url = "https://huggingface.co/spaces/mistral-hackaton-2026/ethos/resolve/main/models/emotion_model_web.onnx"
print(f"[voxtral] FER: file is LFS pointer β€” downloading from {url}")
try:
urllib.request.urlretrieve(url, real_path)
size = os.path.getsize(real_path)
print(f"[voxtral] FER: downloaded {size} bytes to {real_path}")
return real_path
except Exception as e:
print(f"[voxtral] FER: download failed: {e}")
return fer_path
def _init_fer() -> None:
global _fer_session, _fer_input_name, _face_cascade
candidates = [
os.environ.get("FER_MODEL_PATH", ""),
"/app/models/emotion_model_web.onnx", # Docker
os.path.join(os.path.dirname(__file__), "../models/emotion_model_web.onnx"), # local: api/../models/
os.path.join(os.path.dirname(__file__), "../../models/emotion_model_web.onnx"), # fallback
]
fer_path = next((p for p in candidates if p and os.path.exists(p)), None)
if not fer_path:
print("[voxtral] FER model not found β€” facial emotion disabled")
return
# Debug: log file size and first bytes to diagnose LFS pointer vs real binary
try:
file_size = os.path.getsize(fer_path)
with open(fer_path, "rb") as f:
first_bytes = f.read(32).hex()
print(f"[voxtral] FER file: {fer_path} size={file_size} first_bytes={first_bytes}")
except Exception as e:
print(f"[voxtral] FER file stat error: {e}")
# If it's a Git LFS pointer, download the actual binary
if _is_lfs_pointer(fer_path):
print("[voxtral] FER: detected Git LFS pointer β€” resolving...")
fer_path = _resolve_lfs_model(fer_path)
try:
import onnxruntime as rt
print(f"[voxtral] FER: onnxruntime version = {rt.__version__}")
_fer_session = rt.InferenceSession(fer_path, providers=["CPUExecutionProvider"])
_fer_input_name = _fer_session.get_inputs()[0].name
print(f"[voxtral] FER model loaded: {fer_path} (input={_fer_input_name}, shape={_fer_session.get_inputs()[0].shape})")
except Exception as e:
import traceback
print(f"[voxtral] FER model load failed: {e}")
print(f"[voxtral] FER traceback: {traceback.format_exc()}")
return
try:
import cv2
cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
_face_cascade = cv2.CascadeClassifier(cascade_path)
print("[voxtral] Face cascade loaded")
except Exception as e:
print(f"[voxtral] Face cascade load failed (FER will use center crop): {e}")
def _is_video(filename: str) -> bool:
return os.path.splitext(filename)[1].lower() in _VIDEO_EXTS
def _fer_frame(img_bgr: np.ndarray) -> Optional[str]:
"""Detect face (or center-crop), run FER ONNX; return emotion label or None."""
if _fer_session is None:
return None
try:
import cv2
face_crop = None
if _face_cascade is not None:
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
faces = _face_cascade.detectMultiScale(gray, 1.1, 5, minSize=(40, 40))
if len(faces) > 0:
x, y, w, h = max(faces, key=lambda f: f[2] * f[3])
pad = int(min(w, h) * 0.15)
x1, y1 = max(0, x - pad), max(0, y - pad)
x2, y2 = min(img_bgr.shape[1], x + w + pad), min(img_bgr.shape[0], y + h + pad)
face_crop = img_bgr[y1:y2, x1:x2]
if face_crop is None:
h, w = img_bgr.shape[:2]
crop_h = int(h * 0.6)
cx = w // 2
half = min(crop_h, w) // 2
face_crop = img_bgr[:crop_h, max(0, cx - half):cx + half]
resized = cv2.resize(face_crop, (224, 224))
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
# ImageNet normalization (matches original emotion-recognition.ts)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
rgb = (rgb - mean) / std
tensor = np.transpose(rgb, (2, 0, 1))[np.newaxis] # [1, 3, 224, 224]
out = _fer_session.run(None, {_fer_input_name: tensor})[0]
return _FER_CLASSES[int(np.argmax(out[0]))]
except Exception as e:
print(f"[voxtral] FER frame error: {e}")
return None
def _fer_for_segments(
video_path: str, segments: list[dict]
) -> tuple[dict[int, str], dict[int, str]]:
"""
Extract ~1fps frames from video, run FER.
Returns:
segment_emotions : {segment_id: majority_emotion}
timeline : {second: emotion} β€” per-second, for live panel
"""
if _fer_session is None:
return {}, {}
frames_dir = tempfile.mkdtemp()
try:
import cv2
from collections import Counter
subprocess.run(
["ffmpeg", "-y", "-i", video_path,
"-vf", "fps=1", "-vframes", "600",
"-q:v", "5", os.path.join(frames_dir, "%06d.jpg")],
capture_output=True, timeout=120,
)
frame_files = sorted(f for f in os.listdir(frames_dir) if f.endswith(".jpg"))
if not frame_files:
print("[voxtral] FER: no video frames extracted (audio-only?)")
return {}, {}
# Build per-second emotion map
timeline: dict[int, str] = {}
for fname in frame_files:
second = int(os.path.splitext(fname)[0]) - 1
img = cv2.imread(os.path.join(frames_dir, fname))
if img is None:
continue
emo = _fer_frame(img)
if emo:
timeline[second] = emo
# Majority-vote per segment
segment_emotions: dict[int, str] = {}
for seg in segments:
start_s = int(seg["start"])
end_s = max(int(seg["end"]), start_s + 1)
emos = [timeline[s] for s in range(start_s, end_s) if s in timeline]
if emos:
segment_emotions[seg["id"]] = Counter(emos).most_common(1)[0][0]
print(f"[voxtral] FER: {len(frame_files)} frames β†’ {len(segment_emotions)} segs, {len(timeline)} timeline pts")
return segment_emotions, timeline
except Exception as e:
print(f"[voxtral] FER extraction error: {e}")
return {}, {}
finally:
shutil.rmtree(frames_dir, ignore_errors=True)
# ─── Startup ──────────────────────────────────────────────────────────────────
def _check_ffmpeg():
if shutil.which("ffmpeg") is None:
raise RuntimeError(
"ffmpeg not found.\n"
" macOS: brew install ffmpeg\n"
" Ubuntu: sudo apt install ffmpeg"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
_check_ffmpeg()
print(f"[voxtral] ffmpeg: {shutil.which('ffmpeg')}")
_init_fer()
_init_model()
yield
app = FastAPI(title="Evoxtral Speech-to-Text (local)", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
@app.get("/debug-inference")
async def debug_inference():
"""Quick smoke-test: synthesize 0.5s of silence and run a minimal generate() call."""
import traceback, torch
if _model is None:
return {"ok": False, "error": "model not loaded"}
try:
import numpy as np
silence = np.zeros(8000, dtype=np.float32) # 0.5 s @ 16 kHz
import tempfile, soundfile as sf, asyncio
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
wav_path = f.name
sf.write(wav_path, silence, 16000)
loop = asyncio.get_running_loop()
text = await loop.run_in_executor(None, _transcribe_sync, wav_path)
import os; os.unlink(wav_path)
return {"ok": True, "text": text, "dtype": str(_model_dtype), "device": str(_model_device)}
except Exception as e:
return {"ok": False, "error": str(e), "traceback": traceback.format_exc()}
@app.get("/health")
async def health():
return {
"status": "ok",
"model": f"{MODEL_ID} + {ADAPTER_ID} (local)",
"model_loaded": _model is not None,
"ffmpeg": shutil.which("ffmpeg") is not None,
"fer_enabled": _fer_session is not None,
"device": str(_model_device) if _model_device else None,
"max_upload_mb": MAX_UPLOAD_BYTES // 1024 // 1024,
}
# ─── Audio helpers ─────────────────────────────────────────────────────────────
def _convert_to_wav_ffmpeg(path: str, target_sr: int) -> str:
out = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
out.close()
rc = subprocess.run(
["ffmpeg", "-y", "-i", path,
"-vn", "-acodec", "pcm_s16le", "-ar", str(target_sr), "-ac", "1",
"-f", "wav", out.name],
capture_output=True, timeout=120,
)
if rc.returncode != 0:
os.unlink(out.name)
raise RuntimeError(f"ffmpeg failed: {rc.stderr.decode(errors='replace')[:400]}")
return out.name
def _load_audio(file_path: str, target_sr: int) -> np.ndarray:
y, _ = librosa.load(file_path, sr=target_sr, mono=True)
return y.astype(np.float32)
def _validate_upload(contents: bytes) -> None:
if len(contents) == 0:
raise HTTPException(status_code=400, detail="Audio file is empty")
if len(contents) > MAX_UPLOAD_BYTES:
raise HTTPException(
status_code=413,
detail=f"File too large ({len(contents)/1024/1024:.1f} MB); max {MAX_UPLOAD_BYTES//1024//1024} MB",
)
# ─── Segmentation ──────────────────────────────────────────────────────────────
def _vad_segment(audio: np.ndarray, sr: int) -> list[tuple[int, int]]:
intervals = librosa.effects.split(audio, top_db=28, frame_length=2048, hop_length=512)
if len(intervals) == 0:
return [(0, len(audio))]
merged: list[list[int]] = [[int(intervals[0][0]), int(intervals[0][1])]]
for s, e in intervals[1:]:
if (int(s) - merged[-1][1]) / sr < 0.3:
merged[-1][1] = int(e)
else:
merged.append([int(s), int(e)])
result = [(s, e) for s, e in merged if (e - s) / sr >= 0.3]
return result if result else [(0, len(audio))]
def _segments_from_vad(audio: np.ndarray, sr: int) -> tuple[list[dict], str]:
intervals = _vad_segment(audio, sr)
segs = [
{"id": i + 1, "speaker": "SPEAKER_00", "start": round(s / sr, 3), "end": round(e / sr, 3)}
for i, (s, e) in enumerate(intervals)
]
print(f"[voxtral] VAD: {len(segs)} segment(s)")
return segs, "vad"
def _split_sentences(text: str) -> list[str]:
parts = re.split(r'(?<=[οΌŸοΌγ€‚?!])\s*', text)
return [p for p in parts if p.strip()]
def _distribute_text(full_text: str, segs: list[dict]) -> list[dict]:
if not full_text or not segs:
return [{**s, "text": ""} for s in segs]
if len(segs) == 1:
return [{**segs[0], "text": full_text}]
sentences = _split_sentences(full_text)
if len(sentences) <= 1:
is_cjk = len(full_text.split()) <= 1
sentences = list(full_text) if is_cjk else full_text.split()
total_dur = sum(s["end"] - s["start"] for s in segs)
if total_dur <= 0:
return [{**segs[0], "text": full_text}] + [{**s, "text": ""} for s in segs[1:]]
is_cjk = len(full_text.split()) <= 1 and len(full_text) > 1
sep = "" if is_cjk else " "
n = len(sentences)
result_texts: list[list[str]] = [[] for _ in segs]
cumulative = 0.0
for i, seg in enumerate(segs):
cumulative += (seg["end"] - seg["start"]) / total_dur
threshold = cumulative * n
while len(result_texts[i]) + sum(len(t) for t in result_texts[:i]) < round(threshold):
idx = sum(len(t) for t in result_texts)
if idx >= n:
break
result_texts[i].append(sentences[idx])
assigned = sum(len(t) for t in result_texts)
result_texts[-1].extend(sentences[assigned:])
return [{**seg, "text": sep.join(texts)} for seg, texts in zip(segs, result_texts)]
# ─── Emotion parsing from evoxtral expression tags ─────────────────────────────
_TAG_EMOTIONS: dict[str, tuple[str, float, float]] = {
"laughs": ("Happy", 0.70, 0.60),
"laughing": ("Happy", 0.70, 0.60),
"chuckles": ("Happy", 0.50, 0.30),
"giggles": ("Happy", 0.60, 0.40),
"sighs": ("Sad", -0.30, -0.30),
"sighing": ("Sad", -0.30, -0.30),
"cries": ("Sad", -0.70, 0.40),
"crying": ("Sad", -0.70, 0.40),
"whispers": ("Calm", 0.10, -0.50),
"whispering": ("Calm", 0.10, -0.50),
"shouts": ("Angry", -0.50, 0.80),
"shouting": ("Angry", -0.50, 0.80),
"exclaims": ("Excited", 0.50, 0.70),
"gasps": ("Surprised", 0.20, 0.70),
"hesitates": ("Anxious", -0.20, 0.30),
"stutters": ("Anxious", -0.20, 0.40),
"stammers": ("Anxious", -0.25, 0.35),
"mumbles": ("Sad", -0.20, -0.30),
"nervous": ("Anxious", -0.30, 0.40),
"frustrated": ("Frustrated", -0.50, 0.50),
"excited": ("Excited", 0.50, 0.70),
"sad": ("Sad", -0.60, -0.20),
"angry": ("Angry", -0.60, 0.70),
"claps": ("Happy", 0.60, 0.50),
"applause": ("Happy", 0.60, 0.50),
"clears throat": ("Neutral", 0.00, 0.10),
"pause": ("Neutral", 0.00, -0.10),
"laughs nervously": ("Anxious", -0.10, 0.40),
}
def _parse_emotion(text: str) -> dict:
tags = re.findall(r'\[([^\]]+)\]', text.lower())
for tag in tags:
tag = tag.strip()
if tag in _TAG_EMOTIONS:
label, valence, arousal = _TAG_EMOTIONS[tag]
return {"emotion": label, "valence": valence, "arousal": arousal}
for key, (label, valence, arousal) in _TAG_EMOTIONS.items():
if key in tag:
return {"emotion": label, "valence": valence, "arousal": arousal}
return {"emotion": "Neutral", "valence": 0.0, "arousal": 0.0}
# ─── Endpoints ─────────────────────────────────────────────────────────────────
@app.post("/transcribe")
async def transcribe(audio: UploadFile = File(...)):
req_start = time.perf_counter()
filename = audio.filename or "audio.wav"
print(f"[voxtral] POST /transcribe filename={filename}")
if _model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
contents = await audio.read()
_validate_upload(contents)
suffix = os.path.splitext(filename)[1].lower() or ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(contents)
tmp_path = tmp.name
wav_path = None
try:
wav_path = _convert_to_wav_ffmpeg(tmp_path, TARGET_SR)
loop = asyncio.get_running_loop()
text = await loop.run_in_executor(None, _transcribe_sync, wav_path)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"Cannot process audio: {e}")
finally:
for p in (tmp_path, wav_path):
if p and os.path.exists(p):
try: os.unlink(p)
except OSError: pass
print(f"[voxtral] /transcribe done {(time.perf_counter()-req_start)*1000:.0f}ms")
return {"text": text, "words": []}
@app.post("/transcribe-diarize")
async def transcribe_diarize(audio: UploadFile = File(...)):
"""
Upload audio/video β†’ transcription + VAD segmentation + emotion.
For video files (.mp4, .mkv, .avi, .mov, .m4v), also runs FER.
"""
req_start = time.perf_counter()
req_id = f"diarize-{int(req_start * 1000)}"
filename = audio.filename or "audio.wav"
print(f"[voxtral] {req_id} POST /transcribe-diarize filename={filename}")
if _model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
contents = await audio.read()
_validate_upload(contents)
suffix = os.path.splitext(filename)[1].lower() or ".wav"
if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm",
".mp4", ".mkv", ".avi", ".mov", ".m4v"):
suffix = ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(contents)
tmp_path = tmp.name
wav_path = None
try:
t0 = time.perf_counter()
wav_path = _convert_to_wav_ffmpeg(tmp_path, TARGET_SR)
audio_array = _load_audio(wav_path, TARGET_SR)
print(f"[voxtral] {req_id} audio loaded shape={audio_array.shape} in {(time.perf_counter()-t0)*1000:.0f}ms")
except Exception as e:
for p in (tmp_path, wav_path):
if p and os.path.exists(p):
try: os.unlink(p)
except OSError: pass
raise HTTPException(status_code=400, detail=f"Cannot decode audio: {e}")
duration = round(len(audio_array) / TARGET_SR, 3)
# ── STT (local model, run in thread pool) ────────────────────────────────
try:
t0 = time.perf_counter()
loop = asyncio.get_running_loop()
full_text = await loop.run_in_executor(None, _transcribe_sync, wav_path)
print(f"[voxtral] {req_id} STT done {(time.perf_counter()-t0)*1000:.0f}ms text_len={len(full_text)}")
except Exception as e:
import traceback as _tb
print(f"[voxtral] {req_id} STT error: {e}\n{_tb.format_exc()}")
raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
finally:
if wav_path and os.path.exists(wav_path):
try: os.unlink(wav_path)
except OSError: pass
# ── VAD segmentation + text distribution ─────────────────────────────────
raw_segs, seg_method = _segments_from_vad(audio_array, TARGET_SR)
segs_with_text = _distribute_text(full_text, raw_segs)
# ── FER (video only) ─────────────────────────────────────────────────────
has_fer = False
face_emotions: dict[int, str] = {}
fer_timeline: dict[int, str] = {}
if _is_video(filename) and _fer_session is not None:
t0 = time.perf_counter()
face_emotions, fer_timeline = await loop.run_in_executor(
None, _fer_for_segments, tmp_path, raw_segs
)
has_fer = bool(face_emotions)
print(f"[voxtral] {req_id} FER done {(time.perf_counter()-t0)*1000:.0f}ms faces={len(face_emotions)} timeline={len(fer_timeline)}")
if tmp_path and os.path.exists(tmp_path):
try: os.unlink(tmp_path)
except OSError: pass
# ── Build segments ────────────────────────────────────────────────────────
segments = []
for s in segs_with_text:
emo = _parse_emotion(s["text"])
seg_data = {
"id": s["id"],
"speaker": s["speaker"],
"start": s["start"],
"end": s["end"],
"text": s["text"],
"emotion": emo["emotion"],
"valence": emo["valence"],
"arousal": emo["arousal"],
}
if s["id"] in face_emotions:
seg_data["face_emotion"] = face_emotions[s["id"]]
segments.append(seg_data)
total_ms = (time.perf_counter() - req_start) * 1000
print(f"[voxtral] {req_id} complete total={total_ms:.0f}ms segments={len(segments)} has_fer={has_fer}")
return {
"segments": segments,
"duration": duration,
"text": full_text,
"filename": filename,
"diarization_method": seg_method,
"has_video": has_fer,
# Per-second face emotion timeline for live playback panel
# Keys are strings (JSON), values are emotion labels e.g. "Happy"
"face_emotion_timeline": {str(k): v for k, v in fer_timeline.items()},
}