File size: 4,497 Bytes
d7a2919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d04859
d7a2919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789006e
d7a2919
 
789006e
 
 
 
d7a2919
 
 
 
 
 
 
 
 
 
 
8d04859
 
 
 
 
 
 
 
789006e
8d04859
 
 
789006e
 
 
8d04859
 
 
 
 
 
d7a2919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Agglomerative Hierarchical Clustering (AHC) for speaker identity assignment.
Uses cosine similarity on ECAPA-TDNN embeddings to cluster segments by speaker.
"""

import numpy as np
from typing import List, Tuple, Optional
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
from sklearn.metrics import silhouette_score
from loguru import logger


class SpeakerClusterer:
    """
    Agglomerative Hierarchical Clustering for speaker diarization.
    Supports automatic speaker count estimation via silhouette analysis.
    """

    def __init__(
        self,
        linkage_method: str = "average",
        distance_threshold: float = 0.55,
        min_speakers: int = 1,
        max_speakers: int = 10,
    ):
        self.linkage_method = linkage_method
        self.distance_threshold = distance_threshold
        self.min_speakers = min_speakers
        self.max_speakers = max_speakers

    def _cosine_distance_matrix(self, embeddings: np.ndarray) -> np.ndarray:
        similarity = embeddings @ embeddings.T
        distance = np.clip(1.0 - similarity, 0.0, 2.0)
        return distance

    def _estimate_num_speakers(self, embeddings: np.ndarray, linkage_matrix: np.ndarray) -> int:
        n = len(embeddings)
        if n <= 2:
            return n

        min_k = max(2, self.min_speakers)
        upper_k = min(self.max_speakers, n - 1)

        best_k = min_k
        best_score = -1.0

        for k in range(min_k, upper_k + 1):
            labels = fcluster(linkage_matrix, k, criterion="maxclust")
            if len(np.unique(labels)) < 2:
                continue
            try:
                score = silhouette_score(embeddings, labels, metric="cosine")
                if score > best_score:
                    best_score = score
                    best_k = k
            except Exception:
                continue

        threshold_labels = fcluster(
            linkage_matrix,
            t=self.distance_threshold,
            criterion="distance",
        )
        k_threshold = len(np.unique(threshold_labels))
        k_threshold = int(np.clip(k_threshold, self.min_speakers, min(self.max_speakers, n)))

        # Be conservative to avoid severe over-segmentation in open-domain audio.
        if best_score < 0.08:
            chosen_k = k_threshold
        else:
            chosen_k = min(best_k, k_threshold) if k_threshold >= 2 else best_k

        chosen_k = int(np.clip(chosen_k, self.min_speakers, min(self.max_speakers, n)))

        logger.info(
            f"Optimal speaker count: {chosen_k} "
            f"(silhouette_k={best_k}, silhouette={best_score:.4f}, threshold_k={k_threshold})"
        )
        return chosen_k

    def cluster(
        self,
        embeddings: np.ndarray,
        num_speakers: Optional[int] = None,
    ) -> np.ndarray:
        """Cluster embeddings into speaker identities."""
        n = len(embeddings)

        if n == 0:
            return np.array([], dtype=int)
        if n == 1:
            return np.array([0], dtype=int)

        dist_matrix = self._cosine_distance_matrix(embeddings)
        condensed = squareform(dist_matrix, checks=False)
        Z = linkage(condensed, method=self.linkage_method)

        if num_speakers is not None:
            k = max(1, min(num_speakers, n))
        else:
            k = self._estimate_num_speakers(embeddings, Z)

        labels = fcluster(Z, k, criterion="maxclust") - 1
        return labels.astype(int)

    def merge_consecutive_same_speaker(
        self,
        segments: List[Tuple[float, float]],
        labels: np.ndarray,
        gap_tolerance: float = 0.3,
    ) -> List[Tuple[float, float, int]]:
        """Merge consecutive segments assigned to the same speaker."""
        if not segments:
            return []

        merged = []
        current_start, current_end = segments[0]
        current_label = labels[0]

        for i in range(1, len(segments)):
            seg_start, seg_end = segments[i]
            seg_label = labels[i]
            gap = seg_start - current_end

            if seg_label == current_label and gap <= gap_tolerance:
                current_end = seg_end
            else:
                merged.append((current_start, current_end, int(current_label)))
                current_start, current_end = seg_start, seg_end
                current_label = seg_label

        merged.append((current_start, current_end, int(current_label)))
        return merged