Spaces:
Running
Running
Add pyannote-first diarization path and tune fallback clustering
Browse files- app/main.py +5 -1
- app/pipeline.py +155 -33
- models/clusterer.py +20 -4
app/main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
import tempfile
|
|
@@ -71,6 +71,8 @@ def get_pipeline():
|
|
| 71 |
_pipeline = DiarizationPipeline(
|
| 72 |
device="auto",
|
| 73 |
use_pyannote_vad=True,
|
|
|
|
|
|
|
| 74 |
hf_token=os.getenv("HF_TOKEN"),
|
| 75 |
max_speakers=10,
|
| 76 |
cache_dir=cache_dir,
|
|
@@ -283,3 +285,5 @@ async def debug():
|
|
| 283 |
static_dir = Path(__file__).resolve().parent.parent / "static"
|
| 284 |
if static_dir.exists():
|
| 285 |
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Speaker Diarization API - FastAPI Application."""
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
import tempfile
|
|
|
|
| 71 |
_pipeline = DiarizationPipeline(
|
| 72 |
device="auto",
|
| 73 |
use_pyannote_vad=True,
|
| 74 |
+
use_pyannote_diarization=os.getenv("USE_PYANNOTE_DIARIZATION", "true").lower() in {"1", "true", "yes"},
|
| 75 |
+
pyannote_diarization_model=os.getenv("PYANNOTE_DIARIZATION_MODEL", "pyannote/speaker-diarization-3.1"),
|
| 76 |
hf_token=os.getenv("HF_TOKEN"),
|
| 77 |
max_speakers=10,
|
| 78 |
cache_dir=cache_dir,
|
|
|
|
| 285 |
static_dir = Path(__file__).resolve().parent.parent / "static"
|
| 286 |
if static_dir.exists():
|
| 287 |
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
| 288 |
+
|
| 289 |
+
|
app/pipeline.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
"""
|
| 2 |
Speaker Diarization Pipeline
|
| 3 |
-
Combines:
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import numpy as np
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Optional, List, Union, BinaryIO
|
| 11 |
from dataclasses import dataclass, field
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from loguru import logger
|
| 13 |
|
| 14 |
from models.embedder import EcapaTDNNEmbedder
|
|
@@ -55,25 +58,19 @@ class DiarizationResult:
|
|
| 55 |
|
| 56 |
|
| 57 |
class DiarizationPipeline:
|
| 58 |
-
"""
|
| 59 |
-
End-to-end speaker diarization pipeline.
|
| 60 |
-
1. Audio loading & preprocessing
|
| 61 |
-
2. Voice Activity Detection (VAD) via pyannote or energy-based fallback
|
| 62 |
-
3. Sliding-window segmentation of speech regions
|
| 63 |
-
4. ECAPA-TDNN speaker embedding extraction per segment
|
| 64 |
-
5. Agglomerative Hierarchical Clustering
|
| 65 |
-
6. Post-processing: merge consecutive same-speaker segments
|
| 66 |
-
"""
|
| 67 |
|
| 68 |
SAMPLE_RATE = 16000
|
| 69 |
-
WINDOW_DURATION =
|
| 70 |
-
WINDOW_STEP =
|
| 71 |
-
MIN_SEGMENT_DURATION = 0.
|
| 72 |
|
| 73 |
def __init__(
|
| 74 |
self,
|
| 75 |
device: str = "auto",
|
| 76 |
use_pyannote_vad: bool = True,
|
|
|
|
|
|
|
| 77 |
hf_token: Optional[str] = None,
|
| 78 |
num_speakers: Optional[int] = None,
|
| 79 |
max_speakers: int = 10,
|
|
@@ -81,15 +78,18 @@ class DiarizationPipeline:
|
|
| 81 |
):
|
| 82 |
self.device = self._resolve_device(device)
|
| 83 |
self.use_pyannote_vad = use_pyannote_vad
|
|
|
|
|
|
|
| 84 |
self.hf_token = hf_token
|
| 85 |
self.num_speakers = num_speakers
|
| 86 |
self.max_speakers = max_speakers
|
| 87 |
self.cache_dir = Path(cache_dir)
|
| 88 |
|
| 89 |
self.embedder = EcapaTDNNEmbedder(device=self.device, cache_dir=str(cache_dir))
|
| 90 |
-
self.clusterer = SpeakerClusterer(max_speakers=max_speakers)
|
| 91 |
|
| 92 |
self._vad_pipeline = None
|
|
|
|
| 93 |
logger.info(f"DiarizationPipeline ready | device={self.device}")
|
| 94 |
|
| 95 |
def _resolve_device(self, device: str) -> str:
|
|
@@ -98,7 +98,6 @@ class DiarizationPipeline:
|
|
| 98 |
return device
|
| 99 |
|
| 100 |
def _to_mono_1d(self, audio: torch.Tensor) -> torch.Tensor:
|
| 101 |
-
"""Convert waveform to a mono 1D tensor for duration and preprocessing."""
|
| 102 |
if audio.dim() == 1:
|
| 103 |
return audio
|
| 104 |
if audio.dim() >= 2:
|
|
@@ -107,26 +106,141 @@ class DiarizationPipeline:
|
|
| 107 |
return audio.mean(dim=0)
|
| 108 |
return audio.reshape(-1)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
def _load_vad(self):
|
| 111 |
if self._vad_pipeline is not None:
|
| 112 |
return
|
| 113 |
try:
|
| 114 |
-
from pyannote.audio import Pipeline
|
| 115 |
logger.info("Loading pyannote VAD pipeline...")
|
| 116 |
-
self._vad_pipeline =
|
| 117 |
-
"pyannote/voice-activity-detection",
|
| 118 |
-
use_auth_token=self.hf_token,
|
| 119 |
-
)
|
| 120 |
-
self._vad_pipeline.to(torch.device(self.device))
|
| 121 |
logger.success("Pyannote VAD loaded.")
|
| 122 |
except Exception as e:
|
| 123 |
logger.warning(f"Could not load pyannote VAD: {e}. Falling back to energy-based VAD.")
|
| 124 |
self._vad_pipeline = "energy"
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def _energy_vad(
|
| 127 |
self, audio: torch.Tensor, frame_duration: float = 0.02, threshold_db: float = -40.0
|
| 128 |
) -> List[tuple]:
|
| 129 |
-
"""Simple energy-based VAD as fallback."""
|
| 130 |
frame_samples = int(frame_duration * self.SAMPLE_RATE)
|
| 131 |
audio_np = audio.numpy()
|
| 132 |
frames = [
|
|
@@ -206,9 +320,6 @@ class DiarizationPipeline:
|
|
| 206 |
sample_rate: int = None,
|
| 207 |
num_speakers: Optional[int] = None,
|
| 208 |
) -> DiarizationResult:
|
| 209 |
-
"""Run full diarization pipeline on audio."""
|
| 210 |
-
import time
|
| 211 |
-
|
| 212 |
t_start = time.time()
|
| 213 |
|
| 214 |
if isinstance(audio, (str, Path)):
|
|
@@ -232,6 +343,18 @@ class DiarizationPipeline:
|
|
| 232 |
sample_rate=sample_rate,
|
| 233 |
)
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
processed = self.embedder.preprocess_audio(audio_tensor, sample_rate)
|
| 236 |
|
| 237 |
speech_regions = self._get_speech_regions(processed)
|
|
@@ -262,10 +385,10 @@ class DiarizationPipeline:
|
|
| 262 |
sample_rate=sample_rate,
|
| 263 |
)
|
| 264 |
|
| 265 |
-
k = num_speakers or self.num_speakers
|
| 266 |
labels = self.clusterer.cluster(embeddings, num_speakers=k)
|
| 267 |
-
|
| 268 |
-
|
|
|
|
| 269 |
|
| 270 |
speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)}
|
| 271 |
segments = [
|
|
@@ -277,7 +400,7 @@ class DiarizationPipeline:
|
|
| 277 |
processing_time = time.time() - t_start
|
| 278 |
|
| 279 |
logger.success(
|
| 280 |
-
f"
|
| 281 |
f"{len(segments)} segments, {processing_time:.2f}s"
|
| 282 |
)
|
| 283 |
|
|
@@ -288,4 +411,3 @@ class DiarizationPipeline:
|
|
| 288 |
processing_time=processing_time,
|
| 289 |
sample_rate=sample_rate,
|
| 290 |
)
|
| 291 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
Speaker Diarization Pipeline
|
| 3 |
+
Combines: pyannote diarization (preferred) -> fallback VAD + ECAPA-TDNN + AHC clustering
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
import tempfile
|
| 7 |
+
import time
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Optional, List, Union, BinaryIO
|
| 10 |
from dataclasses import dataclass, field
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torchaudio
|
| 15 |
from loguru import logger
|
| 16 |
|
| 17 |
from models.embedder import EcapaTDNNEmbedder
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
class DiarizationPipeline:
|
| 61 |
+
"""End-to-end speaker diarization with pyannote-first fallback behavior."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
SAMPLE_RATE = 16000
|
| 64 |
+
WINDOW_DURATION = 2.0
|
| 65 |
+
WINDOW_STEP = 1.0
|
| 66 |
+
MIN_SEGMENT_DURATION = 0.8
|
| 67 |
|
| 68 |
def __init__(
|
| 69 |
self,
|
| 70 |
device: str = "auto",
|
| 71 |
use_pyannote_vad: bool = True,
|
| 72 |
+
use_pyannote_diarization: bool = True,
|
| 73 |
+
pyannote_diarization_model: str = "pyannote/speaker-diarization-3.1",
|
| 74 |
hf_token: Optional[str] = None,
|
| 75 |
num_speakers: Optional[int] = None,
|
| 76 |
max_speakers: int = 10,
|
|
|
|
| 78 |
):
|
| 79 |
self.device = self._resolve_device(device)
|
| 80 |
self.use_pyannote_vad = use_pyannote_vad
|
| 81 |
+
self.use_pyannote_diarization = use_pyannote_diarization
|
| 82 |
+
self.pyannote_diarization_model = pyannote_diarization_model
|
| 83 |
self.hf_token = hf_token
|
| 84 |
self.num_speakers = num_speakers
|
| 85 |
self.max_speakers = max_speakers
|
| 86 |
self.cache_dir = Path(cache_dir)
|
| 87 |
|
| 88 |
self.embedder = EcapaTDNNEmbedder(device=self.device, cache_dir=str(cache_dir))
|
| 89 |
+
self.clusterer = SpeakerClusterer(max_speakers=max_speakers, distance_threshold=0.55)
|
| 90 |
|
| 91 |
self._vad_pipeline = None
|
| 92 |
+
self._full_diar_pipeline = None
|
| 93 |
logger.info(f"DiarizationPipeline ready | device={self.device}")
|
| 94 |
|
| 95 |
def _resolve_device(self, device: str) -> str:
|
|
|
|
| 98 |
return device
|
| 99 |
|
| 100 |
def _to_mono_1d(self, audio: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 101 |
if audio.dim() == 1:
|
| 102 |
return audio
|
| 103 |
if audio.dim() >= 2:
|
|
|
|
| 106 |
return audio.mean(dim=0)
|
| 107 |
return audio.reshape(-1)
|
| 108 |
|
| 109 |
+
def _load_pyannote_pipeline(self, model_id: str):
|
| 110 |
+
from pyannote.audio import Pipeline
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
if self.hf_token:
|
| 114 |
+
try:
|
| 115 |
+
pipeline = Pipeline.from_pretrained(model_id, use_auth_token=self.hf_token)
|
| 116 |
+
except TypeError:
|
| 117 |
+
pipeline = Pipeline.from_pretrained(model_id, token=self.hf_token)
|
| 118 |
+
else:
|
| 119 |
+
pipeline = Pipeline.from_pretrained(model_id)
|
| 120 |
+
except TypeError:
|
| 121 |
+
pipeline = Pipeline.from_pretrained(model_id)
|
| 122 |
+
|
| 123 |
+
if pipeline is None:
|
| 124 |
+
raise RuntimeError(f"Pipeline.from_pretrained returned None for {model_id}")
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
pipeline.to(torch.device(self.device))
|
| 128 |
+
except Exception:
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
return pipeline
|
| 132 |
+
|
| 133 |
+
def _load_full_diarization(self):
|
| 134 |
+
if self._full_diar_pipeline is not None:
|
| 135 |
+
return
|
| 136 |
+
try:
|
| 137 |
+
logger.info(f"Loading pyannote diarization pipeline: {self.pyannote_diarization_model}")
|
| 138 |
+
self._full_diar_pipeline = self._load_pyannote_pipeline(self.pyannote_diarization_model)
|
| 139 |
+
logger.success("Pyannote speaker diarization pipeline loaded.")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.warning(f"Could not load pyannote diarization pipeline: {e}.")
|
| 142 |
+
self._full_diar_pipeline = "unavailable"
|
| 143 |
+
|
| 144 |
def _load_vad(self):
|
| 145 |
if self._vad_pipeline is not None:
|
| 146 |
return
|
| 147 |
try:
|
|
|
|
| 148 |
logger.info("Loading pyannote VAD pipeline...")
|
| 149 |
+
self._vad_pipeline = self._load_pyannote_pipeline("pyannote/voice-activity-detection")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
logger.success("Pyannote VAD loaded.")
|
| 151 |
except Exception as e:
|
| 152 |
logger.warning(f"Could not load pyannote VAD: {e}. Falling back to energy-based VAD.")
|
| 153 |
self._vad_pipeline = "energy"
|
| 154 |
|
| 155 |
+
def _merge_named_segments(
|
| 156 |
+
self, segments: List[DiarizationSegment], gap_tolerance: float = 0.35
|
| 157 |
+
) -> List[DiarizationSegment]:
|
| 158 |
+
if not segments:
|
| 159 |
+
return []
|
| 160 |
+
|
| 161 |
+
merged = [segments[0]]
|
| 162 |
+
for seg in segments[1:]:
|
| 163 |
+
last = merged[-1]
|
| 164 |
+
if seg.speaker == last.speaker and seg.start - last.end <= gap_tolerance:
|
| 165 |
+
merged[-1] = DiarizationSegment(start=last.start, end=seg.end, speaker=last.speaker)
|
| 166 |
+
else:
|
| 167 |
+
merged.append(seg)
|
| 168 |
+
return merged
|
| 169 |
+
|
| 170 |
+
def _run_full_pyannote(
|
| 171 |
+
self,
|
| 172 |
+
audio: Union[str, Path, torch.Tensor],
|
| 173 |
+
sample_rate: int,
|
| 174 |
+
num_speakers: Optional[int],
|
| 175 |
+
audio_duration: float,
|
| 176 |
+
t_start: float,
|
| 177 |
+
) -> Optional[DiarizationResult]:
|
| 178 |
+
if not self.use_pyannote_diarization:
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
self._load_full_diarization()
|
| 182 |
+
if self._full_diar_pipeline == "unavailable":
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
tmp_path = None
|
| 186 |
+
source = audio
|
| 187 |
+
try:
|
| 188 |
+
if not isinstance(audio, (str, Path)):
|
| 189 |
+
mono = self._to_mono_1d(audio).detach().cpu().float()
|
| 190 |
+
wav = mono.unsqueeze(0)
|
| 191 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 192 |
+
tmp_path = tmp.name
|
| 193 |
+
torchaudio.save(tmp_path, wav, sample_rate)
|
| 194 |
+
source = tmp_path
|
| 195 |
+
|
| 196 |
+
kwargs = {}
|
| 197 |
+
if num_speakers is not None:
|
| 198 |
+
kwargs["num_speakers"] = int(num_speakers)
|
| 199 |
+
|
| 200 |
+
diar_output = self._full_diar_pipeline(str(source), **kwargs)
|
| 201 |
+
|
| 202 |
+
raw_segments = []
|
| 203 |
+
speaker_map = {}
|
| 204 |
+
next_id = 0
|
| 205 |
+
for turn, _, speaker in diar_output.itertracks(yield_label=True):
|
| 206 |
+
start = float(turn.start)
|
| 207 |
+
end = float(turn.end)
|
| 208 |
+
if end - start < 0.2:
|
| 209 |
+
continue
|
| 210 |
+
if speaker not in speaker_map:
|
| 211 |
+
speaker_map[speaker] = f"SPEAKER_{next_id:02d}"
|
| 212 |
+
next_id += 1
|
| 213 |
+
raw_segments.append(
|
| 214 |
+
DiarizationSegment(start=start, end=end, speaker=speaker_map[speaker])
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if not raw_segments:
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
raw_segments.sort(key=lambda s: (s.start, s.end))
|
| 221 |
+
merged_segments = self._merge_named_segments(raw_segments)
|
| 222 |
+
num_unique = len(set(s.speaker for s in merged_segments))
|
| 223 |
+
|
| 224 |
+
logger.success(
|
| 225 |
+
f"Pyannote diarization complete: {num_unique} speakers, {len(merged_segments)} segments"
|
| 226 |
+
)
|
| 227 |
+
return DiarizationResult(
|
| 228 |
+
segments=merged_segments,
|
| 229 |
+
num_speakers=num_unique,
|
| 230 |
+
audio_duration=audio_duration,
|
| 231 |
+
processing_time=time.time() - t_start,
|
| 232 |
+
sample_rate=sample_rate,
|
| 233 |
+
)
|
| 234 |
+
except Exception as e:
|
| 235 |
+
logger.warning(f"Full pyannote diarization failed: {e}. Falling back to ECAPA+AHC.")
|
| 236 |
+
return None
|
| 237 |
+
finally:
|
| 238 |
+
if tmp_path:
|
| 239 |
+
Path(tmp_path).unlink(missing_ok=True)
|
| 240 |
+
|
| 241 |
def _energy_vad(
|
| 242 |
self, audio: torch.Tensor, frame_duration: float = 0.02, threshold_db: float = -40.0
|
| 243 |
) -> List[tuple]:
|
|
|
|
| 244 |
frame_samples = int(frame_duration * self.SAMPLE_RATE)
|
| 245 |
audio_np = audio.numpy()
|
| 246 |
frames = [
|
|
|
|
| 320 |
sample_rate: int = None,
|
| 321 |
num_speakers: Optional[int] = None,
|
| 322 |
) -> DiarizationResult:
|
|
|
|
|
|
|
|
|
|
| 323 |
t_start = time.time()
|
| 324 |
|
| 325 |
if isinstance(audio, (str, Path)):
|
|
|
|
| 343 |
sample_rate=sample_rate,
|
| 344 |
)
|
| 345 |
|
| 346 |
+
k = num_speakers or self.num_speakers
|
| 347 |
+
|
| 348 |
+
pyannote_result = self._run_full_pyannote(
|
| 349 |
+
audio=audio,
|
| 350 |
+
sample_rate=sample_rate,
|
| 351 |
+
num_speakers=k,
|
| 352 |
+
audio_duration=audio_duration,
|
| 353 |
+
t_start=t_start,
|
| 354 |
+
)
|
| 355 |
+
if pyannote_result is not None:
|
| 356 |
+
return pyannote_result
|
| 357 |
+
|
| 358 |
processed = self.embedder.preprocess_audio(audio_tensor, sample_rate)
|
| 359 |
|
| 360 |
speech_regions = self._get_speech_regions(processed)
|
|
|
|
| 385 |
sample_rate=sample_rate,
|
| 386 |
)
|
| 387 |
|
|
|
|
| 388 |
labels = self.clusterer.cluster(embeddings, num_speakers=k)
|
| 389 |
+
merged = self.clusterer.merge_consecutive_same_speaker(
|
| 390 |
+
valid_windows, labels, gap_tolerance=0.45
|
| 391 |
+
)
|
| 392 |
|
| 393 |
speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)}
|
| 394 |
segments = [
|
|
|
|
| 400 |
processing_time = time.time() - t_start
|
| 401 |
|
| 402 |
logger.success(
|
| 403 |
+
f"Fallback diarization complete: {num_unique} speakers, "
|
| 404 |
f"{len(segments)} segments, {processing_time:.2f}s"
|
| 405 |
)
|
| 406 |
|
|
|
|
| 411 |
processing_time=processing_time,
|
| 412 |
sample_rate=sample_rate,
|
| 413 |
)
|
|
|
models/clusterer.py
CHANGED
|
@@ -20,7 +20,7 @@ class SpeakerClusterer:
|
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
linkage_method: str = "average",
|
| 23 |
-
distance_threshold: float = 0.
|
| 24 |
min_speakers: int = 1,
|
| 25 |
max_speakers: int = 10,
|
| 26 |
):
|
|
@@ -39,7 +39,7 @@ class SpeakerClusterer:
|
|
| 39 |
if n <= 2:
|
| 40 |
return n
|
| 41 |
|
| 42 |
-
best_k = self.min_speakers
|
| 43 |
best_score = -1.0
|
| 44 |
upper_k = min(self.max_speakers, n - 1)
|
| 45 |
|
|
@@ -55,8 +55,24 @@ class SpeakerClusterer:
|
|
| 55 |
except Exception:
|
| 56 |
continue
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def cluster(
|
| 62 |
self,
|
|
|
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
linkage_method: str = "average",
|
| 23 |
+
distance_threshold: float = 0.55,
|
| 24 |
min_speakers: int = 1,
|
| 25 |
max_speakers: int = 10,
|
| 26 |
):
|
|
|
|
| 39 |
if n <= 2:
|
| 40 |
return n
|
| 41 |
|
| 42 |
+
best_k = max(2, self.min_speakers)
|
| 43 |
best_score = -1.0
|
| 44 |
upper_k = min(self.max_speakers, n - 1)
|
| 45 |
|
|
|
|
| 55 |
except Exception:
|
| 56 |
continue
|
| 57 |
|
| 58 |
+
threshold_labels = fcluster(
|
| 59 |
+
linkage_matrix,
|
| 60 |
+
t=self.distance_threshold,
|
| 61 |
+
criterion="distance",
|
| 62 |
+
)
|
| 63 |
+
k_threshold = len(np.unique(threshold_labels))
|
| 64 |
+
k_threshold = int(np.clip(k_threshold, self.min_speakers, min(self.max_speakers, n)))
|
| 65 |
+
|
| 66 |
+
if best_score < 0.08:
|
| 67 |
+
chosen_k = k_threshold
|
| 68 |
+
else:
|
| 69 |
+
chosen_k = max(best_k, k_threshold)
|
| 70 |
+
|
| 71 |
+
logger.info(
|
| 72 |
+
f"Optimal speaker count: {chosen_k} "
|
| 73 |
+
f"(silhouette_k={best_k}, silhouette={best_score:.4f}, threshold_k={k_threshold})"
|
| 74 |
+
)
|
| 75 |
+
return chosen_k
|
| 76 |
|
| 77 |
def cluster(
|
| 78 |
self,
|