mazesmazes commited on
Commit
43c4368
·
verified ·
1 Parent(s): 0e60196

Update custom model files, README, and requirements

Browse files
Files changed (2) hide show
  1. asr_pipeline.py +54 -168
  2. diarization.py +29 -121
asr_pipeline.py CHANGED
@@ -1,7 +1,6 @@
1
  """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
 
3
  import re
4
- from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Any
7
 
@@ -24,135 +23,8 @@ def _get_device() -> str:
24
  return "cpu"
25
 
26
 
27
- @dataclass
28
- class _AlignPoint:
29
- """A point in the alignment path."""
30
-
31
- token_index: int
32
- time_index: int
33
- score: float
34
-
35
-
36
- @dataclass
37
- class _AlignSegment:
38
- """An aligned character/word segment."""
39
-
40
- label: str
41
- start: int
42
- end: int
43
- score: float
44
-
45
- @property
46
- def length(self):
47
- return self.end - self.start
48
-
49
-
50
- def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
51
- """Build dynamic programming trellis for CTC alignment.
52
-
53
- Based on WhisperX's alignment algorithm for improved accuracy.
54
- """
55
- num_frame = emission.size(0)
56
- num_tokens = len(tokens)
57
-
58
- trellis = torch.zeros((num_frame, num_tokens))
59
- trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
60
- trellis[0, 1:] = -float("inf")
61
- trellis[-num_tokens + 1 :, 0] = float("inf")
62
-
63
- for t in range(num_frame - 1):
64
- trellis[t + 1, 1:] = torch.maximum(
65
- # Score for staying at the same token
66
- trellis[t, 1:] + emission[t, blank_id],
67
- # Score for changing to the next token
68
- trellis[t, :-1] + emission[t, tokens[1:]],
69
- )
70
- return trellis
71
-
72
-
73
- def _backtrack(
74
- trellis: torch.Tensor,
75
- emission: torch.Tensor,
76
- tokens: list[int],
77
- blank_id: int = 0,
78
- ) -> list[_AlignPoint]:
79
- """Backtrack through trellis to find optimal alignment path."""
80
- t, j = trellis.size(0) - 1, trellis.size(1) - 1
81
-
82
- path = [_AlignPoint(j, t, emission[t, blank_id].exp().item())]
83
- while j > 0:
84
- assert t > 0
85
-
86
- p_stay = emission[t - 1, blank_id]
87
- p_change = emission[t - 1, tokens[j]]
88
-
89
- stayed = trellis[t - 1, j] + p_stay
90
- changed = trellis[t - 1, j - 1] + p_change
91
-
92
- t -= 1
93
- if changed > stayed:
94
- j -= 1
95
-
96
- prob = (p_change if changed > stayed else p_stay).exp().item()
97
- path.append(_AlignPoint(j, t, prob))
98
-
99
- while t > 0:
100
- prob = emission[t - 1, blank_id].exp().item()
101
- path.append(_AlignPoint(j, t - 1, prob))
102
- t -= 1
103
-
104
- return path[::-1]
105
-
106
-
107
- def _merge_repeats(path: list[_AlignPoint], transcript: str) -> list[_AlignSegment]:
108
- """Merge repeated tokens into character segments."""
109
- i1, i2 = 0, 0
110
- segments = []
111
- while i1 < len(path):
112
- while i2 < len(path) and path[i1].token_index == path[i2].token_index:
113
- i2 += 1
114
- score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
115
- segments.append(
116
- _AlignSegment(
117
- transcript[path[i1].token_index],
118
- path[i1].time_index,
119
- path[i2 - 1].time_index + 1,
120
- score,
121
- )
122
- )
123
- i1 = i2
124
- return segments
125
-
126
-
127
- def _merge_words(segments: list[_AlignSegment], separator: str = "|") -> list[_AlignSegment]:
128
- """Merge character segments into word segments."""
129
- words = []
130
- i1, i2 = 0, 0
131
- while i1 < len(segments):
132
- if i2 >= len(segments) or segments[i2].label == separator:
133
- if i1 != i2:
134
- segs = segments[i1:i2]
135
- word = "".join([seg.label for seg in segs])
136
- total_length = sum(seg.length for seg in segs)
137
- score = (
138
- sum(seg.score * seg.length for seg in segs) / total_length
139
- if total_length > 0
140
- else 0
141
- )
142
- words.append(_AlignSegment(word, segments[i1].start, segments[i2 - 1].end, score))
143
- i1 = i2 + 1
144
- i2 = i1
145
- else:
146
- i2 += 1
147
- return words
148
-
149
-
150
  class ForcedAligner:
151
- """Forced aligner for word-level timestamps using wav2vec2.
152
-
153
- Uses WhisperX-style dynamic programming alignment for improved accuracy
154
- over simple CTC greedy alignment.
155
- """
156
 
157
  _bundle = None
158
  _model = None
@@ -172,8 +44,7 @@ class ForcedAligner:
172
  if cls._model is None:
173
  import torchaudio
174
 
175
- # Use LARGE model for better accuracy (same as WhisperX recommendation)
176
- cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_960H
177
  cls._model = cls._bundle.get_model().to(device)
178
  cls._model.eval()
179
  cls._labels = cls._bundle.get_labels()
@@ -186,29 +57,28 @@ class ForcedAligner:
186
  audio: np.ndarray,
187
  text: str,
188
  sample_rate: int = 16000,
189
- _language: str = "en",
190
  _batch_size: int = 16,
191
  ) -> list[dict]:
192
  """Align transcript to audio and return word-level timestamps.
