File size: 26,017 Bytes
e0fe7d5
c631126
aa15e90
c631126
e0fe7d5
0a64934
e0fe7d5
8e1d7bd
e0fe7d5
 
 
 
 
0a64934
e0fe7d5
 
8e1d7bd
 
e0fe7d5
 
c631126
aa15e90
e0fe7d5
8e1d7bd
fdef69c
c631126
 
 
 
 
 
 
 
 
 
 
 
 
 
5fa0792
 
 
 
 
 
64b796f
 
 
c631126
64b796f
c631126
64b796f
c631126
 
 
 
 
 
 
64b796f
c631126
 
 
 
5fa0792
 
 
 
 
c631126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ee7e7e
c631126
 
 
 
 
 
 
 
 
 
 
 
 
5fa0792
 
c631126
 
 
 
 
 
 
 
0a64934
 
c631126
0a64934
c631126
 
 
0a64934
 
d0d44f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a64934
 
 
 
 
1057f87
 
 
0a64934
 
 
 
 
 
d0d44f6
 
 
 
 
 
 
 
 
 
 
 
 
 
0a64934
 
d0d44f6
c631126
0a64934
d0d44f6
0a64934
d0d44f6
0a64934
d0d44f6
0a64934
 
 
 
c631126
0a64934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c631126
0a64934
 
c631126
0a64934
 
 
 
 
 
c631126
 
 
 
0a64934
 
 
c631126
ffde148
 
 
 
0a64934
 
c631126
0a64934
 
 
 
 
 
5fa0792
 
 
 
 
 
 
 
 
0a64934
5fa0792
0a64934
 
 
 
 
 
c631126
 
 
 
0a64934
 
 
 
 
5fa0792
0a64934
5fa0792
 
0a64934
 
 
 
 
 
 
5fa0792
0a64934
5fa0792
 
0a64934
c631126
 
5fa0792
0a64934
5fa0792
0a64934
5fa0792
 
0a64934
 
5fa0792
0a64934
 
 
 
 
e0fe7d5
 
 
 
0a64934
e0fe7d5
8e1d7bd
fdef69c
 
 
e0fe7d5
 
 
 
0a64934
c631126
e0fe7d5
 
 
c631126
e0fe7d5
 
 
8e1d7bd
e0fe7d5
 
 
 
 
ef7f038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0fe7d5
 
 
 
c631126
 
e0fe7d5
0a64934
c631126
e0fe7d5
 
 
 
fdef69c
 
e0fe7d5
 
 
 
0a64934
 
 
 
e0fe7d5
 
 
0a64934
e0fe7d5
 
 
8e1d7bd
0a64934
 
e0fe7d5
 
 
 
0a64934
e0fe7d5
 
 
0a64934
e0fe7d5
 
 
0a64934
fdef69c
4c097b5
fdef69c
 
 
 
 
4c097b5
fdef69c
 
 
4c097b5
fdef69c
 
 
4c097b5
 
 
0a64934
 
4c097b5
0a64934
4c097b5
fdef69c
 
ca55dbc
 
 
fdef69c
 
 
ca55dbc
fdef69c
ca55dbc
 
 
 
 
 
fdef69c
 
ca55dbc
4c097b5
 
ca55dbc
 
 
fdef69c
ca55dbc
 
 
 
 
 
 
 
 
 
fdef69c
 
8e1d7bd
 
 
c631126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e1d7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
04ce75c
 
fdef69c
 
e0fe7d5
 
 
c631126
0a64934
e0fe7d5
c631126
 
 
 
e0fe7d5
 
e31fc14
 
 
 
0a64934
e31fc14
 
 
c631126
 
e31fc14
 
 
0a64934
e31fc14
 
 
c631126
 
e31fc14
0a64934
c631126
fdef69c
 
 
8e1d7bd
fdef69c
0a64934
c631126
fdef69c
 
c631126
 
4c097b5
fdef69c
c631126
 
 
 
fdef69c
 
 
0a64934
 
fdef69c
 
 
 
 
0a64934
e31fc14
fdef69c
 
c631126
 
0a64934
fdef69c
0a64934
 
c631126
 
fdef69c
 
0a64934
 
c631126
e31fc14
c631126
 
 
0a64934
4ee7e7e
 
 
 
