ConvxO2 commited on
Commit
789006e
·
1 Parent(s): 411e5d6

Reduce speaker over-segmentation in auto clustering

Browse files
Files changed (3) hide show
  1. app/main.py +2 -1
  2. app/pipeline.py +2 -1
  3. models/clusterer.py +9 -4
app/main.py CHANGED
@@ -74,7 +74,7 @@ def get_pipeline():
74
  use_pyannote_diarization=os.getenv("USE_PYANNOTE_DIARIZATION", "true").lower() in {"1", "true", "yes"},
75
  pyannote_diarization_model=os.getenv("PYANNOTE_DIARIZATION_MODEL", "pyannote/speaker-diarization-3.1"),
76
  hf_token=os.getenv("HF_TOKEN"),
77
- max_speakers=10,
78
  cache_dir=cache_dir,
79
  )
80
  return _pipeline
@@ -287,3 +287,4 @@ if static_dir.exists():
287
  app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
288
 
289
 
 
 
74
  use_pyannote_diarization=os.getenv("USE_PYANNOTE_DIARIZATION", "true").lower() in {"1", "true", "yes"},
75
  pyannote_diarization_model=os.getenv("PYANNOTE_DIARIZATION_MODEL", "pyannote/speaker-diarization-3.1"),
76
  hf_token=os.getenv("HF_TOKEN"),
77
+ max_speakers=int(os.getenv("MAX_SPEAKERS", "6")),
78
  cache_dir=cache_dir,
79
  )
80
  return _pipeline
 
287
  app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
288
 
289
 
290
+
app/pipeline.py CHANGED
@@ -73,7 +73,7 @@ class DiarizationPipeline:
73
  pyannote_diarization_model: str = "pyannote/speaker-diarization-3.1",
74
  hf_token: Optional[str] = None,
75
  num_speakers: Optional[int] = None,
76
- max_speakers: int = 10,
77
  cache_dir: str = "./model_cache",
78
  ):
79
  self.device = self._resolve_device(device)
@@ -411,3 +411,4 @@ class DiarizationPipeline:
411
  processing_time=processing_time,
412
  sample_rate=sample_rate,
413
  )
 
 
73
  pyannote_diarization_model: str = "pyannote/speaker-diarization-3.1",
74
  hf_token: Optional[str] = None,
75
  num_speakers: Optional[int] = None,
76
+ max_speakers: int = 6,
77
  cache_dir: str = "./model_cache",
78
  ):
79
  self.device = self._resolve_device(device)
 
411
  processing_time=processing_time,
412
  sample_rate=sample_rate,
413
  )
414
+
models/clusterer.py CHANGED
@@ -39,11 +39,13 @@ class SpeakerClusterer:
39
  if n <= 2:
40
  return n
41
 
42
- best_k = max(2, self.min_speakers)
43
- best_score = -1.0
44
  upper_k = min(self.max_speakers, n - 1)
45
 
46
- for k in range(max(2, self.min_speakers), upper_k + 1):
 
 
 
47
  labels = fcluster(linkage_matrix, k, criterion="maxclust")
48
  if len(np.unique(labels)) < 2:
49
  continue
@@ -63,10 +65,13 @@ class SpeakerClusterer:
63
  k_threshold = len(np.unique(threshold_labels))
64
  k_threshold = int(np.clip(k_threshold, self.min_speakers, min(self.max_speakers, n)))
65
 
 
66
  if best_score < 0.08:
67
  chosen_k = k_threshold
68
  else:
69
- chosen_k = max(best_k, k_threshold)
 
 
70
 
71
  logger.info(
72
  f"Optimal speaker count: {chosen_k} "
 
39
  if n <= 2:
40
  return n
41
 
42
+ min_k = max(2, self.min_speakers)
 
43
  upper_k = min(self.max_speakers, n - 1)
44
 
45
+ best_k = min_k
46
+ best_score = -1.0
47
+
48
+ for k in range(min_k, upper_k + 1):
49
  labels = fcluster(linkage_matrix, k, criterion="maxclust")
50
  if len(np.unique(labels)) < 2:
51
  continue
 
65
  k_threshold = len(np.unique(threshold_labels))
66
  k_threshold = int(np.clip(k_threshold, self.min_speakers, min(self.max_speakers, n)))
67
 
68
+ # Be conservative to avoid severe over-segmentation in open-domain audio.
69
  if best_score < 0.08:
70
  chosen_k = k_threshold
71
  else:
72
+ chosen_k = min(best_k, k_threshold) if k_threshold >= 2 else best_k
73
+
74
+ chosen_k = int(np.clip(chosen_k, self.min_speakers, min(self.max_speakers, n)))
75
 
76
  logger.info(
77
  f"Optimal speaker count: {chosen_k} "