193
 
194
- Uses WhisperX-style dynamic programming for improved alignment accuracy.
195
-
196
  Args:
197
  audio: Audio waveform as numpy array
198
  text: Transcript text to align
199
  sample_rate: Audio sample rate (default 16000)
200
- _language: Language code (unused, English only)
201
- _batch_size: Batch size (unused)
202
 
203
  Returns:
204
  List of dicts with 'word', 'start', 'end' keys
205
  """
206
  import torchaudio
 
207
 
208
  device = _get_device()
209
  model, labels, dictionary = cls.get_instance(device)
210
 
211
- # Convert audio to tensor
212
  if isinstance(audio, np.ndarray):
213
  waveform = torch.from_numpy(audio.copy()).float()
214
  else:
@@ -218,7 +88,7 @@ class ForcedAligner:
218
  if waveform.dim() == 1:
219
  waveform = waveform.unsqueeze(0)
220
 
221
- # Resample if needed
222
  if sample_rate != cls._bundle.sample_rate:
223
  waveform = torchaudio.functional.resample(
224
  waveform, sample_rate, cls._bundle.sample_rate
@@ -233,47 +103,67 @@ class ForcedAligner:
233
 
234
  emission = emissions[0].cpu()
235
 
236
- # Normalize text and build token sequence
237
  transcript = text.upper()
 
238
  tokens = []
239
- clean_transcript = ""
240
-
241
  for char in transcript:
242
  if char in dictionary:
243
  tokens.append(dictionary[char])
244
- clean_transcript += char
245
  elif char == " ":
246
- sep_token = dictionary.get("|", dictionary.get(" ", 0))
247
- tokens.append(sep_token)
248
- clean_transcript += "|"
249
 
250
  if not tokens:
251
  return []
252
 
253
- # Build trellis and find optimal path (WhisperX-style DP alignment)
254
- trellis = _get_trellis(emission, tokens, blank_id=0)
255
- path = _backtrack(trellis, emission, tokens, blank_id=0)
 
 
 
256
 
257
- # Merge into character segments, then word segments
258
- char_segments = _merge_repeats(path, clean_transcript)
259
- word_segments = _merge_words(char_segments, separator="|")
260
 
261
- # Convert frame indices to time
262
- frame_duration = 320 / cls._bundle.sample_rate # 20ms per frame
263
 
264
- # Build output with original words
265
  words = text.split()
266
  word_timestamps = []
267
-
268
- for i, seg in enumerate(word_segments):
269
- if i < len(words):
270
- word_timestamps.append(
271
- {
272
- "word": words[i],
273
- "start": seg.start * frame_duration,
274
- "end": seg.end * frame_duration,
275
- }
276
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  return word_timestamps
279
 
@@ -339,8 +229,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
339
  num_speakers: Exact number of speakers (if known, for diarization)
340
  min_speakers: Minimum number of speakers (for diarization)
341
  max_speakers: Maximum number of speakers (for diarization)
342
- hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var)
343
- diarization_backend: Backend for diarization ("pyannote" or "local")
344
  **kwargs: Additional arguments passed to the pipeline
345
 
346
  Returns:
@@ -355,8 +243,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
355
  "num_speakers": kwargs.pop("num_speakers", None),
356
  "min_speakers": kwargs.pop("min_speakers", None),
357
  "max_speakers": kwargs.pop("max_speakers", None),
358
- "hf_token": kwargs.pop("hf_token", None),
359
- "backend": kwargs.pop("diarization_backend", "pyannote"),
360
  }
361
 
362
  if return_speakers:
 
1
  """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
 
