Lior-0618 Claude Sonnet 4.6 commited on
Commit
4c097b5
Β·
1 Parent(s): fdef69c

refactor: replace speaker clustering with VAD sentence segmentation

Browse files

Single-speaker mode: segments by silence (0.3s threshold), each segment
gets independent emotion analysis. Removes resemblyzer/MFCC clustering.
Also fixes MPS device (float16), adds warm-up, and CJK text distribution.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

model/voxtral-server/main.py CHANGED
@@ -64,6 +64,8 @@ def _get_pyannote_pipeline():
64
  )
65
  if torch.cuda.is_available():
66
  pipeline = pipeline.to(torch.device("cuda"))
 
 
67
  _pyannote_pipeline = pipeline
68
  print("[voxtral] pyannote speaker-diarization-3.1 loaded")
69
  except Exception as e:
@@ -79,6 +81,17 @@ async def lifespan(app: FastAPI):
79
  _check_ffmpeg()
80
  print(f"[voxtral] ffmpeg: {shutil.which('ffmpeg')}")
81
 
 
 
 
 
 
 
 
 
 
 
 
82
  print(f"[voxtral] Loading model: {REPO_ID} (first run may download ~8–16GB)...")
83
  try:
84
  from transformers import (
@@ -87,10 +100,10 @@ async def lifespan(app: FastAPI):
87
  )
88
  processor = AutoProcessor.from_pretrained(REPO_ID)
89
  model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
90
- REPO_ID, device_map="auto", torch_dtype=torch.bfloat16
91
- )
92
  model.eval()
93
- print(f"[voxtral] Model loaded: {REPO_ID}")
94
  except Exception as e:
95
  raise RuntimeError(
96
  f"Model load failed: {e}\n"
@@ -98,6 +111,20 @@ async def lifespan(app: FastAPI):
98
  "And sufficient VRAM (recommended β‰₯16GB) or use CPU (slower)."
99
  ) from e
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  yield
102
 
103
 
@@ -197,137 +224,56 @@ def _validate_upload(contents: bytes) -> None:
197
  )
198
 
199
 
200
- # ─── Speaker diarization helpers ───────────────────────────────────────────────
201
 
202
- def _vad_split(audio: np.ndarray, sr: int) -> list[tuple[int, int]]:
203
- """Split audio on silence, merge gaps < 0.8 s, filter segments < 0.5 s.
204
- Returns list of (start_sample, end_sample) tuples.
 
205
  """
206
  intervals = librosa.effects.split(audio, top_db=28, frame_length=2048, hop_length=512)
207
  if len(intervals) == 0:
208
  return [(0, len(audio))]
209
 
210
- # Merge intervals with gap < 0.8 s
211
  merged: list[list[int]] = [[int(intervals[0][0]), int(intervals[0][1])]]
212
  for s, e in intervals[1:]:
213
- if (int(s) - merged[-1][1]) / sr < 0.8:
214
  merged[-1][1] = int(e)
215
  else:
216
  merged.append([int(s), int(e)])
217
 
218
- # Filter very short segments
219
- result = [(s, e) for s, e in merged if (e - s) / sr >= 0.4]
220
  return result if result else [(0, len(audio))]
221
 
222
 
223
- def _extract_mfcc_features(segments: list[tuple[int, int]], audio: np.ndarray, sr: int) -> np.ndarray:
224
- """Extract normalised MFCC feature matrix (one row per segment)."""
225
- from sklearn.preprocessing import StandardScaler
226
-
227
- rows = []
228
- for s, e in segments:
229
- chunk = audio[s:e]
230
- if len(chunk) < 512:
231
- rows.append(np.zeros(40))
232
- continue
233
- mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=20)
234
- rows.append(np.concatenate([mfcc.mean(axis=1), mfcc.std(axis=1)]))
235
- X = np.array(rows)
236
- return StandardScaler().fit_transform(X)
237
-
238
-
239
- def _auto_num_speakers(X: np.ndarray, max_speakers: int = 8) -> int:
240
- """Pick the number of speakers that maximises silhouette score (k=2..max_k)."""
241
- from sklearn.cluster import KMeans
242
- from sklearn.metrics import silhouette_score
243
-
244
- max_k = min(max_speakers, len(X))
245
- if max_k < 2:
246
- return 1
247
-
248
- best_k, best_score = 2, -1.0
249
- for k in range(2, max_k + 1):
250
- labels = KMeans(n_clusters=k, random_state=42, n_init=10).fit_predict(X)
251
- if len(set(labels)) < 2:
252
- continue
253
- score = float(silhouette_score(X, labels))
254
- print(f"[voxtral] silhouette k={k}: {score:.4f}")
255
- if score > best_score:
256
- best_score, best_k = score, k
257
-
258
- print(f"[voxtral] auto-detected {best_k} speaker(s) (silhouette={best_score:.4f})")
259
- return best_k
260
-
261
-
262
- def _cluster_speakers(
263
- segments: list[tuple[int, int]],
264
- audio: np.ndarray,
265
- sr: int,
266
- n_speakers: int, # 0 = auto-detect
267
- ) -> list[int]:
268
- """Assign speaker IDs to segments via MFCC + KMeans.
269
- Pass n_speakers=0 to automatically detect the speaker count.
270
- Falls back to alternating labels if sklearn unavailable or clustering fails.
271
  """
