mazesmazes commited on
Commit
08c708e
·
verified ·
1 Parent(s): 8fa5959

Update custom model files, README, and requirements

Browse files
Files changed (2) hide show
  1. asr_pipeline.py +222 -1
  2. diarization.py +80 -164
asr_pipeline.py CHANGED
@@ -30,6 +30,12 @@ class ForcedAligner:
30
  _model = None
31
  _labels = None
32
  _dictionary = None
 
 
 
 
 
 
33
 
34
  @classmethod
35
  def get_instance(cls, device: str = "cuda"):
@@ -51,6 +57,135 @@ class ForcedAligner:
51
  cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
52
  return cls._model, cls._labels, cls._dictionary
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  @classmethod
55
  def align(
56
  cls,
@@ -59,6 +194,7 @@ class ForcedAligner:
59
  sample_rate: int = 16000,
60
  _language: str = "eng",
61
  _batch_size: int = 16,
 
62
  ) -> list[dict]:
63
  """Align transcript to audio and return word-level timestamps.
64
 
@@ -68,9 +204,10 @@ class ForcedAligner:
68
  sample_rate: Audio sample rate (default 16000)
69
  _language: ISO-639-3 language code (default "eng" for English, unused)
70
  _batch_size: Batch size for alignment model (unused)
 
71
 
72
  Returns:
73
- List of dicts with 'word', 'start', 'end' keys
74
  """
75
  import torchaudio
76
  from torchaudio.functional import forced_align, merge_tokens
@@ -78,6 +215,11 @@ class ForcedAligner:
78
  device = _get_device()
79
  model, labels, dictionary = cls.get_instance(device)
80
 
 
 
 
 
 
81
  # Convert audio to tensor (copy to ensure array is writable)
82
  if isinstance(audio, np.ndarray):
83
  waveform = torch.from_numpy(audio.copy()).float()
@@ -130,43 +272,122 @@ class ForcedAligner:
130
  frame_duration = 320 / cls._bundle.sample_rate
131
 
132
  # Group token spans into words based on pipe separator
 
133
  words = text.split()
134
  word_timestamps = []
135
  current_word_start = None
136
  current_word_end = None
 
137
  word_idx = 0
138
 
139
  for span in token_spans:
140
  token_char = labels[span.token]
141
  if token_char == "|": # Word separator
142
  if current_word_start is not None and word_idx < len(words):
 
 
 
 
 
 
143
  word_timestamps.append(
144
  {
145
  "word": words[word_idx],
146
  "start": current_word_start * frame_duration,
147
  "end": current_word_end * frame_duration,
 
148
  }
149
  )
150
  word_idx += 1
151
  current_word_start = None
152
  current_word_end = None
 
153
  else:
154
  if current_word_start is None:
155
  current_word_start = span.start
156
  current_word_end = span.end
 
157
 
158
  # Don't forget the last word
159
  if current_word_start is not None and word_idx < len(words):
 
 
 
160
  word_timestamps.append(
161
  {
162
  "word": words[word_idx],
163
  "start": current_word_start * frame_duration,
164
  "end": current_word_end * frame_duration,
 
165
  }
166
  )
167
 
 
 
 
 
168
  return word_timestamps
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  try:
172
  from .diarization import SpeakerDiarizer
 
30
  _model = None
31
  _labels = None
32
  _dictionary = None
33
+ _vad_model = None
34
+
35
+ # VAD parameters
36
+ VAD_HOP_SIZE = 256 # TEN-VAD frame size (16ms at 16kHz)
37
+ VAD_THRESHOLD = 0.5 # Speech detection threshold
38
+ VAD_MAX_GAP = 0.15 # Max gap to merge speech segments (seconds)
39
 
40
  @classmethod
41
  def get_instance(cls, device: str = "cuda"):
 
57
  cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
58
  return cls._model, cls._labels, cls._dictionary
59
 