3
  import re
 
4
  from pathlib import Path
5
  from typing import Any
6
 
 
23
  return "cpu"
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class ForcedAligner:
27
+ """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
 
 
 
 
28
 
29
  _bundle = None
30
  _model = None
 
44
  if cls._model is None:
45
  import torchaudio
46
 
47
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
 
48
  cls._model = cls._bundle.get_model().to(device)
49
  cls._model.eval()
50
  cls._labels = cls._bundle.get_labels()
 
57
  audio: np.ndarray,
58
  text: str,
59
  sample_rate: int = 16000,
60
+ _language: str = "eng",
61
  _batch_size: int = 16,
62
  ) -> list[dict]:
63
  """Align transcript to audio and return word-level timestamps.
64
 
 
 
65
  Args:
66
  audio: Audio waveform as numpy array
67
  text: Transcript text to align
68
  sample_rate: Audio sample rate (default 16000)
69
+ _language: ISO-639-3 language code (default "eng" for English, unused)
70
+ _batch_size: Batch size for alignment model (unused)
71
 
72
  Returns:
73
  List of dicts with 'word', 'start', 'end' keys
74
  """
75
  import torchaudio
76
+ from torchaudio.functional import forced_align, merge_tokens
77
 
78
  device = _get_device()
79
  model, labels, dictionary = cls.get_instance(device)
80
 
81
+ # Convert audio to tensor (copy to ensure array is writable)
82
  if isinstance(audio, np.ndarray):
83
  waveform = torch.from_numpy(audio.copy()).float()
84
  else:
 
88
  if waveform.dim() == 1:
89
  waveform = waveform.unsqueeze(0)
90
 
91
+ # Resample if needed (wav2vec2 expects 16kHz)
92
  if sample_rate != cls._bundle.sample_rate:
93
  waveform = torchaudio.functional.resample(
94
  waveform, sample_rate, cls._bundle.sample_rate
 
103
 
104
  emission = emissions[0].cpu()
105
 
106
+ # Normalize text: uppercase, keep only valid characters
107
  transcript = text.upper()
108
+ # Build tokens from transcript
109
  tokens = []
 
 
110
  for char in transcript:
111
  if char in dictionary:
112
  tokens.append(dictionary[char])
 
113
  elif char == " ":
114
+ tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
 
 
115
 
116
  if not tokens:
117
  return []
118
 
119
+ targets = torch.tensor([tokens], dtype=torch.int32)
120
+
121
+ # Run forced alignment
122
+ # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
123
+ # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
124
+ aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
125
 
126
+ # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
127
+ token_spans = merge_tokens(aligned_tokens[0], scores[0])
 
128
 
129
+ # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
130
+ frame_duration = 320 / cls._bundle.sample_rate
131
 
132
+ # Group token spans into words based on pipe separator
133
  words = text.split()
134
  word_timestamps = []
135
+ current_word_start = None
136
+ current_word_end = None
137
+ word_idx = 0
138
+
139
+ for span in token_spans:
140
+ token_char = labels[span.token]
141
+ if token_char == "|": # Word separator
142
+ if current_word_start is not None and word_idx < len(words):
143
+ word_timestamps.append(
144
+ {
145
+ "word": words[word_idx],
146
+ "start": current_word_start * frame_duration,
147
+ "end": current_word_end * frame_duration,
148
+ }
149
+ )
150
+ word_idx += 1
151
+ current_word_start = None
152
+ current_word_end = None
153
+ else:
154
+ if current_word_start is None:
155
+ current_word_start = span.start
156
+ current_word_end = span.end
157
+
158
+ # Don't forget the last word
159
+ if current_word_start is not None and word_idx < len(words):
160
+ word_timestamps.append(
161
+ {
162
+ "word": words[word_idx],
163
+ "start": current_word_start * frame_duration,
164
+ "end": current_word_end * frame_duration,
165
+ }
166
+ )
167
 
168
  return word_timestamps
169
 
 
229
  num_speakers: Exact number of speakers (if known, for diarization)
230
  min_speakers: Minimum number of speakers (for diarization)
231
  max_speakers: Maximum number of speakers (for diarization)
 
 
232
  **kwargs: Additional arguments passed to the pipeline
233
 
234
  Returns:
 
243
  "num_speakers": kwargs.pop("num_speakers", None),
244
  "min_speakers": kwargs.pop("min_speakers", None),
245
  "max_speakers": kwargs.pop("max_speakers", None),
 
 
246
  }
247
 
248
  if return_speakers:
diarization.py CHANGED
@@ -1,8 +1,9 @@
1
- """Speaker diarization with support for pyannote and local (tiny-audio) backends.
2
 
