Spaces:
Running
Running
| """ | |
| Agglomerative Hierarchical Clustering (AHC) for speaker identity assignment. | |
| Uses cosine similarity on ECAPA-TDNN embeddings to cluster segments by speaker. | |
| """ | |
| import numpy as np | |
| from typing import List, Tuple, Optional | |
| from scipy.cluster.hierarchy import linkage, fcluster | |
| from scipy.spatial.distance import squareform | |
| from sklearn.metrics import silhouette_score | |
| from loguru import logger | |
| class SpeakerClusterer: | |
| """ | |
| Agglomerative Hierarchical Clustering for speaker diarization. | |
| Supports automatic speaker count estimation via silhouette analysis. | |
| """ | |
| def __init__( | |
| self, | |
| linkage_method: str = "average", | |
| distance_threshold: float = 0.55, | |
| min_speakers: int = 1, | |
| max_speakers: int = 10, | |
| ): | |
| self.linkage_method = linkage_method | |
| self.distance_threshold = distance_threshold | |
| self.min_speakers = min_speakers | |
| self.max_speakers = max_speakers | |
| def _cosine_distance_matrix(self, embeddings: np.ndarray) -> np.ndarray: | |
| similarity = embeddings @ embeddings.T | |
| distance = np.clip(1.0 - similarity, 0.0, 2.0) | |
| return distance | |
| def _estimate_num_speakers(self, embeddings: np.ndarray, linkage_matrix: np.ndarray) -> int: | |
| n = len(embeddings) | |
| if n <= 2: | |
| return n | |
| min_k = max(2, self.min_speakers) | |
| upper_k = min(self.max_speakers, n - 1) | |
| best_k = min_k | |
| best_score = -1.0 | |
| for k in range(min_k, upper_k + 1): | |
| labels = fcluster(linkage_matrix, k, criterion="maxclust") | |
| if len(np.unique(labels)) < 2: | |
| continue | |
| try: | |
| score = silhouette_score(embeddings, labels, metric="cosine") | |
| if score > best_score: | |
| best_score = score | |
| best_k = k | |
| except Exception: | |
| continue | |
| threshold_labels = fcluster( | |
| linkage_matrix, | |
| t=self.distance_threshold, | |
| criterion="distance", | |
| ) | |
| k_threshold = len(np.unique(threshold_labels)) | |
| k_threshold = int(np.clip(k_threshold, self.min_speakers, min(self.max_speakers, n))) | |
| # Be conservative to avoid severe over-segmentation in open-domain audio. | |
| if best_score < 0.08: | |
| chosen_k = k_threshold | |
| else: | |
| chosen_k = min(best_k, k_threshold) if k_threshold >= 2 else best_k | |
| chosen_k = int(np.clip(chosen_k, self.min_speakers, min(self.max_speakers, n))) | |
| logger.info( | |
| f"Optimal speaker count: {chosen_k} " | |
| f"(silhouette_k={best_k}, silhouette={best_score:.4f}, threshold_k={k_threshold})" | |
| ) | |
| return chosen_k | |
| def cluster( | |
| self, | |
| embeddings: np.ndarray, | |
| num_speakers: Optional[int] = None, | |
| ) -> np.ndarray: | |
| """Cluster embeddings into speaker identities.""" | |
| n = len(embeddings) | |
| if n == 0: | |
| return np.array([], dtype=int) | |
| if n == 1: | |
| return np.array([0], dtype=int) | |
| dist_matrix = self._cosine_distance_matrix(embeddings) | |
| condensed = squareform(dist_matrix, checks=False) | |
| Z = linkage(condensed, method=self.linkage_method) | |
| if num_speakers is not None: | |
| k = max(1, min(num_speakers, n)) | |
| else: | |
| k = self._estimate_num_speakers(embeddings, Z) | |
| labels = fcluster(Z, k, criterion="maxclust") - 1 | |
| return labels.astype(int) | |
| def merge_consecutive_same_speaker( | |
| self, | |
| segments: List[Tuple[float, float]], | |
| labels: np.ndarray, | |
| gap_tolerance: float = 0.3, | |
| ) -> List[Tuple[float, float, int]]: | |
| """Merge consecutive segments assigned to the same speaker.""" | |
| if not segments: | |
| return [] | |
| merged = [] | |
| current_start, current_end = segments[0] | |
| current_label = labels[0] | |
| for i in range(1, len(segments)): | |
| seg_start, seg_end = segments[i] | |
| seg_label = labels[i] | |
| gap = seg_start - current_end | |
| if seg_label == current_label and gap <= gap_tolerance: | |
| current_end = seg_end | |
| else: | |
| merged.append((current_start, current_end, int(current_label))) | |
| current_start, current_end = seg_start, seg_end | |
| current_label = seg_label | |
| merged.append((current_start, current_end, int(current_label))) | |
| return merged | |