ConvxO2 commited on
Commit
8d04859
·
1 Parent(s): 4b8c370

Add pyannote-first diarization path and tune fallback clustering

Browse files
Files changed (3) hide show
  1. app/main.py +5 -1
  2. app/pipeline.py +155 -33
  3. models/clusterer.py +20 -4
app/main.py CHANGED
@@ -1,4 +1,4 @@
1
- """Speaker Diarization API - FastAPI Application."""
2
 
3
  import asyncio
4
  import tempfile
@@ -71,6 +71,8 @@ def get_pipeline():
71
  _pipeline = DiarizationPipeline(
72
  device="auto",
73
  use_pyannote_vad=True,
 
 
74
  hf_token=os.getenv("HF_TOKEN"),
75
  max_speakers=10,
76
  cache_dir=cache_dir,
@@ -283,3 +285,5 @@ async def debug():
283
  static_dir = Path(__file__).resolve().parent.parent / "static"
284
  if static_dir.exists():
285
  app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
 
 
 
1
+ """Speaker Diarization API - FastAPI Application."""
2
 
3
  import asyncio
4
  import tempfile
 
71
  _pipeline = DiarizationPipeline(
72
  device="auto",
73
  use_pyannote_vad=True,
74
+ use_pyannote_diarization=os.getenv("USE_PYANNOTE_DIARIZATION", "true").lower() in {"1", "true", "yes"},
75
+ pyannote_diarization_model=os.getenv("PYANNOTE_DIARIZATION_MODEL", "pyannote/speaker-diarization-3.1"),
76
  hf_token=os.getenv("HF_TOKEN"),
77
  max_speakers=10,
78
  cache_dir=cache_dir,
 
285
  static_dir = Path(__file__).resolve().parent.parent / "static"
286
  if static_dir.exists():
287
  app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
288
+
289
+
app/pipeline.py CHANGED
@@ -1,14 +1,17 @@
1
  """
2
  Speaker Diarization Pipeline
3
- Combines: Voice Activity Detection -> Segmentation -> ECAPA-TDNN Embeddings -> AHC Clustering
4
  """
5
 
6
- import torch
7
- import torchaudio
8
- import numpy as np
9
  from pathlib import Path
10
  from typing import Optional, List, Union, BinaryIO
11
  from dataclasses import dataclass, field
 
 
 
 
12
  from loguru import logger
13
 
14
  from models.embedder import EcapaTDNNEmbedder
@@ -55,25 +58,19 @@ class DiarizationResult:
55
 
56
 
57
  class DiarizationPipeline:
58
- """
59
- End-to-end speaker diarization pipeline.
60
- 1. Audio loading & preprocessing
61
- 2. Voice Activity Detection (VAD) via pyannote or energy-based fallback
62
- 3. Sliding-window segmentation of speech regions
63
- 4. ECAPA-TDNN speaker embedding extraction per segment
64
- 5. Agglomerative Hierarchical Clustering
65
- 6. Post-processing: merge consecutive same-speaker segments
66
- """
67
 
68
  SAMPLE_RATE = 16000
69
- WINDOW_DURATION = 1.5
70
- WINDOW_STEP = 0.75
71
- MIN_SEGMENT_DURATION = 0.5
72
 
73
  def __init__(
74
  self,
75
  device: str = "auto",
76
  use_pyannote_vad: bool = True,
 
 
77
  hf_token: Optional[str] = None,
78
  num_speakers: Optional[int] = None,
79
  max_speakers: int = 10,
@@ -81,15 +78,18 @@ class DiarizationPipeline:
81
  ):
82
  self.device = self._resolve_device(device)
83
  self.use_pyannote_vad = use_pyannote_vad
 
 
84
  self.hf_token = hf_token
85
  self.num_speakers = num_speakers
86
  self.max_speakers = max_speakers
87
  self.cache_dir = Path(cache_dir)
88
 
89
  self.embedder = EcapaTDNNEmbedder(device=self.device, cache_dir=str(cache_dir))
90
- self.clusterer = SpeakerClusterer(max_speakers=max_speakers)
91
 
92
  self._vad_pipeline = None
 
93
  logger.info(f"DiarizationPipeline ready | device={self.device}")
94
 
95
  def _resolve_device(self, device: str) -> str:
@@ -98,7 +98,6 @@ class DiarizationPipeline:
98
  return device
99
 
100
  def _to_mono_1d(self, audio: torch.Tensor) -> torch.Tensor:
101
- """Convert waveform to a mono 1D tensor for duration and preprocessing."""
102
  if audio.dim() == 1:
103
  return audio
104
  if audio.dim() >= 2:
@@ -107,26 +106,141 @@ class DiarizationPipeline:
107
  return audio.mean(dim=0)
108
  return audio.reshape(-1)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def _load_vad(self):
111
  if self._vad_pipeline is not None:
112
  return
113
  try:
114
- from pyannote.audio import Pipeline
115
  logger.info("Loading pyannote VAD pipeline...")
116
- self._vad_pipeline = Pipeline.from_pretrained(
117
- "pyannote/voice-activity-detection",
118
- use_auth_token=self.hf_token,
119
- )
120
- self._vad_pipeline.to(torch.device(self.device))
121
  logger.success("Pyannote VAD loaded.")
122
  except Exception as e:
123
  logger.warning(f"Could not load pyannote VAD: {e}. Falling back to energy-based VAD.")
124
  self._vad_pipeline = "energy"
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def _energy_vad(
127
  self, audio: torch.Tensor, frame_duration: float = 0.02, threshold_db: float = -40.0
128
  ) -> List[tuple]:
129
- """Simple energy-based VAD as fallback."""
130
  frame_samples = int(frame_duration * self.SAMPLE_RATE)
131
  audio_np = audio.numpy()
132
  frames = [
@@ -206,9 +320,6 @@ class DiarizationPipeline:
206
  sample_rate: int = None,
207
  num_speakers: Optional[int] = None,
208
  ) -> DiarizationResult:
209
- """Run full diarization pipeline on audio."""
210
- import time
211
-
212
  t_start = time.time()
213
 
214
  if isinstance(audio, (str, Path)):
@@ -232,6 +343,18 @@ class DiarizationPipeline:
232
  sample_rate=sample_rate,
233
  )
234
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  processed = self.embedder.preprocess_audio(audio_tensor, sample_rate)
236
 
237
  speech_regions = self._get_speech_regions(processed)
@@ -262,10 +385,10 @@ class DiarizationPipeline:
262
  sample_rate=sample_rate,
263
  )
