mazesmazes commited on
Commit
8fa5959
·
verified ·
1 Parent(s): 43c4368

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. diarization.py +60 -5
diarization.py CHANGED
@@ -275,8 +275,9 @@ class LocalSpeakerDiarizer:
275
  # ==================== TUNABLE PARAMETERS ====================
276
 
277
  # Sliding window for embedding extraction
278
- WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
279
- STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
 
280
  TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
281
 
282
  # VAD hysteresis parameters
@@ -290,8 +291,8 @@ class LocalSpeakerDiarizer:
290
  VOTING_RATE = 0.01 # 10ms resolution for consensus voting
291
 
292
  # Post-processing
293
- MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
294
- SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
295
  SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
296
 
297
  # ===========================================================
@@ -381,7 +382,10 @@ class LocalSpeakerDiarizer:
381
  clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
382
  labels = clusterer(embeddings, num_speakers)
383
 
384
- # Step 4: Post-process with consensus voting (VAD-aware)
 
 
 
385
  return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
386
 
387
  @classmethod
@@ -479,6 +483,57 @@ class LocalSpeakerDiarizer:
479
 
480
  return filtered
481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  @classmethod
483
  def _extract_embeddings(
484
  cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
 
275
  # ==================== TUNABLE PARAMETERS ====================
276
 
277
  # Sliding window for embedding extraction
278
+ # Longer windows (1.5-2.0s) capture more prosody, reducing speaker confusion
279
+ WINDOW_SIZE = 1.5 # seconds
280
+ STEP_SIZE = 0.5 # seconds (67% overlap)
281
  TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
282
 
283
  # VAD hysteresis parameters
 
291
  VOTING_RATE = 0.01 # 10ms resolution for consensus voting
292
 
293
  # Post-processing
294
+ MIN_SEGMENT_DURATION = 0.3 # Minimum final segment duration (seconds)
295
+ SHORT_SEGMENT_GAP = 0.3 # Gap threshold for merging short segments
296
  SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
297
 
298
  # ===========================================================
 
382
  clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
383
  labels = clusterer(embeddings, num_speakers)
384
 
385
+ # Step 4: Centroid refinement - reduces flickering/confusion
386
+ labels = cls._refine_with_centroids(embeddings, labels)
387
+
388
+ # Step 5: Post-process with consensus voting (VAD-aware)
389
  return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
390
 
391
  @classmethod
 
483
 
484
  return filtered
485
 
486
+ @classmethod
487
+ def _refine_with_centroids(cls, embeddings: np.ndarray, labels: np.ndarray) -> np.ndarray:
488
+ """Refine cluster assignments using nearest centroid.
489
+
490
+ This reduces "flickering" where embeddings rapidly switch between speakers.
491
+ For each embedding, we re-assign it to the speaker whose centroid is closest
492
+ (by cosine similarity).
493
+
494
+ Args:
495
+ embeddings: Speaker embeddings of shape [N, D]
496
+ labels: Initial cluster labels of shape [N]
497
+
498
+ Returns:
499
+ Refined labels of shape [N]
500
+ """
501
+ if len(embeddings) == 0 or len(np.unique(labels)) <= 1:
502
+ return labels
503
+
504
+ # Normalize embeddings for cosine similarity
505
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
506
+ norms = np.maximum(norms, 1e-10)
507
+ norm_embeddings = embeddings / norms
508
+
509
+ # Calculate centroid for each speaker
510
+ unique_labels = np.unique(labels)
511
+ centroids = {}
512
+ for label in unique_labels:
513
+ mask = labels == label
514
+ speaker_embs = norm_embeddings[mask]
515
+ centroid = speaker_embs.mean(axis=0)
516
+ # Normalize centroid
517
+ centroid_norm = np.linalg.norm(centroid)
518
+ if centroid_norm > 1e-10:
519
+ centroids[label] = centroid / centroid_norm
520
+ else:
521
+ centroids[label] = centroid
522
+
523
+ # Re-assign each embedding to nearest centroid
524
+ refined_labels = np.zeros_like(labels)
525
+ for i, emb in enumerate(norm_embeddings):
526
+ best_label = labels[i]
527
+ best_sim = -1.0
528
+ for label, centroid in centroids.items():
529
+ sim = np.dot(emb, centroid)
530
+ if sim > best_sim:
531
+ best_sim = sim
532
+ best_label = label
533
+ refined_labels[i] = best_label
534
+
535
+ return refined_labels
536
+
537
  @classmethod
538
  def _extract_embeddings(
539
  cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int