272
- if len(segments) <= 1:
273
- return [0] * len(segments)
 
 
 
 
 
274
 
275
- try:
276
- from sklearn.cluster import KMeans
277
 
278
- X = _extract_mfcc_features(segments, audio, sr)
279
-
280
- if n_speakers == 0:
281
- n_speakers = _auto_num_speakers(X)
282
-
283
- n_speakers = min(n_speakers, len(segments))
284
- labels = KMeans(n_clusters=n_speakers, random_state=42, n_init=10).fit_predict(X)
285
- return [int(l) for l in labels]
286
- except Exception as e:
287
- print(f"[voxtral] MFCC clustering failed: {e} β€” using alternating labels")
288
- k = max(1, n_speakers)
289
- return [i % k for i in range(len(segments))]
290
-
291
-
292
- def _segments_from_pyannote(wav_path: str) -> Optional[list[dict]]:
293
- """Run pyannote pipeline and return raw segments. Returns None if unavailable."""
294
- pipeline = _get_pyannote_pipeline()
295
- if pipeline is None:
296
- return None
297
- try:
298
- diarization = pipeline(wav_path)
299
- segs = []
300
- for turn, _, speaker in diarization.itertracks(yield_label=True):
301
- segs.append({
302
- "speaker": speaker,
303
- "start": round(turn.start, 3),
304
- "end": round(turn.end, 3),
305
- })
306
- return segs if segs else None
307
- except Exception as e:
308
- print(f"[voxtral] pyannote inference failed: {e}")
309
- return None
310
-
311
-
312
- def _segments_from_vad(audio: np.ndarray, sr: int, n_speakers: int) -> list[dict]:
313
- """Fallback: VAD split + MFCC speaker clustering."""
314
- intervals = _vad_split(audio, sr)
315
- labels = _cluster_speakers(intervals, audio, sr, n_speakers)
316
- segs = []
317
- for (s, e), spk in zip(intervals, labels):
318
- segs.append({
319
- "speaker": f"SPEAKER_{spk:02d}",
320
- "start": round(s / sr, 3),
321
- "end": round(e / sr, 3),
322
- })
323
- return segs
324
 
325
 
326
  def _distribute_text(full_text: str, segs: list[dict]) -> list[dict]:
327
- """Proportionally distribute transcription words across segments by duration."""
328
- words = full_text.split()
329
- total_words = len(words)
330
- if not words or not segs:
331
  return [{**s, "text": ""} for s in segs]
332
 
333
  total_dur = sum(s["end"] - s["start"] for s in segs)
@@ -335,18 +281,21 @@ def _distribute_text(full_text: str, segs: list[dict]) -> list[dict]:
335
  result = [{**segs[0], "text": full_text}]
336
  return result + [{**s, "text": ""} for s in segs[1:]]
337
 
 
 
 
338
  result: list[dict] = []
339
- word_idx = 0
340
  for i, seg in enumerate(segs):
341
  dur = seg["end"] - seg["start"]
342
  frac = dur / total_dur
343
- n = round(frac * total_words)
344
  if i == len(segs) - 1:
345
- chunk = words[word_idx:]
346
  else:
347
- chunk = words[word_idx: word_idx + max(1, n)]
348
- result.append({**seg, "text": " ".join(chunk)})
349
- word_idx += len(chunk)
350
  return result
351
 
352
 
@@ -459,10 +408,10 @@ async def transcribe(audio: UploadFile = File(...)):
459
  except OSError:
460
  pass
461
 
462
- with torch.no_grad():
463
  inputs = processor(audio_array, return_tensors="pt")
464
  inputs = inputs.to(model.device, dtype=model.dtype)
