mazesmazes commited on
Commit
c9c127a
·
verified ·
1 Parent(s): addb26d

Update custom model files, README, and requirements

Browse files
Files changed (1) hide show
  1. diarization.py +853 -0
diarization.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speaker diarization with support for pyannote and local (tiny-audio) backends.
2
+
3
+ Provides two diarization backends:
4
+ - pyannote: Uses pyannote-audio pipeline (requires HF token with model access)
5
+ - local: Uses TEN-VAD + ERes2NetV2 + spectral clustering (no token required)
6
+
7
+ Spectral clustering implementation adapted from FunASR/3D-Speaker:
8
+ https://github.com/alibaba-damo-academy/FunASR
9
+ MIT License (https://opensource.org/licenses/MIT)
10
+ """
11
+
12
+ import numpy as np
13
+ import scipy
14
+ import sklearn.metrics.pairwise
15
+ import torch
16
+ from sklearn.cluster._kmeans import k_means
17
+
18
+
19
+ def _get_device() -> torch.device:
20
+ """Get best available device for inference."""
21
+ if torch.cuda.is_available():
22
+ return torch.device("cuda")
23
+ if torch.backends.mps.is_available():
24
+ return torch.device("mps")
25
+ return torch.device("cpu")
26
+
27
+
28
+ class SpectralCluster:
29
+ """Spectral clustering using unnormalized Laplacian of affinity matrix.
30
+
31
+ Adapted from FunASR/3D-Speaker and SpeechBrain implementations.
32
+ Uses eigenvalue gap to automatically determine number of speakers.
33
+ """
34
+
35
+ def __init__(self, min_num_spks: int = 1, max_num_spks: int = 15, pval: float = 0.06):
36
+ self.min_num_spks = min_num_spks
37
+ self.max_num_spks = max_num_spks
38
+ self.pval = pval
39
+
40
+ def __call__(self, embeddings: np.ndarray, oracle_num: int | None = None) -> np.ndarray:
41
+ """Run spectral clustering on embeddings.
42
+
43
+ Args:
44
+ embeddings: Speaker embeddings of shape [N, D]
45
+ oracle_num: Optional known number of speakers
46
+
47
+ Returns:
48
+ Cluster labels of shape [N]
49
+ """
50
+ # Similarity matrix computation
51
+ sim_mat = self.get_sim_mat(embeddings)
52
+
53
+ # Refining similarity matrix with pval
54
+ prunned_sim_mat = self.p_pruning(sim_mat)
55
+
56
+ # Symmetrization
57
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
58
+
59
+ # Laplacian calculation
60
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
61
+
62
+ # Get Spectral Embeddings
63
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
64
+
65
+ # Perform clustering
66
+ return self.cluster_embs(emb, num_of_spk)
67
+
68
+ def get_sim_mat(self, embeddings: np.ndarray) -> np.ndarray:
69
+ """Compute cosine similarity matrix."""
70
+ return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
71
+
72
+ def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
73
+ """Prune low similarity values in affinity matrix."""
74
+ pval = 6.0 / affinity.shape[0] if affinity.shape[0] * self.pval < 6 else self.pval
75
+ n_elems = int((1 - pval) * affinity.shape[0])
76
+
77
+ # For each row in affinity matrix, zero out low similarities
78
+ for i in range(affinity.shape[0]):
79
+ low_indexes = np.argsort(affinity[i, :])
80
+ low_indexes = low_indexes[0:n_elems]
81
+ affinity[i, low_indexes] = 0
82
+ return affinity
83
+
84
+ def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
85
+ """Compute unnormalized Laplacian matrix."""
86
+ sim_mat[np.diag_indices(sim_mat.shape[0])] = 0
87
+ degree = np.sum(np.abs(sim_mat), axis=1)
88
+ degree_mat = np.diag(degree)
89
+ return degree_mat - sim_mat
90
+
91
+ def get_spec_embs(
92
+ self, laplacian: np.ndarray, k_oracle: int | None = None
93
+ ) -> tuple[np.ndarray, int]:
94
+ """Extract spectral embeddings from Laplacian."""
95
+ lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
96
+
97
+ if k_oracle is not None:
98
+ num_of_spk = k_oracle
99
+ else:
100
+ lambda_gap_list = self.get_eigen_gaps(
101
+ lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
102
+ )
103
+ num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
104
+
105
+ emb = eig_vecs[:, :num_of_spk]
106
+ return emb, num_of_spk
107
+
108
+ def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
109
+ """Cluster spectral embeddings using k-means."""
110
+ _, labels, _ = k_means(emb, k, n_init=10)
111
+ return labels
112
+
113
+ def get_eigen_gaps(self, eig_vals: np.ndarray) -> list[float]:
114
+ """Compute gaps between consecutive eigenvalues."""
115
+ eig_vals_gap_list = []
116
+ for i in range(len(eig_vals) - 1):
117
+ gap = float(eig_vals[i + 1]) - float(eig_vals[i])
118
+ eig_vals_gap_list.append(gap)
119
+ return eig_vals_gap_list
120
+
121
+
122
+ class SpeakerClusterer:
123
+ """Speaker clustering backend using spectral clustering with speaker merging.
124
+
125
+ Features:
126
+ - Spectral clustering with eigenvalue gap for auto speaker count detection
127
+ - P-pruning for affinity matrix refinement
128
+ - Post-clustering speaker merging by cosine similarity
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ min_num_spks: int = 2,
134
+ max_num_spks: int = 10,
135
+ merge_thr: float = 0.90, # Moderate merging
136
+ ):
137
+ self.min_num_spks = min_num_spks
138
+ self.max_num_spks = max_num_spks
139
+ self.merge_thr = merge_thr
140
+ self._spectral_cluster: SpectralCluster | None = None
141
+
142
+ def _get_spectral_cluster(self) -> SpectralCluster:
143
+ """Lazy-load spectral clusterer."""
144
+ if self._spectral_cluster is None:
145
+ self._spectral_cluster = SpectralCluster(
146
+ min_num_spks=self.min_num_spks,
147
+ max_num_spks=self.max_num_spks,
148
+ )
149
+ return self._spectral_cluster
150
+
151
+ def __call__(self, embeddings: np.ndarray, num_speakers: int | None = None) -> np.ndarray:
152
+ """Cluster speaker embeddings and return labels.
153
+
154
+ Args:
155
+ embeddings: Speaker embeddings of shape [N, D]
156
+ num_speakers: Optional oracle number of speakers
157
+
158
+ Returns:
159
+ Cluster labels of shape [N]
160
+ """
161
+ import warnings
162
+
163
+ if len(embeddings.shape) != 2:
164
+ raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
165
+
166
+ # Handle edge cases
167
+ if embeddings.shape[0] == 0:
168
+ return np.array([], dtype=int)
169
+ if embeddings.shape[0] == 1:
170
+ return np.array([0], dtype=int)
171
+ if embeddings.shape[0] < 6:
172
+ return np.zeros(embeddings.shape[0], dtype=int)
173
+
174
+ # Normalize embeddings
175
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
176
+ norms = np.maximum(norms, 1e-10)
177
+ embeddings = embeddings / norms
178
+
179
+ # Replace NaN/inf with zeros
180
+ embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
181
+
182
+ # Run spectral clustering (suppress numerical warnings)
183
+ spectral = self._get_spectral_cluster()
184
+
185
+ # Update min/max for oracle case
186
+ if num_speakers is not None:
187
+ spectral.min_num_spks = num_speakers
188
+ spectral.max_num_spks = num_speakers
189
+
190
+ with warnings.catch_warnings():
191
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
192
+ labels = spectral(embeddings, oracle_num=num_speakers)
193
+
194
+ # Reset min/max
195
+ if num_speakers is not None:
196
+ spectral.min_num_spks = self.min_num_spks
197
+ spectral.max_num_spks = self.max_num_spks
198
+
199
+ # Merge similar speakers if no oracle
200
+ if num_speakers is None:
201
+ labels = self._merge_by_cos(labels, embeddings, self.merge_thr)
202
+
203
+ # Re-index labels sequentially
204
+ _, labels = np.unique(labels, return_inverse=True)
205
+
206
+ return labels
207
+
208
+ def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
209
+ """Merge similar speakers by cosine similarity of centroids."""
210
+ labels = labels.copy()
211
+
212
+ while True:
213
+ spk_num = labels.max() + 1
214
+ if spk_num == 1:
215
+ break
216
+
217
+ # Compute speaker centroids
218
+ spk_center = []
219
+ for i in range(spk_num):
220
+ spk_emb = embs[labels == i].mean(0)
221
+ spk_center.append(spk_emb)
222
+
223
+ if len(spk_center) == 0:
224
+ break
225
+
226
+ spk_center = np.stack(spk_center, axis=0)
227
+ norm_spk_center = spk_center / np.linalg.norm(spk_center, axis=1, keepdims=True)
228
+ affinity = np.matmul(norm_spk_center, norm_spk_center.T)
229
+ affinity = np.triu(affinity, 1)
230
+
231
+ # Find most similar pair
232
+ spks = np.unravel_index(np.argmax(affinity), affinity.shape)
233
+ if affinity[spks] < cos_thr:
234
+ break
235
+
236
+ # Merge speakers
237
+ for i in range(len(labels)):
238
+ if labels[i] == spks[1]:
239
+ labels[i] = spks[0]
240
+ elif labels[i] > spks[1]:
241
+ labels[i] -= 1
242
+
243
+ return labels
244
+
245
+
246
+ class LocalSpeakerDiarizer:
247
+ """Local speaker diarization using TEN-VAD + ERes2NetV2 + spectral clustering.
248
+
249
+ Pipeline:
250
+ 1. TEN-VAD detects speech segments
251
+ 2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
252
+ 3. ERes2NetV2 extracts speaker embeddings per window
253
+ 4. Spectral clustering with eigenvalue gap for auto speaker detection
254
+ 5. Frame-level consensus voting for segment reconstruction
255
+ 6. Post-processing merges short segments to reduce flicker
256
+
257
+ Tunable Parameters (class attributes):
258
+ - WINDOW_SIZE: Embedding extraction window size in seconds
259
+ - STEP_SIZE: Sliding window step size (overlap = WINDOW_SIZE - STEP_SIZE)
260
+ - VAD_THRESHOLD: Speech detection threshold (lower = more sensitive)
261
+ - VAD_MIN_DURATION: Minimum speech segment duration
262
+ - VAD_MAX_GAP: Maximum gap to bridge between segments
263
+ - VAD_PAD_ONSET/OFFSET: Padding added to speech segments
264
+ - VOTING_RATE: Frame resolution for consensus voting
265
+ - MIN_SEGMENT_DURATION: Minimum final segment duration
266
+ - SAME_SPEAKER_GAP: Maximum gap to merge same-speaker segments
267
+ - TAIL_COVERAGE_RATIO: Minimum tail coverage to add extra window
268
+ """
269
+
270
+ _ten_vad_model = None
271
+ _eres2netv2_model = None
272
+ _device = None
273
+
274
+ # ==================== TUNABLE PARAMETERS ====================
275
+
276
+ # Sliding window for embedding extraction
277
+ WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
278
+ STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
279
+ TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
280
+
281
+ # VAD hysteresis parameters
282
+ VAD_THRESHOLD = 0.25 # Balanced threshold
283
+ VAD_MIN_DURATION = 0.05 # Minimum speech segment duration (seconds)
284
+ VAD_MAX_GAP = 0.50 # Bridge gaps shorter than this (seconds)
285
+ VAD_PAD_ONSET = 0.05 # Padding at segment start (seconds)
286
+ VAD_PAD_OFFSET = 0.05 # Padding at segment end (seconds)
287
+
288
+ # Frame-level voting
289
+ VOTING_RATE = 0.01 # 10ms resolution for consensus voting
290
+
291
+ # Post-processing
292
+ MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
293
+ SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
294
+ SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
295
+
296
+ # ===========================================================
297
+
298
+ @classmethod
299
+ def _get_ten_vad_model(cls):
300
+ """Lazy-load TEN-VAD model (singleton)."""
301
+ if cls._ten_vad_model is None:
302
+ from ten_vad import TenVad
303
+
304
+ cls._ten_vad_model = TenVad(hop_size=256, threshold=cls.VAD_THRESHOLD)
305
+ return cls._ten_vad_model
306
+
307
+ @classmethod
308
+ def _get_device(cls) -> torch.device:
309
+ """Get the best available device."""
310
+ if cls._device is None:
311
+ cls._device = _get_device()
312
+ return cls._device
313
+
314
+ @classmethod
315
+ def _get_eres2netv2_model(cls):
316
+ """Lazy-load ERes2NetV2 speaker embedding model (singleton)."""
317
+ if cls._eres2netv2_model is None:
318
+ from modelscope.pipelines import pipeline
319
+ from modelscope.utils.constant import Tasks
320
+
321
+ sv_pipeline = pipeline(
322
+ task=Tasks.speaker_verification,
323
+ model="iic/speech_eres2netv2_sv_zh-cn_16k-common",
324
+ )
325
+ cls._eres2netv2_model = sv_pipeline.model
326
+
327
+ # Move model to GPU if available
328
+ device = cls._get_device()
329
+ cls._eres2netv2_model = cls._eres2netv2_model.to(device)
330
+ cls._eres2netv2_model.device = device
331
+ cls._eres2netv2_model.eval()
332
+
333
+ return cls._eres2netv2_model
334
+
335
+ @classmethod
336
+ def diarize(
337
+ cls,
338
+ audio: np.ndarray | str,
339
+ sample_rate: int = 16000,
340
+ num_speakers: int | None = None,
341
+ min_speakers: int = 2,
342
+ max_speakers: int = 10,
343
+ **_kwargs,
344
+ ) -> list[dict]:
345
+ """Run speaker diarization on audio.
346
+
347
+ Args:
348
+ audio: Audio waveform as numpy array or path to audio file
349
+ sample_rate: Audio sample rate (default 16000)
350
+ num_speakers: Exact number of speakers (if known)
351
+ min_speakers: Minimum number of speakers
352
+ max_speakers: Maximum number of speakers
353
+
354
+ Returns:
355
+ List of dicts with 'speaker', 'start', 'end' keys
356
+ """
357
+ # Handle file path input
358
+ if isinstance(audio, str):
359
+ import librosa
360
+
361
+ audio, sample_rate = librosa.load(audio, sr=16000)
362
+
363
+ # Ensure correct sample rate
364
+ if sample_rate != 16000:
365
+ import librosa
366
+
367
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
368
+ sample_rate = 16000
369
+
370
+ audio = audio.astype(np.float32)
371
+ total_duration = len(audio) / sample_rate
372
+
373
+ # Step 1: VAD (returns segments and raw frame-level decisions)
374
+ segments, vad_frames = cls._get_speech_segments(audio, sample_rate)
375
+ if not segments:
376
+ return []
377
+
378
+ # Step 2: Extract embeddings
379
+ embeddings, window_segments = cls._extract_embeddings(audio, segments, sample_rate)
380
+ if len(embeddings) == 0:
381
+ return []
382
+
383
+ # Step 3: Cluster
384
+ clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
385
+ labels = clusterer(embeddings, num_speakers)
386
+
387
+ # Step 4: Post-process with consensus voting (VAD-aware)
388
+ return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
389
+
390
+ @classmethod
391
+ def _get_speech_segments(
392
+ cls, audio_array: np.ndarray, sample_rate: int = 16000
393
+ ) -> tuple[list[dict], list[bool]]:
394
+ """Get speech segments using TEN-VAD.
395
+
396
+ Returns:
397
+ Tuple of (segments list, vad_frames list of per-frame speech decisions)
398
+ """
399
+ vad_model = cls._get_ten_vad_model()
400
+
401
+ # Convert to int16 as required by TEN-VAD
402
+ # Clip to prevent integer overflow
403
+ if audio_array.dtype != np.int16:
404
+ audio_int16 = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
405
+ else:
406
+ audio_int16 = audio_array
407
+
408
+ # Process frame by frame
409
+ hop_size = 256
410
+ frame_duration = hop_size / sample_rate
411
+ speech_frames: list[bool] = []
412
+
413
+ for i in range(0, len(audio_int16) - hop_size, hop_size):
414
+ frame = audio_int16[i : i + hop_size]
415
+ _, is_speech = vad_model.process(frame)
416
+ speech_frames.append(is_speech)
417
+
418
+ # Convert frame-level decisions to segments
419
+ segments = []
420
+ in_speech = False
421
+ start_idx = 0
422
+
423
+ for i, is_speech in enumerate(speech_frames):
424
+ if is_speech and not in_speech:
425
+ start_idx = i
426
+ in_speech = True
427
+ elif not is_speech and in_speech:
428
+ start_time = start_idx * frame_duration
429
+ end_time = i * frame_duration
430
+ segments.append(
431
+ {
432
+ "start": start_time,
433
+ "end": end_time,
434
+ "start_sample": int(start_time * sample_rate),
435
+ "end_sample": int(end_time * sample_rate),
436
+ }
437
+ )
438
+ in_speech = False
439
+
440
+ # Handle trailing speech
441
+ if in_speech:
442
+ start_time = start_idx * frame_duration
443
+ end_time = len(speech_frames) * frame_duration
444
+ segments.append(
445
+ {
446
+ "start": start_time,
447
+ "end": end_time,
448
+ "start_sample": int(start_time * sample_rate),
449
+ "end_sample": int(end_time * sample_rate),
450
+ }
451
+ )
452
+
453
+ return cls._apply_vad_hysteresis(segments, sample_rate), speech_frames
454
+
455
+ @classmethod
456
+ def _apply_vad_hysteresis(cls, segments: list[dict], sample_rate: int = 16000) -> list[dict]:
457
+ """Apply hysteresis-like post-processing to VAD segments."""
458
+ if not segments:
459
+ return segments
460
+
461
+ segments = sorted(segments, key=lambda x: x["start"])
462
+
463
+ # Fill short gaps
464
+ merged = [segments[0].copy()]
465
+ for seg in segments[1:]:
466
+ gap = seg["start"] - merged[-1]["end"]
467
+ if gap <= cls.VAD_MAX_GAP:
468
+ merged[-1]["end"] = seg["end"]
469
+ merged[-1]["end_sample"] = seg["end_sample"]
470
+ else:
471
+ merged.append(seg.copy())
472
+
473
+ # Remove short segments
474
+ filtered = [seg for seg in merged if (seg["end"] - seg["start"]) >= cls.VAD_MIN_DURATION]
475
+
476
+ # Dilate segments (add padding)
477
+ for seg in filtered:
478
+ seg["start"] = max(0.0, seg["start"] - cls.VAD_PAD_ONSET)
479
+ seg["end"] = seg["end"] + cls.VAD_PAD_OFFSET
480
+ seg["start_sample"] = int(seg["start"] * sample_rate)
481
+ seg["end_sample"] = int(seg["end"] * sample_rate)
482
+
483
+ return filtered
484
+
485
+ @classmethod
486
+ def _extract_embeddings(
487
+ cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
488
+ ) -> tuple[np.ndarray, list[dict]]:
489
+ """Extract speaker embeddings using sliding windows."""
490
+ speaker_model = cls._get_eres2netv2_model()
491
+ device = cls._get_device()
492
+
493
+ window_samples = int(cls.WINDOW_SIZE * sample_rate)
494
+ step_samples = int(cls.STEP_SIZE * sample_rate)
495
+
496
+ embeddings = []
497
+ window_segments = []
498
+
499
+ with torch.no_grad():
500
+ for seg in segments:
501
+ seg_start = seg["start_sample"]
502
+ seg_end = seg["end_sample"]
503
+ seg_len = seg_end - seg_start
504
+
505
+ # Generate window positions
506
+ if seg_len <= window_samples:
507
+ starts = [seg_start]
508
+ ends = [seg_end]
509
+ else:
510
+ starts = list(range(seg_start, seg_end - window_samples + 1, step_samples))
511
+ ends = [s + window_samples for s in starts]
512
+
513
+ # Cover tail if > TAIL_COVERAGE_RATIO of window remains
514
+ if ends and ends[-1] < seg_end:
515
+ remainder = seg_end - ends[-1]
516
+ if remainder > (window_samples * cls.TAIL_COVERAGE_RATIO):
517
+ starts.append(seg_end - window_samples)
518
+ ends.append(seg_end)
519
+
520
+ for c_start, c_end in zip(starts, ends):
521
+ chunk = audio_array[c_start:c_end]
522
+
523
+ # Pad short chunks with reflection
524
+ if len(chunk) < window_samples:
525
+ pad_width = window_samples - len(chunk)
526
+ chunk = np.pad(chunk, (0, pad_width), mode="reflect")
527
+
528
+ # Extract embedding
529
+ chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0).to(device)
530
+ embedding = speaker_model.forward(chunk_tensor).squeeze(0).cpu().numpy()
531
+
532
+ # Validate and normalize
533
+ if not np.isfinite(embedding).all():
534
+ continue
535
+ norm = np.linalg.norm(embedding)
536
+ if norm > 1e-8:
537
+ embeddings.append(embedding / norm)
538
+ window_segments.append(
539
+ {
540
+ "start": c_start / sample_rate,
541
+ "end": c_end / sample_rate,
542
+ }
543
+ )
544
+
545
+ if embeddings:
546
+ return np.array(embeddings), window_segments
547
+ return np.array([]), []
548
+
549
+ @classmethod
550
+ def _resample_vad(cls, vad_frames: list[bool], num_frames: int) -> np.ndarray:
551
+ """Resample VAD frame decisions to match voting grid resolution.
552
+
553
+ VAD operates at 256 samples / 16000 Hz = 16ms per frame.
554
+ Voting operates at VOTING_RATE (default 10ms) per frame.
555
+ This maps VAD decisions to the finer voting grid.
556
+ """
557
+ if not vad_frames:
558
+ return np.zeros(num_frames, dtype=bool)
559
+
560
+ vad_rate = 256 / 16000 # 16ms per VAD frame
561
+ result = np.zeros(num_frames, dtype=bool)
562
+
563
+ for i in range(num_frames):
564
+ voting_time = i * cls.VOTING_RATE
565
+ vad_frame = int(voting_time / vad_rate)
566
+ if vad_frame < len(vad_frames):
567
+ result[i] = vad_frames[vad_frame]
568
+
569
+ return result
570
+
571
+ @classmethod
572
+ def _postprocess_segments(
573
+ cls,
574
+ window_segments: list[dict],
575
+ labels: np.ndarray,
576
+ total_duration: float,
577
+ vad_frames: list[bool],
578
+ ) -> list[dict]:
579
+ """Post-process using frame-level consensus voting with VAD-aware silence."""
580
+ if not window_segments or len(labels) == 0:
581
+ return []
582
+
583
+ # Correct labels to be contiguous
584
+ unique_labels = np.unique(labels)
585
+ label_map = {old: new for new, old in enumerate(unique_labels)}
586
+ clean_labels = np.array([label_map[lbl] for lbl in labels])
587
+ num_speakers = len(unique_labels)
588
+
589
+ if num_speakers == 0:
590
+ return []
591
+
592
+ # Create voting grid
593
+ num_frames = int(np.ceil(total_duration / cls.VOTING_RATE)) + 1
594
+ votes = np.zeros((num_frames, num_speakers), dtype=np.float32)
595
+
596
+ # Accumulate votes
597
+ for win, label in zip(window_segments, clean_labels):
598
+ start_frame = int(win["start"] / cls.VOTING_RATE)
599
+ end_frame = int(win["end"] / cls.VOTING_RATE)
600
+ end_frame = min(end_frame, num_frames)
601
+ if start_frame < end_frame:
602
+ votes[start_frame:end_frame, label] += 1.0
603
+
604
+ # Determine winner per frame
605
+ frame_speakers = np.argmax(votes, axis=1)
606
+ max_votes = np.max(votes, axis=1)
607
+
608
+ # Resample VAD to voting grid resolution for silence-aware voting
609
+ vad_resampled = cls._resample_vad(vad_frames, num_frames)
610
+
611
+ # Convert frames to segments
612
+ final_segments = []
613
+ current_speaker = -1
614
+ seg_start = 0.0
615
+
616
+ for f in range(num_frames):
617
+ speaker = int(frame_speakers[f])
618
+ score = max_votes[f]
619
+
620
+ # Force silence if VAD says no speech OR no votes
621
+ if score == 0 or not vad_resampled[f]:
622
+ speaker = -1
623
+
624
+ if speaker != current_speaker:
625
+ if current_speaker != -1:
626
+ final_segments.append(
627
+ {
628
+ "speaker": f"SPEAKER_{current_speaker}",
629
+ "start": seg_start,
630
+ "end": f * cls.VOTING_RATE,
631
+ }
632
+ )
633
+ current_speaker = speaker
634
+ seg_start = f * cls.VOTING_RATE
635
+
636
+ # Close last segment
637
+ if current_speaker != -1:
638
+ final_segments.append(
639
+ {
640
+ "speaker": f"SPEAKER_{current_speaker}",
641
+ "start": seg_start,
642
+ "end": num_frames * cls.VOTING_RATE,
643
+ }
644
+ )
645
+
646
+ return cls._merge_short_segments(final_segments)
647
+
648
+ @classmethod
649
+ def _merge_short_segments(cls, segments: list[dict]) -> list[dict]:
650
+ """Merge short segments to reduce flicker."""
651
+ if not segments:
652
+ return []
653
+
654
+ clean: list[dict] = []
655
+ for seg in segments:
656
+ dur = seg["end"] - seg["start"]
657
+ if dur < cls.MIN_SEGMENT_DURATION:
658
+ if (
659
+ clean
660
+ and clean[-1]["speaker"] == seg["speaker"]
661
+ and seg["start"] - clean[-1]["end"] < cls.SHORT_SEGMENT_GAP
662
+ ):
663
+ clean[-1]["end"] = seg["end"]
664
+ continue
665
+
666
+ if (
667
+ clean
668
+ and clean[-1]["speaker"] == seg["speaker"]
669
+ and seg["start"] - clean[-1]["end"] < cls.SAME_SPEAKER_GAP
670
+ ):
671
+ clean[-1]["end"] = seg["end"]
672
+ else:
673
+ clean.append(seg)
674
+
675
+ return clean
676
+
677
+ @classmethod
678
+ def assign_speakers_to_words(
679
+ cls,
680
+ words: list[dict],
681
+ speaker_segments: list[dict],
682
+ ) -> list[dict]:
683
+ """Assign speaker labels to words based on timestamp overlap.
684
+
685
+ Args:
686
+ words: List of word dicts with 'word', 'start', 'end' keys
687
+ speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
688
+
689
+ Returns:
690
+ Words list with 'speaker' key added to each word
691
+ """
692
+ for word in words:
693
+ word_mid = (word["start"] + word["end"]) / 2
694
+
695
+ # Find the speaker segment that contains this word's midpoint
696
+ best_speaker = None
697
+ for seg in speaker_segments:
698
+ if seg["start"] <= word_mid <= seg["end"]:
699
+ best_speaker = seg["speaker"]
700
+ break
701
+
702
+ # If no exact match, find closest segment
703
+ if best_speaker is None and speaker_segments:
704
+ min_dist = float("inf")
705
+ for seg in speaker_segments:
706
+ seg_mid = (seg["start"] + seg["end"]) / 2
707
+ dist = abs(word_mid - seg_mid)
708
+ if dist < min_dist:
709
+ min_dist = dist
710
+ best_speaker = seg["speaker"]
711
+
712
+ word["speaker"] = best_speaker
713
+
714
+ return words
715
+
716
+
717
+ class SpeakerDiarizer:
718
+ """Unified speaker diarization interface supporting multiple backends.
719
+
720
+ Backends:
721
+ - 'pyannote': Uses pyannote-audio pipeline (requires HF token)
722
+ - 'local': Uses TEN-VAD + ERes2NetV2 + spectral clustering
723
+
724
+ Example:
725
+ >>> segments = SpeakerDiarizer.diarize(audio_array, backend="local")
726
+ >>> for seg in segments:
727
+ ... print(f"{seg['speaker']}: {seg['start']:.2f} - {seg['end']:.2f}")
728
+ """
729
+
730
+ _pyannote_pipeline = None
731
+
732
+ @classmethod
733
+ def _get_pyannote_pipeline(cls, hf_token: str | None = None):
734
+ """Get or create the pyannote diarization pipeline."""
735
+ if cls._pyannote_pipeline is None:
736
+ from pyannote.audio import Pipeline
737
+
738
+ cls._pyannote_pipeline = Pipeline.from_pretrained(
739
+ "pyannote/speaker-diarization-3.1",
740
+ use_auth_token=hf_token,
741
+ )
742
+ cls._pyannote_pipeline.to(torch.device(_get_device()))
743
+
744
+ return cls._pyannote_pipeline
745
+
746
+ @classmethod
747
+ def diarize(
748
+ cls,
749
+ audio: np.ndarray | str,
750
+ sample_rate: int = 16000,
751
+ num_speakers: int | None = None,
752
+ min_speakers: int | None = None,
753
+ max_speakers: int | None = None,
754
+ hf_token: str | None = None,
755
+ backend: str = "pyannote",
756
+ ) -> list[dict]:
757
+ """Run speaker diarization on audio.
758
+
759
+ Args:
760
+ audio: Audio waveform as numpy array or path to audio file
761
+ sample_rate: Audio sample rate (default 16000)
762
+ num_speakers: Exact number of speakers (if known)
763
+ min_speakers: Minimum number of speakers
764
+ max_speakers: Maximum number of speakers
765
+ hf_token: HuggingFace token for pyannote models
766
+ backend: Diarization backend ("pyannote" or "local")
767
+
768
+ Returns:
769
+ List of dicts with 'speaker', 'start', 'end' keys
770
+ """
771
+ if backend == "local":
772
+ return LocalSpeakerDiarizer.diarize(
773
+ audio,
774
+ sample_rate=sample_rate,
775
+ num_speakers=num_speakers,
776
+ min_speakers=min_speakers or 2,
777
+ max_speakers=max_speakers or 10,
778
+ )
779
+
780
+ # Default to pyannote
781
+ return cls._diarize_pyannote(
782
+ audio,
783
+ sample_rate=sample_rate,
784
+ num_speakers=num_speakers,
785
+ min_speakers=min_speakers,
786
+ max_speakers=max_speakers,
787
+ hf_token=hf_token,
788
+ )
789
+
790
+ @classmethod
791
+ def _diarize_pyannote(
792
+ cls,
793
+ audio: np.ndarray | str,
794
+ sample_rate: int = 16000,
795
+ num_speakers: int | None = None,
796
+ min_speakers: int | None = None,
797
+ max_speakers: int | None = None,
798
+ hf_token: str | None = None,
799
+ ) -> list[dict]:
800
+ """Run pyannote diarization."""
801
+ pipeline = cls._get_pyannote_pipeline(hf_token)
802
+
803
+ # Prepare audio input
804
+ if isinstance(audio, np.ndarray):
805
+ waveform = torch.from_numpy(audio.copy()).unsqueeze(0)
806
+ if waveform.dim() == 1:
807
+ waveform = waveform.unsqueeze(0)
808
+ audio_input = {"waveform": waveform, "sample_rate": sample_rate}
809
+ else:
810
+ audio_input = audio
811
+
812
+ # Run diarization
813
+ diarization_args = {}
814
+ if num_speakers is not None:
815
+ diarization_args["num_speakers"] = num_speakers
816
+ if min_speakers is not None:
817
+ diarization_args["min_speakers"] = min_speakers
818
+ if max_speakers is not None:
819
+ diarization_args["max_speakers"] = max_speakers
820
+
821
+ diarization = pipeline(audio_input, **diarization_args)
822
+
823
+ # Handle different pyannote return types
824
+ if hasattr(diarization, "itertracks"):
825
+ annotation = diarization
826
+ elif hasattr(diarization, "speaker_diarization"):
827
+ annotation = diarization.speaker_diarization
828
+ elif isinstance(diarization, tuple):
829
+ annotation = diarization[0]
830
+ else:
831
+ raise TypeError(f"Unexpected diarization output type: {type(diarization)}")
832
+
833
+ # Convert to simple format
834
+ segments = []
835
+ for turn, _, speaker in annotation.itertracks(yield_label=True):
836
+ segments.append(
837
+ {
838
+ "speaker": speaker,
839
+ "start": turn.start,
840
+ "end": turn.end,
841
+ }
842
+ )
843
+
844
+ return segments
845
+
846
+ @classmethod
847
+ def assign_speakers_to_words(
848
+ cls,
849
+ words: list[dict],
850
+ speaker_segments: list[dict],
851
+ ) -> list[dict]:
852
+ """Assign speaker labels to words based on timestamp overlap."""
853
+ return LocalSpeakerDiarizer.assign_speakers_to_words(words, speaker_segments)