e31fc14
 
c631126
 
e31fc14
0a64934
8e1d7bd
c631126
fdef69c
c631126
5fa0792
 
 
0a64934
 
5fa0792
 
 
0a64934
5fa0792
0a64934
 
c631126
 
0a64934
 
fdef69c
0a64934
c631126
0a64934
 
fdef69c
 
 
 
 
 
 
0a64934
 
 
 
fdef69c
 
0a64934
fdef69c
 
5fa0792
 
 
 
 
 
 
 
 
fdef69c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
"""
Evoxtral speech-to-text server (Model layer).
Runs Voxtral-Mini-3B + evoxtral-rl locally for transcription with expressive
tags. For video files, also runs FER (MobileViT-XXS ONNX) per segment.
"""
import asyncio
import os
import re
import shutil
import subprocess
import tempfile
import time
from contextlib import asynccontextmanager
from typing import Optional

import librosa
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware

MODEL_ID   = os.environ.get("MODEL_ID",   "mistralai/Voxtral-Mini-3B-2507")
ADAPTER_ID = os.environ.get("ADAPTER_ID", "YongkangZOU/evoxtral-rl")
MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_MB", "100")) * 1024 * 1024
TARGET_SR = 16000

# ─── STT model ────────────────────────────────────────────────────────────────

_model     = None
_processor = None
_model_dtype  = None
_model_device = None