3
- Provides two diarization backends:
4
- - pyannote: Uses pyannote-audio pipeline (requires HF token with model access)
5
- - local: Uses TEN-VAD + ERes2NetV2 + spectral clustering (no token required)
 
6
 
7
  Spectral clustering implementation adapted from FunASR/3D-Speaker:
8
  https://github.com/alibaba-damo-academy/FunASR
@@ -244,12 +245,12 @@ class SpeakerClusterer:
244
 
245
 
246
  class LocalSpeakerDiarizer:
247
- """Local speaker diarization using TEN-VAD + ERes2NetV2 + spectral clustering.
248
 
249
  Pipeline:
250
  1. TEN-VAD detects speech segments
251
  2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
252
- 3. ERes2NetV2 extracts speaker embeddings per window
253
  4. Spectral clustering with eigenvalue gap for auto speaker detection
254
  5. Frame-level consensus voting for segment reconstruction
255
  6. Post-processing merges short segments to reduce flicker
@@ -268,7 +269,7 @@ class LocalSpeakerDiarizer:
268
  """
269
 
270
  _ten_vad_model = None
271
- _eres2netv2_model = None
272
  _device = None
273
 
274
  # ==================== TUNABLE PARAMETERS ====================
@@ -312,25 +313,21 @@ class LocalSpeakerDiarizer:
312
  return cls._device
313
 
314
  @classmethod
315
- def _get_eres2netv2_model(cls):
316
- """Lazy-load ERes2NetV2 speaker embedding model (singleton)."""
317
- if cls._eres2netv2_model is None:
318
- from modelscope.pipelines import pipeline
319
- from modelscope.utils.constant import Tasks
320
-
321
- sv_pipeline = pipeline(
322
- task=Tasks.speaker_verification,
323
- model="iic/speech_eres2netv2_sv_zh-cn_16k-common",
324
  )
325
- cls._eres2netv2_model = sv_pipeline.model
326
 
327
- # Move model to GPU if available
328
  device = cls._get_device()
329
- cls._eres2netv2_model = cls._eres2netv2_model.to(device)
330
- cls._eres2netv2_model.device = device
331
- cls._eres2netv2_model.eval()
332
 
333
- return cls._eres2netv2_model
334
 
335
  @classmethod
336
  def diarize(
@@ -487,7 +484,7 @@ class LocalSpeakerDiarizer:
487
  cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
488
  ) -> tuple[np.ndarray, list[dict]]:
489
  """Extract speaker embeddings using sliding windows."""
490
- speaker_model = cls._get_eres2netv2_model()
491
  device = cls._get_device()
492
 
493
  window_samples = int(cls.WINDOW_SIZE * sample_rate)
@@ -525,9 +522,10 @@ class LocalSpeakerDiarizer:
525
  pad_width = window_samples - len(chunk)
526
  chunk = np.pad(chunk, (0, pad_width), mode="reflect")
527
 
528
- # Extract embedding
529
  chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0).to(device)
530
- embedding = speaker_model.forward(chunk_tensor).squeeze(0).cpu().numpy()
 
531
 
532
  # Validate and normalize
533
  if not np.isfinite(embedding).all():
@@ -715,34 +713,14 @@ class LocalSpeakerDiarizer:
715
 
716
 
717
  class SpeakerDiarizer:
718
- """Unified speaker diarization interface supporting multiple backends.
719
-
720
- Backends:
721
- - 'pyannote': Uses pyannote-audio pipeline (requires HF token)
722
- - 'local': Uses TEN-VAD + ERes2NetV2 + spectral clustering
723
 
724
  Example:
725
- >>> segments = SpeakerDiarizer.diarize(audio_array, backend="local")
726
  >>> for seg in segments:
727
  ... print(f"{seg['speaker']}: {seg['start']:.2f} - {seg['end']:.2f}")
728
  """