264
 
265
- k = num_speakers or self.num_speakers
266
  labels = self.clusterer.cluster(embeddings, num_speakers=k)
267
-
268
- merged = self.clusterer.merge_consecutive_same_speaker(valid_windows, labels)
 
269
 
270
  speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)}
271
  segments = [
@@ -277,7 +400,7 @@ class DiarizationPipeline:
277
  processing_time = time.time() - t_start
278
 
279
  logger.success(
280
- f"Diarization complete: {num_unique} speakers, "
281
  f"{len(segments)} segments, {processing_time:.2f}s"
282
  )
283
 
@@ -288,4 +411,3 @@ class DiarizationPipeline:
288
  processing_time=processing_time,
289
  sample_rate=sample_rate,
290
  )
291
-
 
1
  """
2
  Speaker Diarization Pipeline
3
+ Combines: pyannote diarization (preferred) -> fallback VAD + ECAPA-TDNN + AHC clustering
4
  """
5
 
6
+ import tempfile
7
+ import time
 
8
  from pathlib import Path
9
  from typing import Optional, List, Union, BinaryIO
10
  from dataclasses import dataclass, field
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
  from loguru import logger
16
 
17
  from models.embedder import EcapaTDNNEmbedder
 
58
 
59
 
60
  class DiarizationPipeline:
61
+ """End-to-end speaker diarization with pyannote-first fallback behavior."""
 
 
 
 
 
 
 
 
62
 
63
  SAMPLE_RATE = 16000
64
+ WINDOW_DURATION = 2.0
65
+ WINDOW_STEP = 1.0
66
+ MIN_SEGMENT_DURATION = 0.8
67
 
68
  def __init__(
69
  self,
70
  device: str = "auto",
71
  use_pyannote_vad: bool = True,
72
+ use_pyannote_diarization: bool = True,
73
+ pyannote_diarization_model: str = "pyannote/speaker-diarization-3.1",
74
  hf_token: Optional[str] = None,
75
  num_speakers: Optional[int] = None,
76
  max_speakers: int = 10,
 
78
  ):
79
  self.device = self._resolve_device(device)
80
  self.use_pyannote_vad = use_pyannote_vad
81
+ self.use_pyannote_diarization = use_pyannote_diarization
82
+ self.pyannote_diarization_model = pyannote_diarization_model
83
  self.hf_token = hf_token
84
  self.num_speakers = num_speakers
85
  self.max_speakers = max_speakers
86
  self.cache_dir = Path(cache_dir)
87
 
88
  self.embedder = EcapaTDNNEmbedder(device=self.device, cache_dir=str(cache_dir))
89
+ self.clusterer = SpeakerClusterer(max_speakers=max_speakers, distance_threshold=0.55)
90
 
91
  self._vad_pipeline = None
92
+ self._full_diar_pipeline = None
93
  logger.info(f"DiarizationPipeline ready | device={self.device}")
94
 
95
  def _resolve_device(self, device: str) -> str:
 
98
  return device
99
 
100
  def _to_mono_1d(self, audio: torch.Tensor) -> torch.Tensor:
 
