Update custom model files, README, and requirements
Browse files- diarization.py +60 -5
diarization.py
CHANGED
|
@@ -275,8 +275,9 @@ class LocalSpeakerDiarizer:
|
|
| 275 |
# ==================== TUNABLE PARAMETERS ====================
|
| 276 |
|
| 277 |
# Sliding window for embedding extraction
|
| 278 |
-
|
| 279 |
-
|
|
|
|
| 280 |
TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
|
| 281 |
|
| 282 |
# VAD hysteresis parameters
|
|
@@ -290,8 +291,8 @@ class LocalSpeakerDiarizer:
|
|
| 290 |
VOTING_RATE = 0.01 # 10ms resolution for consensus voting
|
| 291 |
|
| 292 |
# Post-processing
|
| 293 |
-
MIN_SEGMENT_DURATION = 0.
|
| 294 |
-
SHORT_SEGMENT_GAP = 0.
|
| 295 |
SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
|
| 296 |
|
| 297 |
# ===========================================================
|
|
@@ -381,7 +382,10 @@ class LocalSpeakerDiarizer:
|
|
| 381 |
clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
|
| 382 |
labels = clusterer(embeddings, num_speakers)
|
| 383 |
|
| 384 |
-
# Step 4:
|
|
|
|
|
|
|
|
|
|
| 385 |
return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
|
| 386 |
|
| 387 |
@classmethod
|
|
@@ -479,6 +483,57 @@ class LocalSpeakerDiarizer:
|
|
| 479 |
|
| 480 |
return filtered
|
| 481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
@classmethod
|
| 483 |
def _extract_embeddings(
|
| 484 |
cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
|
|
|
|
| 275 |
# ==================== TUNABLE PARAMETERS ====================
|
| 276 |
|
| 277 |
# Sliding window for embedding extraction
|
| 278 |
+
# Longer windows (1.5-2.0s) capture more prosody, reducing speaker confusion
|
| 279 |
+
WINDOW_SIZE = 1.5 # seconds
|
| 280 |
+
STEP_SIZE = 0.5 # seconds (67% overlap)
|
| 281 |
TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
|
| 282 |
|
| 283 |
# VAD hysteresis parameters
|
|
|
|
| 291 |
VOTING_RATE = 0.01 # 10ms resolution for consensus voting
|
| 292 |
|
| 293 |
# Post-processing
|
| 294 |
+
MIN_SEGMENT_DURATION = 0.3 # Minimum final segment duration (seconds)
|
| 295 |
+
SHORT_SEGMENT_GAP = 0.3 # Gap threshold for merging short segments
|
| 296 |
SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
|
| 297 |
|
| 298 |
# ===========================================================
|
|
|
|
| 382 |
clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
|
| 383 |
labels = clusterer(embeddings, num_speakers)
|
| 384 |
|
| 385 |
+
# Step 4: Centroid refinement - reduces flickering/confusion
|
| 386 |
+
labels = cls._refine_with_centroids(embeddings, labels)
|
| 387 |
+
|
| 388 |
+
# Step 5: Post-process with consensus voting (VAD-aware)
|
| 389 |
return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
|
| 390 |
|
| 391 |
@classmethod
|
|
|
|
| 483 |
|
| 484 |
return filtered
|
| 485 |
|
| 486 |
+
@classmethod
|
| 487 |
+
def _refine_with_centroids(cls, embeddings: np.ndarray, labels: np.ndarray) -> np.ndarray:
|
| 488 |
+
"""Refine cluster assignments using nearest centroid.
|
| 489 |
+
|
| 490 |
+
This reduces "flickering" where embeddings rapidly switch between speakers.
|
| 491 |
+
For each embedding, we re-assign it to the speaker whose centroid is closest
|
| 492 |
+
(by cosine similarity).
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
embeddings: Speaker embeddings of shape [N, D]
|
| 496 |
+
labels: Initial cluster labels of shape [N]
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
Refined labels of shape [N]
|
| 500 |
+
"""
|
| 501 |
+
if len(embeddings) == 0 or len(np.unique(labels)) <= 1:
|
| 502 |
+
return labels
|
| 503 |
+
|
| 504 |
+
# Normalize embeddings for cosine similarity
|
| 505 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 506 |
+
norms = np.maximum(norms, 1e-10)
|
| 507 |
+
norm_embeddings = embeddings / norms
|
| 508 |
+
|
| 509 |
+
# Calculate centroid for each speaker
|
| 510 |
+
unique_labels = np.unique(labels)
|
| 511 |
+
centroids = {}
|
| 512 |
+
for label in unique_labels:
|
| 513 |
+
mask = labels == label
|
| 514 |
+
speaker_embs = norm_embeddings[mask]
|
| 515 |
+
centroid = speaker_embs.mean(axis=0)
|
| 516 |
+
# Normalize centroid
|
| 517 |
+
centroid_norm = np.linalg.norm(centroid)
|
| 518 |
+
if centroid_norm > 1e-10:
|
| 519 |
+
centroids[label] = centroid / centroid_norm
|
| 520 |
+
else:
|
| 521 |
+
centroids[label] = centroid
|
| 522 |
+
|
| 523 |
+
# Re-assign each embedding to nearest centroid
|
| 524 |
+
refined_labels = np.zeros_like(labels)
|
| 525 |
+
for i, emb in enumerate(norm_embeddings):
|
| 526 |
+
best_label = labels[i]
|
| 527 |
+
best_sim = -1.0
|
| 528 |
+
for label, centroid in centroids.items():
|
| 529 |
+
sim = np.dot(emb, centroid)
|
| 530 |
+
if sim > best_sim:
|
| 531 |
+
best_sim = sim
|
| 532 |
+
best_label = label
|
| 533 |
+
refined_labels[i] = best_label
|
| 534 |
+
|
| 535 |
+
return refined_labels
|
| 536 |
+
|
| 537 |
@classmethod
|
| 538 |
def _extract_embeddings(
|
| 539 |
cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
|