60
+ @classmethod
61
+ def _get_vad_model(cls):
62
+ """Lazy-load TEN-VAD model (singleton)."""
63
+ if cls._vad_model is None:
64
+ from ten_vad import TenVad
65
+
66
+ cls._vad_model = TenVad(hop_size=cls.VAD_HOP_SIZE, threshold=cls.VAD_THRESHOLD)
67
+ return cls._vad_model
68
+
69
+ @classmethod
70
+ def _get_speech_regions(
71
+ cls, audio: np.ndarray, sample_rate: int = 16000
72
+ ) -> list[tuple[float, float]]:
73
+ """Get speech regions using TEN-VAD.
74
+
75
+ Args:
76
+ audio: Audio waveform as numpy array
77
+ sample_rate: Audio sample rate
78
+
79
+ Returns:
80
+ List of (start_time, end_time) tuples for speech regions
81
+ """
82
+ vad_model = cls._get_vad_model()
83
+
84
+ # Convert to int16 as required by TEN-VAD
85
+ if audio.dtype != np.int16:
86
+ audio_int16 = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
87
+ else:
88
+ audio_int16 = audio
89
+
90
+ # Process frame by frame
91
+ hop_size = cls.VAD_HOP_SIZE
92
+ frame_duration = hop_size / sample_rate
93
+ speech_frames: list[bool] = []
94
+
95
+ for i in range(0, len(audio_int16) - hop_size, hop_size):
96
+ frame = audio_int16[i : i + hop_size]
97
+ _, is_speech = vad_model.process(frame)
98
+ speech_frames.append(is_speech)
99
+
100
+ # Convert frame-level decisions to segments
101
+ segments: list[tuple[float, float]] = []
102
+ in_speech = False
103
+ start_idx = 0
104
+
105
+ for i, is_speech in enumerate(speech_frames):
106
+ if is_speech and not in_speech:
107
+ start_idx = i
108
+ in_speech = True
109
+ elif not is_speech and in_speech:
110
+ start_time = start_idx * frame_duration
111
+ end_time = i * frame_duration
112
+ segments.append((start_time, end_time))
113
+ in_speech = False
114
+
115
+ # Handle trailing speech
116
+ if in_speech:
117
+ start_time = start_idx * frame_duration
118
+ end_time = len(speech_frames) * frame_duration
119
+ segments.append((start_time, end_time))
120
+
121
+ # Merge segments with small gaps
122
+ return cls._merge_speech_segments(segments)
123
+
124
+ @classmethod
125
+ def _merge_speech_segments(
126
+ cls, segments: list[tuple[float, float]]
127
+ ) -> list[tuple[float, float]]:
128
+ """Merge speech segments with small gaps."""
129
+ if not segments:
130
+ return segments
131
+
132
+ merged: list[tuple[float, float]] = [segments[0]]
133
+ for start, end in segments[1:]:
134
+ prev_start, prev_end = merged[-1]
135
+ if start - prev_end <= cls.VAD_MAX_GAP:
136
+ merged[-1] = (prev_start, end)
137
+ else:
138
+ merged.append((start, end))
139
+ return merged
140
+
141
+ @classmethod
142
+ def _is_in_speech(cls, time: float, speech_regions: list[tuple[float, float]]) -> bool:
143
+ """Check if a timestamp falls within any speech region."""
144
+ return any(start <= time <= end for start, end in speech_regions)
145
+
146
+ @classmethod
147
+ def _find_nearest_speech_boundary(
148
+ cls, time: float, speech_regions: list[tuple[float, float]], direction: str = "any"
149
+ ) -> float:
150
+ """Find the nearest speech region boundary to a timestamp.
151
+
152
+ Args:
153
+ time: Timestamp to find boundary for
154
+ speech_regions: List of (start, end) speech regions
155
+ direction: "start" for word starts, "end" for word ends, "any" for closest
156
+
157
+ Returns:
158
+ Adjusted timestamp snapped to nearest speech boundary
159
+ """
160
+ if not speech_regions:
161
+ return time
162
+
163
+ best_time = time
164
+ min_dist = float("inf")
165
+
166
+ for start, end in speech_regions:
167
+ # If time is inside this region, return as-is
168
+ if start <= time <= end:
169
+ return time
170
+
171
+ # Check distance to boundaries
172
+ if direction in ("start", "any"):
173
+ dist = abs(time - start)
174
+ if dist < min_dist:
175
+ min_dist = dist
176
+ best_time = start
177
+
178
+ if direction in ("end", "any"):
179
+ dist = abs(time - end)
180
+ if dist < min_dist:
181
+ min_dist = dist
182
+ best_time = end
183
+
184
+ return best_time
185
+
186
+ # Confidence threshold for alignment scores (log probability)
187
+ MIN_CONFIDENCE = -5.0 # Tokens with scores below this are considered low-confidence
188
+
189
  @classmethod