465
- outputs = model.generate(**{k: v for k, v in inputs.items()})
466
  decoded = processor.batch_decode(outputs, skip_special_tokens=True)
467
 
468
  text = (decoded[0] or "").strip()
@@ -474,20 +423,16 @@ async def transcribe(audio: UploadFile = File(...)):
474
  @app.post("/transcribe-diarize")
475
  async def transcribe_diarize(
476
  audio: UploadFile = File(...),
477
- num_speakers: int = Query(default=0, ge=0, le=10, description="Expected number of speakers; 0 = auto-detect"),
478
  ):
479
  """
480
- Upload audio β†’ full transcription + speaker diarization.
481
  Returns structured segments: [{id, speaker, start, end, text, emotion, valence, arousal}]
482
-
483
- Speaker detection method (in priority order):
484
- 1. pyannote/speaker-diarization-3.1 (needs HF_TOKEN + pyannote.audio installed)
485
- 2. VAD silence split + MFCC KMeans clustering (zero extra deps, always available)
486
  """
487
  req_start = time.perf_counter()
488
  req_id = f"diarize-{int(req_start * 1000)}"
489
  filename = audio.filename or "audio.wav"
490
- print(f"[voxtral] {req_id} POST /transcribe-diarize filename={filename} num_speakers={num_speakers}")
491
 
492
  try:
493
  contents = await audio.read()
@@ -523,38 +468,18 @@ async def transcribe_diarize(
523
 
524
  # ── Step 1: full transcription via Voxtral ──────────────────────────────
525
  t0 = time.perf_counter()
526
- with torch.no_grad():
527
  inputs = processor(audio_array, return_tensors="pt")
528
  inputs = inputs.to(model.device, dtype=model.dtype)
529
- outputs = model.generate(**{k: v for k, v in inputs.items()})
530
  decoded = processor.batch_decode(outputs, skip_special_tokens=True)
531
  full_text = (decoded[0] or "").strip()
532
  print(f"[voxtral] {req_id} transcription done in {(time.perf_counter()-t0)*1000:.0f}ms text_len={len(full_text)}")
533
 
534
- # ── Step 2: speaker diarization ─────────────────────────────────────────
535
  t0 = time.perf_counter()
536
- raw_segs: Optional[list[dict]] = None
537
-
538
- # Try pyannote first (requires HF_TOKEN)
539
- if _pyannote_available and HF_TOKEN:
540
- wav_tmp = None
541
- try:
542
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
543
- sf.write(f.name, audio_array, target_sr)
544
- wav_tmp = f.name
545
- raw_segs = _segments_from_pyannote(wav_tmp)
546
- except Exception as e:
547
- print(f"[voxtral] {req_id} pyannote error: {e}")
548
- finally:
549
- if wav_tmp and os.path.exists(wav_tmp):
550
- os.unlink(wav_tmp)
551
- if raw_segs:
552
- print(f"[voxtral] {req_id} pyannote diarization done in {(time.perf_counter()-t0)*1000:.0f}ms segs={len(raw_segs)}")
553
-
554
- # Fallback: VAD + MFCC clustering
555
- if not raw_segs:
556
- raw_segs = _segments_from_vad(audio_array, target_sr, num_speakers)
557
- print(f"[voxtral] {req_id} VAD+MFCC diarization done in {(time.perf_counter()-t0)*1000:.0f}ms segs={len(raw_segs)}")
558
 
559
  # ── Step 3: distribute text proportionally ──────────────────────────────
560
  segs_with_text = _distribute_text(full_text, raw_segs)
@@ -587,5 +512,5 @@ async def transcribe_diarize(
587
  "duration": duration,
588
  "text": full_text,
589
  "filename": filename,
590
- "diarization_method": "pyannote" if (raw_segs and _pyannote_available and HF_TOKEN) else "vad_mfcc",
591
  }
 
64
  )
65
  if torch.cuda.is_available():
66
  pipeline = pipeline.to(torch.device("cuda"))
67
+ elif torch.backends.mps.is_available():
68
+ pipeline = pipeline.to(torch.device("mps"))
69
  _pyannote_pipeline = pipeline
70
  print("[voxtral] pyannote speaker-diarization-3.1 loaded")
71
  except Exception as e:
 
81
  _check_ffmpeg()
82
  print(f"[voxtral] ffmpeg: {shutil.which('ffmpeg')}")
83
 