101
  if audio.dim() == 1:
102
  return audio
103
  if audio.dim() >= 2:
 
106
  return audio.mean(dim=0)
107
  return audio.reshape(-1)
108
 
109
+ def _load_pyannote_pipeline(self, model_id: str):
110
+ from pyannote.audio import Pipeline
111
+
112
+ try:
113
+ if self.hf_token:
114
+ try:
115
+ pipeline = Pipeline.from_pretrained(model_id, use_auth_token=self.hf_token)
116
+ except TypeError:
117
+ pipeline = Pipeline.from_pretrained(model_id, token=self.hf_token)
118
+ else:
119
+ pipeline = Pipeline.from_pretrained(model_id)
120
+ except TypeError:
121
+ pipeline = Pipeline.from_pretrained(model_id)
122
+
123
+ if pipeline is None:
124
+ raise RuntimeError(f"Pipeline.from_pretrained returned None for {model_id}")
125
+
126
+ try:
127
+ pipeline.to(torch.device(self.device))
128
+ except Exception:
129
+ pass
130
+
131
+ return pipeline
132
+
133
+ def _load_full_diarization(self):
134
+ if self._full_diar_pipeline is not None:
135
+ return
136
+ try:
137
+ logger.info(f"Loading pyannote diarization pipeline: {self.pyannote_diarization_model}")
138
+ self._full_diar_pipeline = self._load_pyannote_pipeline(self.pyannote_diarization_model)
139
+ logger.success("Pyannote speaker diarization pipeline loaded.")
140
+ except Exception as e:
141
+ logger.warning(f"Could not load pyannote diarization pipeline: {e}.")
142
+ self._full_diar_pipeline = "unavailable"
143
+
144
  def _load_vad(self):
145
  if self._vad_pipeline is not None:
146
  return
147
  try:
 
148
  logger.info("Loading pyannote VAD pipeline...")
149
+ self._vad_pipeline = self._load_pyannote_pipeline("pyannote/voice-activity-detection")
 
 
 
 
150
  logger.success("Pyannote VAD loaded.")
151
  except Exception as e:
152
  logger.warning(f"Could not load pyannote VAD: {e}. Falling back to energy-based VAD.")
153
  self._vad_pipeline = "energy"
154
 
155
+ def _merge_named_segments(
156
+ self, segments: List[DiarizationSegment], gap_tolerance: float = 0.35
157
+ ) -> List[DiarizationSegment]:
158
+ if not segments:
159
+ return []
160
+
161
+ merged = [segments[0]]
162
+ for seg in segments[1:]:
163
+ last = merged[-1]
164
+ if seg.speaker == last.speaker and seg.start - last.end <= gap_tolerance:
165
+ merged[-1] = DiarizationSegment(start=last.start, end=seg.end, speaker=last.speaker)
166
+ else:
167
+ merged.append(seg)
168
+ return merged
169
+
170
+ def _run_full_pyannote(
171
+ self,
172
+ audio: Union[str, Path, torch.Tensor],
173
+ sample_rate: int,
174
+ num_speakers: Optional[int],
175
+ audio_duration: float,
176
+ t_start: float,
177
+ ) -> Optional[DiarizationResult]:
178
+ if not self.use_pyannote_diarization:
179
+ return None
180
+
181
+ self._load_full_diarization()
182
+ if self._full_diar_pipeline == "unavailable":
183
+ return None
184
+
185
+ tmp_path = None
186
+ source = audio
187
+ try:
188
+ if not isinstance(audio, (str, Path)):
189
+ mono = self._to_mono_1d(audio).detach().cpu().float()
190
+ wav = mono.unsqueeze(0)
191
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
192
+ tmp_path = tmp.name
193
+ torchaudio.save(tmp_path, wav, sample_rate)
194
+ source = tmp_path
195
+
196
+ kwargs = {}
197
+ if num_speakers is not None:
198
+ kwargs["num_speakers"] = int(num_speakers)
199
+
200
+ diar_output = self._full_diar_pipeline(str(source), **kwargs)
201
+
202
+ raw_segments = []
203
+ speaker_map = {}
204
+ next_id = 0
205
+ for turn, _, speaker in diar_output.itertracks(yield_label=True):
206
+ start = float(turn.start)
207
+ end = float(turn.end)
208
+ if end - start < 0.2:
209
+ continue
210
+ if speaker not in speaker_map:
211
+ speaker_map[speaker] = f"SPEAKER_{next_id:02d}"
212
+ next_id += 1
213
+ raw_segments.append(
214
+ DiarizationSegment(start=start, end=end, speaker=speaker_map[speaker])
215
+ )
216
+
217
+ if not raw_segments:
218
+ return None
219
+
220
+ raw_segments.sort(key=lambda s: (s.start, s.end))
221
+ merged_segments = self._merge_named_segments(raw_segments)
222
+ num_unique = len(set(s.speaker for s in merged_segments))
223
+
224
+ logger.success(
225
+ f"Pyannote diarization complete: {num_unique} speakers, {len(merged_segments)} segments"
226
+ )
227
+ return DiarizationResult(
228
+ segments=merged_segments,
229
+ num_speakers=num_unique,
230
+ audio_duration=audio_duration,
231
+ processing_time=time.time() - t_start,
232
+ sample_rate=sample_rate,
233
+ )
234
+ except Exception as e:
235
+ logger.warning(f"Full pyannote diarization failed: {e}. Falling back to ECAPA+AHC.")
236
+ return None
237
+ finally:
238
+ if tmp_path:
239
+ Path(tmp_path).unlink(missing_ok=True)
240
+
241
  def _energy_vad(
242
  self, audio: torch.Tensor, frame_duration: float = 0.02, threshold_db: float = -40.0
243
  ) -> List[tuple]:
 