def _init_model() -> None:
    global _model, _processor, _model_dtype, _model_device
    import torch
    from transformers import VoxtralForConditionalGeneration, AutoProcessor
    from peft import PeftModel

    # Use all available CPU cores for parallel compute
    n_threads = os.cpu_count() or 4
    torch.set_num_threads(n_threads)
    torch.set_num_interop_threads(max(1, n_threads // 2))
    print(f"[voxtral] torch threads={n_threads}, interop={max(1, n_threads // 2)}")

    # bfloat16 on both GPU and CPU β€” halves memory vs float32 (~6 GB vs ~12 GB)
    # PyTorch CPU supports bfloat16 natively since 1.12
    _model_dtype = torch.bfloat16
    if torch.cuda.is_available():
        device_map = "auto"
    else:
        device_map = "cpu"

    print(f"[voxtral] Loading processor {MODEL_ID} ...")
    _processor = AutoProcessor.from_pretrained(MODEL_ID)

    print(f"[voxtral] Loading base model {MODEL_ID} (dtype={_model_dtype}) ...")
    base_model = VoxtralForConditionalGeneration.from_pretrained(
        MODEL_ID,
        dtype=_model_dtype,
        device_map=device_map,
    )

    print(f"[voxtral] Applying LoRA adapter {ADAPTER_ID} ...")
    peft_model = PeftModel.from_pretrained(base_model, ADAPTER_ID)

    # Merge LoRA weights into base model and unload adapter β€” removes per-forward overhead
    print(f"[voxtral] Merging LoRA weights into base model ...")
    _model = peft_model.merge_and_unload()
    _model.eval()

    _model_device = next(_model.parameters()).device
    print(f"[voxtral] Model ready on {_model_device} (dtype={_model_dtype})")


def _transcribe_sync(wav_path: str) -> str:
    """Run local Voxtral inference (blocking β€” call via run_in_executor)."""
    import torch

    audio_array, _ = librosa.load(wav_path, sr=TARGET_SR, mono=True)

    inputs = _processor.apply_transcription_request(
        audio=[audio_array],
        format=["WAV"],
        language="en",
        model_id=MODEL_ID,
        return_tensors="pt",
    )

    # Move inputs to model device / dtype
    inputs = {
        k: (v.to(_model_device, dtype=_model_dtype)
            if v.dtype in (torch.float32, torch.float16, torch.bfloat16)
            else v.to(_model_device))
        if hasattr(v, "to") else v
        for k, v in inputs.items()
    }

    with torch.inference_mode():
        output_ids = _model.generate(**inputs, max_new_tokens=448, do_sample=False)

    input_len = inputs["input_ids"].shape[1]
    text = _processor.tokenizer.decode(
        output_ids[0][input_len:], skip_special_tokens=True
    ).strip()
    return text


# ─── FER setup ────────────────────────────────────────────────────────────────

_fer_session   = None
_fer_input_name = "input"
_face_cascade  = None
_FER_CLASSES   = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
_VIDEO_EXTS    = {".mp4", ".mkv", ".avi", ".mov", ".m4v"}


def _is_lfs_pointer(path: str) -> bool:
    """Return True if the file looks like a Git LFS pointer (small text file)."""
    try:
        size = os.path.getsize(path)
        if size > 10_000:
            return False
        with open(path, "rb") as f:
            header = f.read(64)
        return header.startswith(b"version https://git-lfs")
    except Exception:
        return False


def _resolve_lfs_model(fer_path: str) -> str:
    """
    If fer_path is a Git LFS pointer, download the real binary.
    Returns the path to the actual model file.
    """
    import urllib.request
    real_path = fer_path + ".resolved"
    # Use HF Space's own file resolution URL to download the actual binary
    url = "https://huggingface.co/spaces/mistral-hackaton-2026/ethos/resolve/main/models/emotion_model_web.onnx"
    print(f"[voxtral] FER: file is LFS pointer β€” downloading from {url}")
    try:
        urllib.request.urlretrieve(url, real_path)
        size = os.path.getsize(real_path)
        print(f"[voxtral] FER: downloaded {size} bytes to {real_path}")
        return real_path
    except Exception as e:
        print(f"[voxtral] FER: download failed: {e}")
        return fer_path


def _init_fer() -> None:
    global _fer_session, _fer_input_name, _face_cascade

    candidates = [
        os.environ.get("FER_MODEL_PATH", ""),
        "/app/models/emotion_model_web.onnx",                                          # Docker
        os.path.join(os.path.dirname(__file__), "../models/emotion_model_web.onnx"),   # local: api/../models/
        os.path.join(os.path.dirname(__file__), "../../models/emotion_model_web.onnx"), # fallback
    ]
    fer_path = next((p for p in candidates if p and os.path.exists(p)), None)
    if not fer_path:
        print("[voxtral] FER model not found β€” facial emotion disabled")
        return

    # Debug: log file size and first bytes to diagnose LFS pointer vs real binary
    try:
        file_size = os.path.getsize(fer_path)
        with open(fer_path, "rb") as f:
            first_bytes = f.read(32).hex()
        print(f"[voxtral] FER file: {fer_path} size={file_size} first_bytes={first_bytes}")
    except Exception as e:
        print(f"[voxtral] FER file stat error: {e}")

    # If it's a Git LFS pointer, download the actual binary
    if _is_lfs_pointer(fer_path):
        print("[voxtral] FER: detected Git LFS pointer β€” resolving...")
        fer_path = _resolve_lfs_model(fer_path)

    try:
        import onnxruntime as rt
        print(f"[voxtral] FER: onnxruntime version = {rt.__version__}")
        _fer_session    = rt.InferenceSession(fer_path, providers=["CPUExecutionProvider"])
        _fer_input_name = _fer_session.get_inputs()[0].name
        print(f"[voxtral] FER model loaded: {fer_path} (input={_fer_input_name}, shape={_fer_session.get_inputs()[0].shape})")
    except Exception as e:
        import traceback
        print(f"[voxtral] FER model load failed: {e}")
        print(f"[voxtral] FER traceback: {traceback.format_exc()}")
        return

    try:
        import cv2
        cascade_path  = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
        _face_cascade = cv2.CascadeClassifier(cascade_path)
        print("[voxtral] Face cascade loaded")
    except Exception as e:
        print(f"[voxtral] Face cascade load failed (FER will use center crop): {e}")


def _is_video(filename: str) -> bool:
    return os.path.splitext(filename)[1].lower() in _VIDEO_EXTS


def _fer_frame(img_bgr: np.ndarray) -> Optional[str]:
    """Detect face (or center-crop), run FER ONNX; return emotion label or None."""
    if _fer_session is None:
        return None
    try:
        import cv2
        face_crop = None

        if _face_cascade is not None:
            gray  = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
            faces = _face_cascade.detectMultiScale(gray, 1.1, 5, minSize=(40, 40))
            if len(faces) > 0:
                x, y, w, h = max(faces, key=lambda f: f[2] * f[3])
                pad = int(min(w, h) * 0.15)
                x1, y1 = max(0, x - pad), max(0, y - pad)
                x2, y2 = min(img_bgr.shape[1], x + w + pad), min(img_bgr.shape[0], y + h + pad)
                face_crop = img_bgr[y1:y2, x1:x2]

        if face_crop is None:
            h, w    = img_bgr.shape[:2]
            crop_h  = int(h * 0.6)
            cx      = w // 2
            half    = min(crop_h, w) // 2
            face_crop = img_bgr[:crop_h, max(0, cx - half):cx + half]

        resized = cv2.resize(face_crop, (224, 224))
        rgb     = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        # ImageNet normalization (matches original emotion-recognition.ts)
        mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        rgb  = (rgb - mean) / std
        tensor = np.transpose(rgb, (2, 0, 1))[np.newaxis]  # [1, 3, 224, 224]

        out = _fer_session.run(None, {_fer_input_name: tensor})[0]
        return _FER_CLASSES[int(np.argmax(out[0]))]
    except Exception as e:
        print(f"[voxtral] FER frame error: {e}")
        return None


def _fer_for_segments(
    video_path: str, segments: list[dict]
) -> tuple[dict[int, str], dict[int, str]]:
    """
    Extract ~1fps frames from video, run FER.
    Returns:
      segment_emotions : {segment_id: majority_emotion}
      timeline         : {second: emotion}  β€” per-second, for live panel
    """
    if _fer_session is None:
        return {}, {}

    frames_dir = tempfile.mkdtemp()
    try:
        import cv2
        from collections import Counter

        subprocess.run(
            ["ffmpeg", "-y", "-i", video_path,
             "-vf", "fps=1", "-vframes", "600",
             "-q:v", "5", os.path.join(frames_dir, "%06d.jpg")],
            capture_output=True, timeout=120,
        )
        frame_files = sorted(f for f in os.listdir(frames_dir) if f.endswith(".jpg"))
        if not frame_files:
            print("[voxtral] FER: no video frames extracted (audio-only?)")
            return {}, {}

        # Build per-second emotion map
        timeline: dict[int, str] = {}
        for fname in frame_files:
            second = int(os.path.splitext(fname)[0]) - 1
            img = cv2.imread(os.path.join(frames_dir, fname))
            if img is None:
                continue
            emo = _fer_frame(img)
            if emo:
                timeline[second] = emo

        # Majority-vote per segment
        segment_emotions: dict[int, str] = {}
        for seg in segments:
            start_s = int(seg["start"])
            end_s   = max(int(seg["end"]), start_s + 1)
            emos    = [timeline[s] for s in range(start_s, end_s) if s in timeline]
            if emos:
                segment_emotions[seg["id"]] = Counter(emos).most_common(1)[0][0]

        print(f"[voxtral] FER: {len(frame_files)} frames β†’ {len(segment_emotions)} segs, {len(timeline)} timeline pts")
        return segment_emotions, timeline
    except Exception as e:
        print(f"[voxtral] FER extraction error: {e}")
        return {}, {}
    finally:
        shutil.rmtree(frames_dir, ignore_errors=True)


# ─── Startup ──────────────────────────────────────────────────────────────────

def _check_ffmpeg():
    if shutil.which("ffmpeg") is None:
        raise RuntimeError(
            "ffmpeg not found.\n"
            "  macOS:   brew install ffmpeg\n"
            "  Ubuntu:  sudo apt install ffmpeg"
        )


@asynccontextmanager
async def lifespan(app: FastAPI):
    _check_ffmpeg()
    print(f"[voxtral] ffmpeg: {shutil.which('ffmpeg')}")
    _init_fer()
    _init_model()
    yield


app = FastAPI(title="Evoxtral Speech-to-Text (local)", lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)


@app.get("/debug-inference")
async def debug_inference():
    """Quick smoke-test: synthesize 0.5s of silence and run a minimal generate() call."""
    import traceback, torch
    if _model is None:
        return {"ok": False, "error": "model not loaded"}
    try:
        import numpy as np
        silence = np.zeros(8000, dtype=np.float32)  # 0.5 s @ 16 kHz
        import tempfile, soundfile as sf, asyncio
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
            wav_path = f.name
        sf.write(wav_path, silence, 16000)
        loop = asyncio.get_running_loop()
        text = await loop.run_in_executor(None, _transcribe_sync, wav_path)
        import os; os.unlink(wav_path)
        return {"ok": True, "text": text, "dtype": str(_model_dtype), "device": str(_model_device)}
    except Exception as e:
        return {"ok": False, "error": str(e), "traceback": traceback.format_exc()}


@app.get("/health")
async def health():
    return {
        "status": "ok",
        "model": f"{MODEL_ID} + {ADAPTER_ID} (local)",
        "model_loaded": _model is not None,
        "ffmpeg": shutil.which("ffmpeg") is not None,
        "fer_enabled": _fer_session is not None,
        "device": str(_model_device) if _model_device else None,
        "max_upload_mb": MAX_UPLOAD_BYTES // 1024 // 1024,
    }


# ─── Audio helpers ─────────────────────────────────────────────────────────────

def _convert_to_wav_ffmpeg(path: str, target_sr: int) -> str:
    out = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    out.close()
    rc = subprocess.run(
        ["ffmpeg", "-y", "-i", path,
         "-vn", "-acodec", "pcm_s16le", "-ar", str(target_sr), "-ac", "1",
         "-f", "wav", out.name],
        capture_output=True, timeout=120,
    )
    if rc.returncode != 0:
        os.unlink(out.name)
        raise RuntimeError(f"ffmpeg failed: {rc.stderr.decode(errors='replace')[:400]}")
    return out.name


def _load_audio(file_path: str, target_sr: int) -> np.ndarray:
    y, _ = librosa.load(file_path, sr=target_sr, mono=True)
    return y.astype(np.float32)


def _validate_upload(contents: bytes) -> None:
    if len(contents) == 0:
        raise HTTPException(status_code=400, detail="Audio file is empty")
    if len(contents) > MAX_UPLOAD_BYTES:
        raise HTTPException(
            status_code=413,
            detail=f"File too large ({len(contents)/1024/1024:.1f} MB); max {MAX_UPLOAD_BYTES//1024//1024} MB",
        )


# ─── Segmentation ──────────────────────────────────────────────────────────────

def _vad_segment(audio: np.ndarray, sr: int) -> list[tuple[int, int]]:
    intervals = librosa.effects.split(audio, top_db=28, frame_length=2048, hop_length=512)
    if len(intervals) == 0:
        return [(0, len(audio))]
    merged: list[list[int]] = [[int(intervals[0][0]), int(intervals[0][1])]]
    for s, e in intervals[1:]:
        if (int(s) - merged[-1][1]) / sr < 0.3:
            merged[-1][1] = int(e)
        else:
            merged.append([int(s), int(e)])
    result = [(s, e) for s, e in merged if (e - s) / sr >= 0.3]
    return result if result else [(0, len(audio))]


def _segments_from_vad(audio: np.ndarray, sr: int) -> tuple[list[dict], str]:
    intervals = _vad_segment(audio, sr)
    segs = [
        {"id": i + 1, "speaker": "SPEAKER_00", "start": round(s / sr, 3), "end": round(e / sr, 3)}
        for i, (s, e) in enumerate(intervals)
    ]
    print(f"[voxtral] VAD: {len(segs)} segment(s)")
    return segs, "vad"


def _split_sentences(text: str) -> list[str]:
    parts = re.split(r'(?<=[οΌŸοΌγ€‚?!])\s*', text)
    return [p for p in parts if p.strip()]


def _distribute_text(full_text: str, segs: list[dict]) -> list[dict]:
    if not full_text or not segs:
        return [{**s, "text": ""} for s in segs]
    if len(segs) == 1:
        return [{**segs[0], "text": full_text}]
    sentences = _split_sentences(full_text)
    if len(sentences) <= 1:
        is_cjk = len(full_text.split()) <= 1
        sentences = list(full_text) if is_cjk else full_text.split()
    total_dur = sum(s["end"] - s["start"] for s in segs)
    if total_dur <= 0:
        return [{**segs[0], "text": full_text}] + [{**s, "text": ""} for s in segs[1:]]
    is_cjk = len(full_text.split()) <= 1 and len(full_text) > 1
    sep = "" if is_cjk else " "
    n = len(sentences)
    result_texts: list[list[str]] = [[] for _ in segs]
    cumulative = 0.0
    for i, seg in enumerate(segs):
        cumulative += (seg["end"] - seg["start"]) / total_dur
        threshold = cumulative * n
        while len(result_texts[i]) + sum(len(t) for t in result_texts[:i]) < round(threshold):
            idx = sum(len(t) for t in result_texts)
            if idx >= n:
                break
            result_texts[i].append(sentences[idx])
    assigned = sum(len(t) for t in result_texts)
    result_texts[-1].extend(sentences[assigned:])
    return [{**seg, "text": sep.join(texts)} for seg, texts in zip(segs, result_texts)]


# ─── Emotion parsing from evoxtral expression tags ─────────────────────────────

_TAG_EMOTIONS: dict[str, tuple[str, float, float]] = {
    "laughs":           ("Happy",       0.70,  0.60),
    "laughing":         ("Happy",       0.70,  0.60),
    "chuckles":         ("Happy",       0.50,  0.30),
    "giggles":          ("Happy",       0.60,  0.40),
    "sighs":            ("Sad",        -0.30, -0.30),
    "sighing":          ("Sad",        -0.30, -0.30),
    "cries":            ("Sad",        -0.70,  0.40),
    "crying":           ("Sad",        -0.70,  0.40),
    "whispers":         ("Calm",        0.10, -0.50),
    "whispering":       ("Calm",        0.10, -0.50),
    "shouts":           ("Angry",      -0.50,  0.80),
    "shouting":         ("Angry",      -0.50,  0.80),
    "exclaims":         ("Excited",     0.50,  0.70),
    "gasps":            ("Surprised",   0.20,  0.70),
    "hesitates":        ("Anxious",    -0.20,  0.30),
    "stutters":         ("Anxious",    -0.20,  0.40),
    "stammers":         ("Anxious",    -0.25,  0.35),
    "mumbles":          ("Sad",        -0.20, -0.30),
    "nervous":          ("Anxious",    -0.30,  0.40),
    "frustrated":       ("Frustrated", -0.50,  0.50),
    "excited":          ("Excited",     0.50,  0.70),
    "sad":              ("Sad",        -0.60, -0.20),
    "angry":            ("Angry",      -0.60,  0.70),
    "claps":            ("Happy",       0.60,  0.50),
    "applause":         ("Happy",       0.60,  0.50),
    "clears throat":    ("Neutral",     0.00,  0.10),
    "pause":            ("Neutral",     0.00, -0.10),
    "laughs nervously": ("Anxious",    -0.10,  0.40),
}


def _parse_emotion(text: str) -> dict:
    tags = re.findall(r'\[([^\]]+)\]', text.lower())
    for tag in tags:
        tag = tag.strip()
        if tag in _TAG_EMOTIONS:
            label, valence, arousal = _TAG_EMOTIONS[tag]
            return {"emotion": label, "valence": valence, "arousal": arousal}
        for key, (label, valence, arousal) in _TAG_EMOTIONS.items():
            if key in tag:
                return {"emotion": label, "valence": valence, "arousal": arousal}
    return {"emotion": "Neutral", "valence": 0.0, "arousal": 0.0}


# ─── Endpoints ─────────────────────────────────────────────────────────────────

@app.post("/transcribe")
async def transcribe(audio: UploadFile = File(...)):
    req_start = time.perf_counter()
    filename  = audio.filename or "audio.wav"
    print(f"[voxtral] POST /transcribe filename={filename}")

    if _model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    contents = await audio.read()
    _validate_upload(contents)

    suffix = os.path.splitext(filename)[1].lower() or ".wav"
    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
        tmp.write(contents)
        tmp_path = tmp.name

    wav_path = None
    try:
        wav_path = _convert_to_wav_ffmpeg(tmp_path, TARGET_SR)
        loop = asyncio.get_running_loop()
        text = await loop.run_in_executor(None, _transcribe_sync, wav_path)
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Cannot process audio: {e}")
    finally:
        for p in (tmp_path, wav_path):
            if p and os.path.exists(p):
                try: os.unlink(p)
                except OSError: pass

    print(f"[voxtral] /transcribe done {(time.perf_counter()-req_start)*1000:.0f}ms")
    return {"text": text, "words": []}


@app.post("/transcribe-diarize")
async def transcribe_diarize(audio: UploadFile = File(...)):
    """
    Upload audio/video β†’ transcription + VAD segmentation + emotion.
    For video files (.mp4, .mkv, .avi, .mov, .m4v), also runs FER.
    """
    req_start = time.perf_counter()
    req_id    = f"diarize-{int(req_start * 1000)}"
    filename  = audio.filename or "audio.wav"
    print(f"[voxtral] {req_id} POST /transcribe-diarize filename={filename}")

    if _model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    contents = await audio.read()
    _validate_upload(contents)

    suffix = os.path.splitext(filename)[1].lower() or ".wav"
    if suffix not in (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm",
                      ".mp4", ".mkv", ".avi", ".mov", ".m4v"):
        suffix = ".wav"

    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
        tmp.write(contents)
        tmp_path = tmp.name

    wav_path = None
    try:
        t0 = time.perf_counter()
        wav_path     = _convert_to_wav_ffmpeg(tmp_path, TARGET_SR)
        audio_array  = _load_audio(wav_path, TARGET_SR)
        print(f"[voxtral] {req_id} audio loaded shape={audio_array.shape} in {(time.perf_counter()-t0)*1000:.0f}ms")
    except Exception as e:
        for p in (tmp_path, wav_path):
            if p and os.path.exists(p):
                try: os.unlink(p)
                except OSError: pass
        raise HTTPException(status_code=400, detail=f"Cannot decode audio: {e}")

    duration = round(len(audio_array) / TARGET_SR, 3)

    # ── STT (local model, run in thread pool) ────────────────────────────────
    try:
        t0   = time.perf_counter()
        loop = asyncio.get_running_loop()
        full_text = await loop.run_in_executor(None, _transcribe_sync, wav_path)
        print(f"[voxtral] {req_id} STT done {(time.perf_counter()-t0)*1000:.0f}ms text_len={len(full_text)}")
    except Exception as e:
        import traceback as _tb
        print(f"[voxtral] {req_id} STT error: {e}\n{_tb.format_exc()}")
        raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
    finally:
        if wav_path and os.path.exists(wav_path):
            try: os.unlink(wav_path)
            except OSError: pass

    # ── VAD segmentation + text distribution ─────────────────────────────────
    raw_segs, seg_method = _segments_from_vad(audio_array, TARGET_SR)
    segs_with_text       = _distribute_text(full_text, raw_segs)

    # ── FER (video only) ─────────────────────────────────────────────────────
    has_fer              = False
    face_emotions:  dict[int, str] = {}
    fer_timeline:   dict[int, str] = {}
    if _is_video(filename) and _fer_session is not None:
        t0 = time.perf_counter()
        face_emotions, fer_timeline = await loop.run_in_executor(
            None, _fer_for_segments, tmp_path, raw_segs
        )
        has_fer = bool(face_emotions)
        print(f"[voxtral] {req_id} FER done {(time.perf_counter()-t0)*1000:.0f}ms faces={len(face_emotions)} timeline={len(fer_timeline)}")

    if tmp_path and os.path.exists(tmp_path):
        try: os.unlink(tmp_path)
        except OSError: pass

    # ── Build segments ────────────────────────────────────────────────────────
    segments = []
    for s in segs_with_text:
        emo      = _parse_emotion(s["text"])
        seg_data = {
            "id":      s["id"],
            "speaker": s["speaker"],
            "start":   s["start"],
            "end":     s["end"],
            "text":    s["text"],
            "emotion": emo["emotion"],
            "valence": emo["valence"],
            "arousal": emo["arousal"],
        }
        if s["id"] in face_emotions:
            seg_data["face_emotion"] = face_emotions[s["id"]]
        segments.append(seg_data)

    total_ms = (time.perf_counter() - req_start) * 1000
    print(f"[voxtral] {req_id} complete total={total_ms:.0f}ms segments={len(segments)} has_fer={has_fer}")

    return {
        "segments":               segments,
        "duration":               duration,
        "text":                   full_text,
        "filename":               filename,
        "diarization_method":     seg_method,
        "has_video":              has_fer,
        # Per-second face emotion timeline for live playback panel
        # Keys are strings (JSON), values are emotion labels e.g. "Happy"
        "face_emotion_timeline":  {str(k): v for k, v in fer_timeline.items()},
    }