84
+ if torch.cuda.is_available():
85
+ _device = torch.device("cuda")
86
+ _dtype = torch.bfloat16
87
+ elif torch.backends.mps.is_available():
88
+ _device = torch.device("mps")
89
+ _dtype = torch.float16 # MPS does not support bfloat16
90
+ else:
91
+ _device = torch.device("cpu")
92
+ _dtype = torch.float32
93
+ print(f"[voxtral] Device: {_device} dtype: {_dtype}")
94
+
95
  print(f"[voxtral] Loading model: {REPO_ID} (first run may download ~8–16GB)...")
96
  try:
97
  from transformers import (
 
100
  )
101
  processor = AutoProcessor.from_pretrained(REPO_ID)
102
  model = VoxtralRealtimeForConditionalGeneration.from_pretrained(
103
+ REPO_ID, torch_dtype=_dtype
104
+ ).to(_device)
105
  model.eval()
106
+ print(f"[voxtral] Model loaded: {REPO_ID} on {_device}")
107
  except Exception as e:
108
  raise RuntimeError(
109
  f"Model load failed: {e}\n"
 
111
  "And sufficient VRAM (recommended β‰₯16GB) or use CPU (slower)."
112
  ) from e
113
 
114
+ # Warm-up: run one silent dummy inference to pre-compile MPS Metal shaders.
115
+ # Without this the first real request pays a ~15s compilation penalty.
116
+ print("[voxtral] Warming up MPS shaders (dummy inference)...")
117
+ try:
118
+ sr = processor.feature_extractor.sampling_rate
119
+ dummy = np.zeros(sr, dtype=np.float32) # 1 second of silence
120
+ with torch.inference_mode():
121
+ dummy_inputs = processor(dummy, return_tensors="pt")
122
+ dummy_inputs = dummy_inputs.to(_device, dtype=_dtype)
123
+ model.generate(**dummy_inputs, max_new_tokens=1)
124
+ print("[voxtral] Warm-up complete β€” first request will be fast")
125
+ except Exception as e:
126
+ print(f"[voxtral] Warm-up skipped: {e}")
127
+
128
  yield
129
 
130
 
 
224
  )
225
 
226
 
227
+ # ─── Segmentation helpers ──────────────────────────────────────────────────────
228
 
229
+ def _vad_segment(audio: np.ndarray, sr: int) -> list[tuple[int, int]]:
230
+ """Split audio into speech segments by silence detection.
231
+ Merges gaps < 0.5 s (intra-phrase pauses) and drops segments < 0.3 s.
232
+ Returns list of (start_sample, end_sample).
233
  """
234
  intervals = librosa.effects.split(audio, top_db=28, frame_length=2048, hop_length=512)
235
  if len(intervals) == 0:
236
  return [(0, len(audio))]
237
 
 
238
  merged: list[list[int]] = [[int(intervals[0][0]), int(intervals[0][1])]]
239
  for s, e in intervals[1:]:
240
+ if (int(s) - merged[-1][1]) / sr < 0.3:
241
  merged[-1][1] = int(e)
242
  else:
243
  merged.append([int(s), int(e)])
244
 
245
+ result = [(s, e) for s, e in merged if (e - s) / sr >= 0.3]
 
246
  return result if result else [(0, len(audio))]
247
 
248
 
249
+ def _segments_from_vad(audio: np.ndarray, sr: int) -> tuple[list[dict], str]:
250
+ """Segment audio by silence, assign all segments to SPEAKER_00.
251
+ Returns (segments, method_name).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  """
253
+ intervals = _vad_segment(audio, sr)
254
+ segs = [
255
+ {"speaker": "SPEAKER_00", "start": round(s / sr, 3), "end": round(e / sr, 3)}
256
+ for s, e in intervals
257
+ ]
258
+ print(f"[voxtral] VAD segmentation: {len(segs)} segment(s)")
259
+ return segs, "vad"
260
 
 
 
261
 
262
+ def _tokenize_text(text: str) -> list[str]:
263
+ """Split text into tokens. For CJK text (no spaces), split by character.
264
+ For space-separated languages, split by whitespace."""
265
+ tokens = text.split()
266
+ # If no spaces found (e.g. Chinese/Japanese), split by character instead
267
+ if len(tokens) <= 1 and len(text) > 1:
268
+ return list(text)
269
+ return tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
 
272
  def _distribute_text(full_text: str, segs: list[dict]) -> list[dict]:
273
+ """Proportionally distribute transcription tokens across segments by duration."""
274
+ tokens = _tokenize_text(full_text)
275
+ total_tokens = len(tokens)
276
+ if not tokens or not segs:
277
  return [{**s, "text": ""} for s in segs]