244
  frame_samples = int(frame_duration * self.SAMPLE_RATE)
245
  audio_np = audio.numpy()
246
  frames = [
 
320
  sample_rate: int = None,
321
  num_speakers: Optional[int] = None,
322
  ) -> DiarizationResult:
 
 
 
323
  t_start = time.time()
324
 
325
  if isinstance(audio, (str, Path)):
 
343
  sample_rate=sample_rate,
344
  )
345
 
346
+ k = num_speakers or self.num_speakers
347
+
348
+ pyannote_result = self._run_full_pyannote(
349
+ audio=audio,
350
+ sample_rate=sample_rate,
351
+ num_speakers=k,
352
+ audio_duration=audio_duration,
353
+ t_start=t_start,
354
+ )
355
+ if pyannote_result is not None:
356
+ return pyannote_result
357
+
358
  processed = self.embedder.preprocess_audio(audio_tensor, sample_rate)
359
 
360
  speech_regions = self._get_speech_regions(processed)
 
385
  sample_rate=sample_rate,
386
  )
387
 
 
388
  labels = self.clusterer.cluster(embeddings, num_speakers=k)
389
+ merged = self.clusterer.merge_consecutive_same_speaker(
390
+ valid_windows, labels, gap_tolerance=0.45
391
+ )
392
 
393
  speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)}
394
  segments = [
 
400
  processing_time = time.time() - t_start
401
 
402
  logger.success(
403
+ f"Fallback diarization complete: {num_unique} speakers, "
404
  f"{len(segments)} segments, {processing_time:.2f}s"
405
  )
406
 
 
411
  processing_time=processing_time,
412
  sample_rate=sample_rate,
413
  )
 
models/clusterer.py CHANGED
@@ -20,7 +20,7 @@ class SpeakerClusterer:
20
  def __init__(
21
  self,
22
  linkage_method: str = "average",
23
- distance_threshold: float = 0.7,
24
  min_speakers: int = 1,
25
  max_speakers: int = 10,
26
  ):
@@ -39,7 +39,7 @@ class SpeakerClusterer:
39
  if n <= 2:
40
  return n
41
 
42
- best_k = self.min_speakers
43
  best_score = -1.0
44
  upper_k = min(self.max_speakers, n - 1)
45
 
@@ -55,8 +55,24 @@ class SpeakerClusterer:
55
  except Exception:
56
  continue
57
 
58
- logger.info(f"Optimal speaker count: {best_k} (silhouette={best_score:.4f})")
59
- return best_k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def cluster(
62
  self,
 
20
  def __init__(
21
  self,
22
  linkage_method: str = "average",
23
+ distance_threshold: float = 0.55,
24
  min_speakers: int = 1,
25
  max_speakers: int = 10,
26
  ):
 
39
  if n <= 2:
40
  return n
41
 
42
+ best_k = max(2, self.min_speakers)
43
  best_score = -1.0
44
  upper_k = min(self.max_speakers, n - 1)
45
 
 
55
  except Exception:
56
  continue
57
 
58
+ threshold_labels = fcluster(
59
+ linkage_matrix,
60
+ t=self.distance_threshold,
61
+ criterion="distance",
62
+ )
63
+ k_threshold = len(np.unique(threshold_labels))
64
+ k_threshold = int(np.clip(k_threshold, self.min_speakers, min(self.max_speakers, n)))
65
+
66
+ if best_score < 0.08:
67
+ chosen_k = k_threshold
68
+ else:
69
+ chosen_k = max(best_k, k_threshold)
70
+
71
+ logger.info(
72
+ f"Optimal speaker count: {chosen_k} "
73
+ f"(silhouette_k={best_k}, silhouette={best_score:.4f}, threshold_k={k_threshold})"
74
+ )
75
+ return chosen_k
76
 
77
  def cluster(
78
  self,