190
  def align(
191
  cls,
 
194
  sample_rate: int = 16000,
195
  _language: str = "eng",
196
  _batch_size: int = 16,
197
+ use_vad: bool = True,
198
  ) -> list[dict]:
199
  """Align transcript to audio and return word-level timestamps.
200
 
 
204
  sample_rate: Audio sample rate (default 16000)
205
  _language: ISO-639-3 language code (default "eng" for English, unused)
206
  _batch_size: Batch size for alignment model (unused)
207
+ use_vad: If True, use VAD to refine word boundaries (default True)
208
 
209
  Returns:
210
+ List of dicts with 'word', 'start', 'end', 'confidence' keys
211
  """
212
  import torchaudio
213
  from torchaudio.functional import forced_align, merge_tokens
 
215
  device = _get_device()
216
  model, labels, dictionary = cls.get_instance(device)
217
 
218
+ # Step 1: Get speech regions using VAD (before any processing)
219
+ speech_regions = []
220
+ if use_vad:
221
+ speech_regions = cls._get_speech_regions(audio, sample_rate)
222
+
223
  # Convert audio to tensor (copy to ensure array is writable)
224
  if isinstance(audio, np.ndarray):
225
  waveform = torch.from_numpy(audio.copy()).float()
 
272
  frame_duration = 320 / cls._bundle.sample_rate
273
 
274
  # Group token spans into words based on pipe separator
275
+ # Track confidence scores per word
276
  words = text.split()
277
  word_timestamps = []
278
  current_word_start = None
279
  current_word_end = None
280
+ current_word_scores: list[float] = []
281
  word_idx = 0
282
 
283
  for span in token_spans:
284
  token_char = labels[span.token]
285
  if token_char == "|": # Word separator
286
  if current_word_start is not None and word_idx < len(words):
287
+ # Calculate word confidence as mean of token scores
288
+ confidence = (
289
+ sum(current_word_scores) / len(current_word_scores)
290
+ if current_word_scores
291
+ else 0.0
292
+ )
293
  word_timestamps.append(
294
  {
295
  "word": words[word_idx],
296
  "start": current_word_start * frame_duration,
297
  "end": current_word_end * frame_duration,
298
+ "confidence": confidence,
299
  }
300
  )
301
  word_idx += 1
302
  current_word_start = None
303
  current_word_end = None
304
+ current_word_scores = []
305
  else:
306
  if current_word_start is None:
307
  current_word_start = span.start
308
  current_word_end = span.end
309
+ current_word_scores.append(span.score)
310
 
311
  # Don't forget the last word
312
  if current_word_start is not None and word_idx < len(words):
313
+ confidence = (
314
+ sum(current_word_scores) / len(current_word_scores) if current_word_scores else 0.0
315
+ )
316
  word_timestamps.append(
317
  {
318
  "word": words[word_idx],
319
  "start": current_word_start * frame_duration,
320
  "end": current_word_end * frame_duration,
321
+ "confidence": confidence,
322
  }
323
  )
324
 
325
+ # Step 2: Refine timestamps using VAD
326
+ if use_vad and speech_regions:
327
+ word_timestamps = cls._refine_with_vad(word_timestamps, speech_regions)
328
+
329
  return word_timestamps
330
 