278
 
279
  total_dur = sum(s["end"] - s["start"] for s in segs)
 
281
  result = [{**segs[0], "text": full_text}]
282
  return result + [{**s, "text": ""} for s in segs[1:]]
283
 
284
+ is_cjk = len(full_text.split()) <= 1 and len(full_text) > 1
285
+ sep = "" if is_cjk else " "
286
+
287
  result: list[dict] = []
288
+ token_idx = 0
289
  for i, seg in enumerate(segs):
290
  dur = seg["end"] - seg["start"]
291
  frac = dur / total_dur
292
+ n = round(frac * total_tokens)
293
  if i == len(segs) - 1:
294
+ chunk = tokens[token_idx:]
295
  else:
296
+ chunk = tokens[token_idx: token_idx + max(1, n)]
297
+ result.append({**seg, "text": sep.join(chunk)})
298
+ token_idx += len(chunk)
299
  return result
300
 
301
 
 
408
  except OSError:
409
  pass
410
 
411
+ with torch.inference_mode():
412
  inputs = processor(audio_array, return_tensors="pt")
413
  inputs = inputs.to(model.device, dtype=model.dtype)
414
+ outputs = model.generate(**{k: v for k, v in inputs.items()}, max_new_tokens=1024)
415
  decoded = processor.batch_decode(outputs, skip_special_tokens=True)
416
 
417
  text = (decoded[0] or "").strip()
 
423
  @app.post("/transcribe-diarize")
424
  async def transcribe_diarize(
425
  audio: UploadFile = File(...),
 
426
  ):
427
  """
428
+ Upload audio β†’ transcription + VAD sentence segmentation + per-segment emotion analysis.
429
  Returns structured segments: [{id, speaker, start, end, text, emotion, valence, arousal}]
430
+ All segments are labelled SPEAKER_00 (single-speaker mode).
 
 
 
431
  """
432
  req_start = time.perf_counter()
433
  req_id = f"diarize-{int(req_start * 1000)}"
434
  filename = audio.filename or "audio.wav"
435
+ print(f"[voxtral] {req_id} POST /transcribe-diarize filename={filename}")
436
 
437
  try:
438
  contents = await audio.read()
 
468
 
469
  # ── Step 1: full transcription via Voxtral ──────────────────────────────
470
  t0 = time.perf_counter()
471
+ with torch.inference_mode():
472
  inputs = processor(audio_array, return_tensors="pt")
473
  inputs = inputs.to(model.device, dtype=model.dtype)
474
+ outputs = model.generate(**{k: v for k, v in inputs.items()}, max_new_tokens=1024)
475
  decoded = processor.batch_decode(outputs, skip_special_tokens=True)
476
  full_text = (decoded[0] or "").strip()
477
  print(f"[voxtral] {req_id} transcription done in {(time.perf_counter()-t0)*1000:.0f}ms text_len={len(full_text)}")
478
 
479
+ # ── Step 2: VAD sentence segmentation ───────────────────────────────────
480
  t0 = time.perf_counter()
481
+ raw_segs, seg_method = _segments_from_vad(audio_array, target_sr)
482
+ print(f"[voxtral] {req_id} segmentation done in {(time.perf_counter()-t0)*1000:.0f}ms segs={len(raw_segs)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
  # ── Step 3: distribute text proportionally ──────────────────────────────
485
  segs_with_text = _distribute_text(full_text, raw_segs)
 
512
  "duration": duration,
513
  "text": full_text,
514
  "filename": filename,
515
+ "diarization_method": seg_method,
516
  }
model/voxtral-server/requirements.txt CHANGED
@@ -11,7 +11,8 @@ librosa>=0.10.0
11
  soundfile>=0.12.0
12
  numpy>=1.24.0
13
  scikit-learn>=1.3.0
 
14
 
15
- # Optional: real speaker diarization (requires HF_TOKEN env var + model license acceptance)
16
  # pip install pyannote.audio>=3.1.0
17
  # Then: export HF_TOKEN=your_token
 
11
  soundfile>=0.12.0
12
  numpy>=1.24.0
13
  scikit-learn>=1.3.0
14
+ resemblyzer>=0.1.1.dev0
15
 
16
+ # Optional: production-grade speaker diarization (requires HF_TOKEN env var + model license acceptance)
17
  # pip install pyannote.audio>=3.1.0
18
  # Then: export HF_TOKEN=your_token