729
 
730
- _pyannote_pipeline = None
731
-
732
- @classmethod
733
- def _get_pyannote_pipeline(cls, hf_token: str | None = None):
734
- """Get or create the pyannote diarization pipeline."""
735
- if cls._pyannote_pipeline is None:
736
- from pyannote.audio import Pipeline
737
-
738
- cls._pyannote_pipeline = Pipeline.from_pretrained(
739
- "pyannote/speaker-diarization-3.1",
740
- token=hf_token,
741
- )
742
- cls._pyannote_pipeline.to(torch.device(_get_device()))
743
-
744
- return cls._pyannote_pipeline
745
-
746
  @classmethod
747
  def diarize(
748
  cls,
@@ -751,8 +729,7 @@ class SpeakerDiarizer:
751
  num_speakers: int | None = None,
752
  min_speakers: int | None = None,
753
  max_speakers: int | None = None,
754
- hf_token: str | None = None,
755
- backend: str = "pyannote",
756
  ) -> list[dict]:
757
  """Run speaker diarization on audio.
758
 
@@ -762,87 +739,18 @@ class SpeakerDiarizer:
762
  num_speakers: Exact number of speakers (if known)
763
  min_speakers: Minimum number of speakers
764
  max_speakers: Maximum number of speakers
765
- hf_token: HuggingFace token for pyannote models
766
- backend: Diarization backend ("pyannote" or "local")
767
 
768
  Returns:
769
  List of dicts with 'speaker', 'start', 'end' keys
770
  """
771
- if backend == "local":
772
- return LocalSpeakerDiarizer.diarize(
773
- audio,
774
- sample_rate=sample_rate,
775
- num_speakers=num_speakers,
776
- min_speakers=min_speakers or 2,
777
- max_speakers=max_speakers or 10,
778
- )
779
-
780
- # Default to pyannote
781
- return cls._diarize_pyannote(
782
  audio,
783
  sample_rate=sample_rate,
784
  num_speakers=num_speakers,
785
- min_speakers=min_speakers,
786
- max_speakers=max_speakers,
787
- hf_token=hf_token,
788
  )
789
 
790
- @classmethod
791
- def _diarize_pyannote(
792
- cls,
793
- audio: np.ndarray | str,
794
- sample_rate: int = 16000,
795
- num_speakers: int | None = None,
796
- min_speakers: int | None = None,
797
- max_speakers: int | None = None,
798
- hf_token: str | None = None,
799
- ) -> list[dict]:
800
- """Run pyannote diarization."""
801
- pipeline = cls._get_pyannote_pipeline(hf_token)
802
-
803
- # Prepare audio input
804
- if isinstance(audio, np.ndarray):
805
- waveform = torch.from_numpy(audio.copy()).unsqueeze(0)
806
- if waveform.dim() == 1:
807
- waveform = waveform.unsqueeze(0)
808
- audio_input = {"waveform": waveform, "sample_rate": sample_rate}
809
- else:
810
- audio_input = audio
811
-
812
- # Run diarization
813
- diarization_args = {}
814
- if num_speakers is not None:
815
- diarization_args["num_speakers"] = num_speakers
816
- if min_speakers is not None:
817
- diarization_args["min_speakers"] = min_speakers
818
- if max_speakers is not None:
819
- diarization_args["max_speakers"] = max_speakers
820
-
821
- diarization = pipeline(audio_input, **diarization_args)
822
-
823
- # Handle different pyannote return types
824
- if hasattr(diarization, "itertracks"):
825
- annotation = diarization
826
- elif hasattr(diarization, "speaker_diarization"):
827
- annotation = diarization.speaker_diarization
828
- elif isinstance(diarization, tuple):
829
- annotation = diarization[0]
830
- else:
831
- raise TypeError(f"Unexpected diarization output type: {type(diarization)}")
832
-
833
- # Convert to simple format
834
- segments = []
835
- for turn, _, speaker in annotation.itertracks(yield_label=True):
836
- segments.append(
837
- {
838
- "speaker": speaker,
839
- "start": turn.start,
840
- "end": turn.end,
841
- }
842
- )
843
-
844
- return segments
845
-
846
  @classmethod