331
+ @classmethod
332
+ def _refine_with_vad(
333
+ cls, word_timestamps: list[dict], speech_regions: list[tuple[float, float]]
334
+ ) -> list[dict]:
335
+ """Refine word timestamps using VAD speech regions.
336
+
337
+ - Words with low confidence that fall outside speech regions are flagged
338
+ - Word boundaries are snapped to speech region boundaries when close
339
+
340
+ Args:
341
+ word_timestamps: List of word dicts with 'start', 'end', 'confidence'
342
+ speech_regions: List of (start, end) speech regions
343
+
344
+ Returns:
345
+ Refined word timestamps
346
+ """
347
+ if not word_timestamps or not speech_regions:
348
+ return word_timestamps
349
+
350
+ refined = []
351
+ for word in word_timestamps:
352
+ start = word["start"]
353
+ end = word["end"]
354
+ confidence = word.get("confidence", 0.0)
355
+
356
+ # Check if word midpoint is in a speech region
357
+ midpoint = (start + end) / 2
358
+ in_speech = cls._is_in_speech(midpoint, speech_regions)
359
+
360
+ # For low-confidence words outside speech, snap to nearest speech boundary
361
+ if not in_speech and confidence < cls.MIN_CONFIDENCE:
362
+ # Find the nearest speech region and snap boundaries
363
+ start = cls._find_nearest_speech_boundary(start, speech_regions, "start")
364
+ end = cls._find_nearest_speech_boundary(end, speech_regions, "end")
365
+ # Ensure start < end
366
+ if start >= end:
367
+ end = start + 0.01
368
+
369
+ # For words near speech boundaries, snap to the boundary
370
+ # This helps align word edges with actual speech onset/offset
371
+ snap_threshold = 0.05 # 50ms
372
+ for region_start, region_end in speech_regions:
373
+ # Snap start to speech region start if close
374
+ if abs(start - region_start) < snap_threshold:
375
+ start = region_start
376
+ # Snap end to speech region end if close
377
+ if abs(end - region_end) < snap_threshold:
378
+ end = region_end
379
+
380
+ refined.append(
381
+ {
382
+ "word": word["word"],
383
+ "start": start,
384
+ "end": end,
385
+ "confidence": confidence,
386
+ }
387
+ )
388
+
389
+ return refined
390
+
391
 
392
  try:
393
  from .diarization import SpeakerDiarizer
diarization.py CHANGED
@@ -1,20 +1,18 @@
1
- """Speaker diarization using TEN-VAD + WavLM + spectral clustering.
2
-
3
- Pipeline:
4
- 1. TEN-VAD detects speech segments
5
- 2. WavLM (microsoft/wavlm-base-plus-sv) extracts speaker embeddings
6
- 3. Spectral clustering groups embeddings by speaker
7
 
8
  Spectral clustering implementation adapted from FunASR/3D-Speaker:
9
  https://github.com/alibaba-damo-academy/FunASR
10
  MIT License (https://opensource.org/licenses/MIT)
11
  """
12
 
 
 
13
  import numpy as np
14
  import scipy
15
  import sklearn.metrics.pairwise
16
  import torch
17
  from sklearn.cluster._kmeans import k_means
 
18
 
19
 
20
  def _get_device() -> torch.device:
@@ -71,23 +69,24 @@ class SpectralCluster:
71
  return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
72
 
73
  def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
74
- """Prune low similarity values in affinity matrix."""
75
- pval = 6.0 / affinity.shape[0] if affinity.shape[0] * self.pval < 6 else self.pval
76
- n_elems = int((1 - pval) * affinity.shape[0])
77
-
78
- # For each row in affinity matrix, zero out low similarities
79
- for i in range(affinity.shape[0]):
80
- low_indexes = np.argsort(affinity[i, :])
81
- low_indexes = low_indexes[0:n_elems]
82
- affinity[i, low_indexes] = 0
 
83
  return affinity
84
 
85
  def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
86
  """Compute unnormalized Laplacian matrix."""
87
- sim_mat[np.diag_indices(sim_mat.shape[0])] = 0
88
- degree = np.sum(np.abs(sim_mat), axis=1)
89
- degree_mat = np.diag(degree)
90
- return degree_mat - sim_mat
91
 
