Who-Spoke-When / models /clusterer.py
ConvxO2's picture
Reduce speaker over-segmentation in auto clustering
789006e
"""
Agglomerative Hierarchical Clustering (AHC) for speaker identity assignment.
Uses cosine similarity on ECAPA-TDNN embeddings to cluster segments by speaker.
"""
import numpy as np
from typing import List, Tuple, Optional
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
from sklearn.metrics import silhouette_score
from loguru import logger
class SpeakerClusterer:
"""
Agglomerative Hierarchical Clustering for speaker diarization.
Supports automatic speaker count estimation via silhouette analysis.
"""
def __init__(
self,
linkage_method: str = "average",
distance_threshold: float = 0.55,
min_speakers: int = 1,
max_speakers: int = 10,
):
self.linkage_method = linkage_method
self.distance_threshold = distance_threshold
self.min_speakers = min_speakers
self.max_speakers = max_speakers
def _cosine_distance_matrix(self, embeddings: np.ndarray) -> np.ndarray:
similarity = embeddings @ embeddings.T
distance = np.clip(1.0 - similarity, 0.0, 2.0)
return distance
def _estimate_num_speakers(self, embeddings: np.ndarray, linkage_matrix: np.ndarray) -> int:
n = len(embeddings)
if n <= 2:
return n
min_k = max(2, self.min_speakers)
upper_k = min(self.max_speakers, n - 1)
best_k = min_k
best_score = -1.0
for k in range(min_k, upper_k + 1):
labels = fcluster(linkage_matrix, k, criterion="maxclust")
if len(np.unique(labels)) < 2:
continue
try:
score = silhouette_score(embeddings, labels, metric="cosine")
if score > best_score:
best_score = score
best_k = k
except Exception:
continue
threshold_labels = fcluster(
linkage_matrix,
t=self.distance_threshold,
criterion="distance",
)
k_threshold = len(np.unique(threshold_labels))
k_threshold = int(np.clip(k_threshold, self.min_speakers, min(self.max_speakers, n)))
# Be conservative to avoid severe over-segmentation in open-domain audio.
if best_score < 0.08:
chosen_k = k_threshold
else:
chosen_k = min(best_k, k_threshold) if k_threshold >= 2 else best_k
chosen_k = int(np.clip(chosen_k, self.min_speakers, min(self.max_speakers, n)))
logger.info(
f"Optimal speaker count: {chosen_k} "
f"(silhouette_k={best_k}, silhouette={best_score:.4f}, threshold_k={k_threshold})"
)
return chosen_k
def cluster(
self,
embeddings: np.ndarray,
num_speakers: Optional[int] = None,
) -> np.ndarray:
"""Cluster embeddings into speaker identities."""
n = len(embeddings)
if n == 0:
return np.array([], dtype=int)
if n == 1:
return np.array([0], dtype=int)
dist_matrix = self._cosine_distance_matrix(embeddings)
condensed = squareform(dist_matrix, checks=False)
Z = linkage(condensed, method=self.linkage_method)
if num_speakers is not None:
k = max(1, min(num_speakers, n))
else:
k = self._estimate_num_speakers(embeddings, Z)
labels = fcluster(Z, k, criterion="maxclust") - 1
return labels.astype(int)
def merge_consecutive_same_speaker(
self,
segments: List[Tuple[float, float]],
labels: np.ndarray,
gap_tolerance: float = 0.3,
) -> List[Tuple[float, float, int]]:
"""Merge consecutive segments assigned to the same speaker."""
if not segments:
return []
merged = []
current_start, current_end = segments[0]
current_label = labels[0]
for i in range(1, len(segments)):
seg_start, seg_end = segments[i]
seg_label = labels[i]
gap = seg_start - current_end
if seg_label == current_label and gap <= gap_tolerance:
current_end = seg_end
else:
merged.append((current_start, current_end, int(current_label)))
current_start, current_end = seg_start, seg_end
current_label = seg_label
merged.append((current_start, current_end, int(current_label)))
return merged