847
  def assign_speakers_to_words(
848
  cls,
 
1
+ """Speaker diarization using TEN-VAD + WavLM + spectral clustering.
2
 
3
+ Pipeline:
4
+ 1. TEN-VAD detects speech segments
5
+ 2. WavLM (microsoft/wavlm-base-plus-sv) extracts speaker embeddings
6
+ 3. Spectral clustering groups embeddings by speaker
7
 
8
  Spectral clustering implementation adapted from FunASR/3D-Speaker:
9
  https://github.com/alibaba-damo-academy/FunASR
 
245
 
246
 
247
  class LocalSpeakerDiarizer:
248
+ """Local speaker diarization using TEN-VAD + WavLM + spectral clustering.
249
 
250
  Pipeline:
251
  1. TEN-VAD detects speech segments
252
  2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
253
+ 3. WavLM extracts speaker embeddings per window
254
  4. Spectral clustering with eigenvalue gap for auto speaker detection
255
  5. Frame-level consensus voting for segment reconstruction
256
  6. Post-processing merges short segments to reduce flicker
 
269
  """
270
 
271
  _ten_vad_model = None
272
+ _speaker_model = None
273
  _device = None
274
 
275
  # ==================== TUNABLE PARAMETERS ====================
 
313
  return cls._device
314
 
315
  @classmethod
316
+ def _get_speaker_model(cls):
317
+ """Lazy-load WavLM speaker embedding model (singleton)."""
318
+ if cls._speaker_model is None:
319
+ from transformers import WavLMForXVector
320
+
321
+ cls._speaker_model = WavLMForXVector.from_pretrained(
322
+ "microsoft/wavlm-base-plus-sv",
 
 
323
  )
 
324
 
325
+ # Move model to best available device (MPS/CUDA/CPU)
326
  device = cls._get_device()
327
+ cls._speaker_model = cls._speaker_model.to(device)
328
+ cls._speaker_model.eval()
 
329
 
330
+ return cls._speaker_model
331
 
332
  @classmethod
333
  def diarize(
 
484
  cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
485
  ) -> tuple[np.ndarray, list[dict]]:
486
  """Extract speaker embeddings using sliding windows."""
487
+ speaker_model = cls._get_speaker_model()
488
  device = cls._get_device()
489
 
490
  window_samples = int(cls.WINDOW_SIZE * sample_rate)
 
522
  pad_width = window_samples - len(chunk)
523
  chunk = np.pad(chunk, (0, pad_width), mode="reflect")
524
 
525
+ # Extract embedding (WavLMForXVector returns XVectorOutput with .embeddings)
526
  chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0).to(device)
527
+ output = speaker_model(chunk_tensor)
528
+ embedding = output.embeddings.squeeze(0).cpu().numpy()
529
 
530
  # Validate and normalize
531
  if not np.isfinite(embedding).all():
 
713
 
714
 
715
  class SpeakerDiarizer:
716
+ """Speaker diarization using TEN-VAD + WavLM + spectral clustering.
 
 
 
 
717
 
718
  Example:
719
+ >>> segments = SpeakerDiarizer.diarize(audio_array)
720
  >>> for seg in segments:
721
  ... print(f"{seg['speaker']}: {seg['start']:.2f} - {seg['end']:.2f}")
722
  """
723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  @classmethod
725
  def diarize(
726
  cls,
 
729
  num_speakers: int | None = None,
730
  min_speakers: int | None = None,
731
  max_speakers: int | None = None,
732
+ **_kwargs,
 
733
  ) -> list[dict]:
734
  """Run speaker diarization on audio.
735
 
 
739
  num_speakers: Exact number of speakers (if known)
740
  min_speakers: Minimum number of speakers
741
  max_speakers: Maximum number of speakers
 
 
742
 
743
  Returns:
744
  List of dicts with 'speaker', 'start', 'end' keys
745
  """
746
+ return LocalSpeakerDiarizer.diarize(
 
 
 
 
 
 
 
 
 
 
747
  audio,
748
  sample_rate=sample_rate,
749
  num_speakers=num_speakers,
750
+ min_speakers=min_speakers or 2,
751
+ max_speakers=max_speakers or 10,
 
752
  )
753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  @classmethod
755
  def assign_speakers_to_words(
756
  cls,