92
  def get_spec_embs(
93
  self, laplacian: np.ndarray, k_oracle: int | None = None
@@ -111,13 +110,9 @@ class SpectralCluster:
111
  _, labels, _ = k_means(emb, k, n_init=10)
112
  return labels
113
 
114
- def get_eigen_gaps(self, eig_vals: np.ndarray) -> list[float]:
115
  """Compute gaps between consecutive eigenvalues."""
116
- eig_vals_gap_list = []
117
- for i in range(len(eig_vals) - 1):
118
- gap = float(eig_vals[i + 1]) - float(eig_vals[i])
119
- eig_vals_gap_list.append(gap)
120
- return eig_vals_gap_list
121
 
122
 
123
  class SpeakerClusterer:
@@ -172,13 +167,9 @@ class SpeakerClusterer:
172
  if embeddings.shape[0] < 6:
173
  return np.zeros(embeddings.shape[0], dtype=int)
174
 
175
- # Normalize embeddings
176
- norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
177
- norms = np.maximum(norms, 1e-10)
178
- embeddings = embeddings / norms
179
-
180
- # Replace NaN/inf with zeros
181
  embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
 
182
 
183
  # Run spectral clustering (suppress numerical warnings)
184
  spectral = self._get_spectral_cluster()
@@ -208,49 +199,34 @@ class SpeakerClusterer:
208
 
209
  def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
210
  """Merge similar speakers by cosine similarity of centroids."""
211
- labels = labels.copy()
212
-
213
- while True:
214
- spk_num = labels.max() + 1
215
- if spk_num == 1:
216
- break
217
-
218
- # Compute speaker centroids
219
- spk_center = []
220
- for i in range(spk_num):
221
- spk_emb = embs[labels == i].mean(0)
222
- spk_center.append(spk_emb)
223
-
224
- if len(spk_center) == 0:
225
- break
226
-
227
- spk_center = np.stack(spk_center, axis=0)
228
- norm_spk_center = spk_center / np.linalg.norm(spk_center, axis=1, keepdims=True)
229
- affinity = np.matmul(norm_spk_center, norm_spk_center.T)
230
- affinity = np.triu(affinity, 1)
231
-
232
- # Find most similar pair
233
- spks = np.unravel_index(np.argmax(affinity), affinity.shape)
234
- if affinity[spks] < cos_thr:
235
- break
236
-
237
- # Merge speakers
238
- for i in range(len(labels)):
239
- if labels[i] == spks[1]:
240
- labels[i] = spks[0]
241
- elif labels[i] > spks[1]:
242
- labels[i] -= 1
243
 
244
- return labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
 
247
  class LocalSpeakerDiarizer:
248
- """Local speaker diarization using TEN-VAD + WavLM + spectral clustering.
249
 
250
  Pipeline:
251
  1. TEN-VAD detects speech segments
252
  2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
253
- 3. WavLM extracts speaker embeddings per window
254
  4. Spectral clustering with eigenvalue gap for auto speaker detection
255
  5. Frame-level consensus voting for segment reconstruction
256
  6. Post-processing merges short segments to reduce flicker
@@ -269,15 +245,14 @@ class LocalSpeakerDiarizer:
269
  """
270
 
271
  _ten_vad_model = None
272
- _speaker_model = None
273
  _device = None
274
 
275
  # ==================== TUNABLE PARAMETERS ====================
276
 
277
  # Sliding window for embedding extraction
278
- # Longer windows (1.5-2.0s) capture more prosody, reducing speaker confusion
279
- WINDOW_SIZE = 1.5 # seconds
280
- STEP_SIZE = 0.5 # seconds (67% overlap)
281
  TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
282
 
283
  # VAD hysteresis parameters
@@ -291,8 +266,8 @@ class LocalSpeakerDiarizer:
291
  VOTING_RATE = 0.01 # 10ms resolution for consensus voting
292
 
293
  # Post-processing
294
- MIN_SEGMENT_DURATION = 0.3 # Minimum final segment duration (seconds)
295
- SHORT_SEGMENT_GAP = 0.3 # Gap threshold for merging short segments
296
  SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
297
 
298
  # ===========================================================
@@ -314,21 +289,21 @@ class LocalSpeakerDiarizer:
314
  return cls._device
315
 
316
  @classmethod
317
- def _get_speaker_model(cls):
318
- """Lazy-load WavLM speaker embedding model (singleton)."""
319
- if cls._speaker_model is None:
320
- from transformers import WavLMForXVector
321
-
322
- cls._speaker_model = WavLMForXVector.from_pretrained(
323
- "microsoft/wavlm-base-plus-sv",
324
- )
325
-
326
- # Move model to best available device (MPS/CUDA/CPU)
327
- device = cls._get_device()
328
- cls._speaker_model = cls._speaker_model.to(device)
329
- cls._speaker_model.eval()
330
 
331
- return cls._speaker_model
332
 
333
  @classmethod
334
  def diarize(
@@ -382,10 +357,7 @@ class LocalSpeakerDiarizer:
382
  clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
383
  labels = clusterer(embeddings, num_speakers)
384
 
385
- # Step 4: Centroid refinement - reduces flickering/confusion
386
- labels = cls._refine_with_centroids(embeddings, labels)
387
-
388
- # Step 5: Post-process with consensus voting (VAD-aware)
389
  return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
390
 
391
  @classmethod
@@ -483,64 +455,12 @@ class LocalSpeakerDiarizer:
483
 
484
  return filtered
485
 
486
- @classmethod
487
- def _refine_with_centroids(cls, embeddings: np.ndarray, labels: np.ndarray) -> np.ndarray:
488
- """Refine cluster assignments using nearest centroid.
489
-
490
- This reduces "flickering" where embeddings rapidly switch between speakers.
491
- For each embedding, we re-assign it to the speaker whose centroid is closest
492
- (by cosine similarity).
493
-
494
- Args:
495
- embeddings: Speaker embeddings of shape [N, D]
496
- labels: Initial cluster labels of shape [N]
497
-
498
- Returns:
499
- Refined labels of shape [N]
500
- """
501
- if len(embeddings) == 0 or len(np.unique(labels)) <= 1:
502
- return labels
503
-
504
- # Normalize embeddings for cosine similarity
505
- norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
506
- norms = np.maximum(norms, 1e-10)
507
- norm_embeddings = embeddings / norms
508
-
509
- # Calculate centroid for each speaker
510
- unique_labels = np.unique(labels)
511
- centroids = {}
512
- for label in unique_labels:
513
- mask = labels == label
514
- speaker_embs = norm_embeddings[mask]
515
- centroid = speaker_embs.mean(axis=0)
516
- # Normalize centroid
517
- centroid_norm = np.linalg.norm(centroid)
518
- if centroid_norm > 1e-10:
519
- centroids[label] = centroid / centroid_norm
520
- else:
521
- centroids[label] = centroid
522
-
523
- # Re-assign each embedding to nearest centroid
524
- refined_labels = np.zeros_like(labels)
525
- for i, emb in enumerate(norm_embeddings):
526
- best_label = labels[i]
527
- best_sim = -1.0
528
- for label, centroid in centroids.items():
529
- sim = np.dot(emb, centroid)
530
- if sim > best_sim:
531
- best_sim = sim
532
- best_label = label
533
- refined_labels[i] = best_label
534
-
535
- return refined_labels
536
-
537
  @classmethod
538
  def _extract_embeddings(
539
  cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
540
  ) -> tuple[np.ndarray, list[dict]]:
541
  """Extract speaker embeddings using sliding windows."""
542
- speaker_model = cls._get_speaker_model()
543
- device = cls._get_device()
544
 
545
  window_samples = int(cls.WINDOW_SIZE * sample_rate)
546
  step_samples = int(cls.STEP_SIZE * sample_rate)
@@ -577,17 +497,15 @@ class LocalSpeakerDiarizer:
577
  pad_width = window_samples - len(chunk)
578
  chunk = np.pad(chunk, (0, pad_width), mode="reflect")
579
 
580
- # Extract embedding (WavLMForXVector returns XVectorOutput with .embeddings)
581
- chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0).to(device)
582
- output = speaker_model(chunk_tensor)
583
- embedding = output.embeddings.squeeze(0).cpu().numpy()
584
-
585
- # Validate and normalize
586
- if not np.isfinite(embedding).all():
587
- continue
588
- norm = np.linalg.norm(embedding)
589
- if norm > 1e-8:
590
- embeddings.append(embedding / norm)
591
  window_segments.append(
592
  {
593
  "start": c_start / sample_rate,
@@ -595,8 +513,9 @@ class LocalSpeakerDiarizer:
595
  }
596
  )
597
 
 
598
  if embeddings:
599
- return np.array(embeddings), window_segments
600
  return np.array([]), []
601
 
602
  @classmethod
@@ -611,15 +530,12 @@ class LocalSpeakerDiarizer:
611
  return np.zeros(num_frames, dtype=bool)
612
 
613
  vad_rate = 256 / 16000 # 16ms per VAD frame
614
- result = np.zeros(num_frames, dtype=bool)
615
-
616
- for i in range(num_frames):
617
- voting_time = i * cls.VOTING_RATE
618
- vad_frame = int(voting_time / vad_rate)
619
- if vad_frame < len(vad_frames):
620
- result[i] = vad_frames[vad_frame]
621
 
622
- return result
 
 
 
623
 
624
  @classmethod
625
  def _postprocess_segments(
@@ -768,7 +684,7 @@ class LocalSpeakerDiarizer:
768
 
769
 
770
  class SpeakerDiarizer:
771
- """Speaker diarization using TEN-VAD + WavLM + spectral clustering.
772
 
773
  Example:
774
  >>> segments = SpeakerDiarizer.diarize(audio_array)
 
1
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
 
 
 
 
 
2
 
3
  Spectral clustering implementation adapted from FunASR/3D-Speaker:
4
  https://github.com/alibaba-damo-academy/FunASR
5
  MIT License (https://opensource.org/licenses/MIT)
6
  """
7
 
8
+ import warnings
9
+
10
  import numpy as np
11
  import scipy
12
  import sklearn.metrics.pairwise
13
  import torch
14
  from sklearn.cluster._kmeans import k_means
15
+ from sklearn.preprocessing import normalize
16
 
17
 
18
  def _get_device() -> torch.device:
 
69
  return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
70
 
71
  def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
72
+ """Prune low similarity values in affinity matrix (keep top pval fraction)."""
73
+ n = affinity.shape[0]
74
+ pval = max(self.pval, 6.0 / n)
75
+ k_keep = max(1, int(pval * n))
76
+
77
+ # Vectorized: find top-k indices per row and zero out the rest
78
+ top_k_idx = np.argpartition(affinity, -k_keep, axis=1)[:, -k_keep:]
79
+ mask = np.zeros_like(affinity, dtype=bool)
80
+ np.put_along_axis(mask, top_k_idx, True, axis=1)
81
+ affinity[~mask] = 0
82
  return affinity
83
 
84
  def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
85
  """Compute unnormalized Laplacian matrix."""
86
+ from scipy.sparse.csgraph import laplacian
87
+
88
+ np.fill_diagonal(sim_mat, 0)
89
+ return laplacian(sim_mat, normed=False)
90
 
91
  def get_spec_embs(
92
  self, laplacian: np.ndarray, k_oracle: int | None = None
 
110
  _, labels, _ = k_means(emb, k, n_init=10)
111
  return labels
112
 
113
+ def get_eigen_gaps(self, eig_vals: np.ndarray) -> np.ndarray:
114
  """Compute gaps between consecutive eigenvalues."""
115
+ return np.diff(eig_vals)
 
 
 
 
116
 
117
 
118
  class SpeakerClusterer:
 
167
  if embeddings.shape[0] < 6:
168
  return np.zeros(embeddings.shape[0], dtype=int)
169
 
170
+ # Normalize embeddings and replace NaN/inf
 
 
 
 
 
171
  embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
172
+ embeddings = normalize(embeddings)
173
 
174
  # Run spectral clustering (suppress numerical warnings)
175
  spectral = self._get_spectral_cluster()
 
199
 
200
  def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
201
  """Merge similar speakers by cosine similarity of centroids."""
202
+ from scipy.cluster.hierarchy import fcluster, linkage
203
+ from scipy.spatial.distance import pdist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ unique_labels = np.unique(labels)
206
+ if len(unique_labels) <= 1:
207
+ return labels
208
+
209
+ # Compute normalized speaker centroids
210
+ centroids = np.array([embs[labels == lbl].mean(0) for lbl in unique_labels])
211
+ centroids = normalize(centroids)
212
+
213
+ # Hierarchical clustering with cosine distance
214
+ distances = pdist(centroids, metric="cosine")
215
+ linkage_matrix = linkage(distances, method="average")
216
+ merged_labels = fcluster(linkage_matrix, t=1.0 - cos_thr, criterion="distance") - 1
217
+
218
+ # Map original labels to merged labels
219
+ label_map = dict(zip(unique_labels, merged_labels))
220
+ return np.array([label_map[lbl] for lbl in labels])
221
 
222
 
223
  class LocalSpeakerDiarizer:
224
+ """Local speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
225
 
226
  Pipeline:
227
  1. TEN-VAD detects speech segments
228
  2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
229
+ 3. ECAPA-TDNN extracts speaker embeddings per window
230
  4. Spectral clustering with eigenvalue gap for auto speaker detection
231
  5. Frame-level consensus voting for segment reconstruction
232
  6. Post-processing merges short segments to reduce flicker
 
245
  """
246
 
247
  _ten_vad_model = None
248
+ _ecapa_model = None
249
  _device = None
250
 
251
  # ==================== TUNABLE PARAMETERS ====================
252
 
253
  # Sliding window for embedding extraction
254
+ WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
255
+ STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
 
256
  TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
257
 
258
  # VAD hysteresis parameters
 
266
  VOTING_RATE = 0.01 # 10ms resolution for consensus voting
267
 
268
  # Post-processing
269
+ MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
270
+ SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
271
  SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
272
 
273
  # ===========================================================
 
289
  return cls._device
290
 
291
  @classmethod
292
+ def _get_ecapa_model(cls):
293
+ """Lazy-load ECAPA-TDNN speaker embedding model (singleton)."""
294
+ if cls._ecapa_model is None:
295
+ # Suppress torchaudio deprecation warning from SpeechBrain
296
+ with warnings.catch_warnings():
297
+ warnings.filterwarnings("ignore", message="torchaudio._backend")
298
+ from speechbrain.inference.speaker import EncoderClassifier
299
+
300
+ device = cls._get_device()
301
+ cls._ecapa_model = EncoderClassifier.from_hparams(
302
+ source="speechbrain/spkrec-ecapa-voxceleb",
303
+ run_opts={"device": str(device)},
304
+ )
305
 
306
+ return cls._ecapa_model
307
 
308
  @classmethod
309
  def diarize(
 
357
  clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
358
  labels = clusterer(embeddings, num_speakers)
359
 
360
+ # Step 4: Post-process with consensus voting (VAD-aware)
 
 
 
361
  return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
362
 
363
  @classmethod
 
455
 
456
  return filtered
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  @classmethod
459
  def _extract_embeddings(
460
  cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
461
  ) -> tuple[np.ndarray, list[dict]]:
462
  """Extract speaker embeddings using sliding windows."""
463
+ speaker_model = cls._get_ecapa_model()
 
464
 
465
  window_samples = int(cls.WINDOW_SIZE * sample_rate)
466
  step_samples = int(cls.STEP_SIZE * sample_rate)
 
497
  pad_width = window_samples - len(chunk)
498
  chunk = np.pad(chunk, (0, pad_width), mode="reflect")
499
 
500
+ # Extract embedding using SpeechBrain's encode_batch
501
+ chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0)
502
+ embedding = (
503
+ speaker_model.encode_batch(chunk_tensor).squeeze(0).squeeze(0).cpu().numpy()
504
+ )
505
+
506
+ # Validate embedding
507
+ if np.isfinite(embedding).all() and np.linalg.norm(embedding) > 1e-8:
508
+ embeddings.append(embedding)
 
 
509
  window_segments.append(
510
  {
511
  "start": c_start / sample_rate,
 
513
  }
514
  )
515
 
516
+ # Normalize all embeddings at once
517
  if embeddings:
518
+ return normalize(np.array(embeddings)), window_segments
519
  return np.array([]), []
520
 
521
  @classmethod
 
530
  return np.zeros(num_frames, dtype=bool)
531
 
532
  vad_rate = 256 / 16000 # 16ms per VAD frame
533
+ vad_arr = np.array(vad_frames)
 
 
 
 
 
 
534
 
535
+ # Vectorized: compute VAD frame indices for each voting frame
536
+ voting_times = np.arange(num_frames) * cls.VOTING_RATE
537
+ vad_indices = np.clip((voting_times / vad_rate).astype(int), 0, len(vad_arr) - 1)
538
+ return vad_arr[vad_indices]
539
 
540
  @classmethod
541
  def _postprocess_segments(
 
684
 
685
 
686
  class SpeakerDiarizer:
687
+ """Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
688
 
689
  Example:
690
  >>> segments = SpeakerDiarizer.diarize(audio_array)