Who-Spoke-When / app /pipeline.py
ConvxO2's picture
Reduce speaker over-segmentation in auto clustering
789006e
"""
Speaker Diarization Pipeline
Combines: pyannote diarization (preferred) -> fallback VAD + ECAPA-TDNN + AHC clustering
"""
import tempfile
import time
from pathlib import Path
from typing import Optional, List, Union, BinaryIO
from dataclasses import dataclass, field
import numpy as np
import torch
import torchaudio
from loguru import logger
from models.embedder import EcapaTDNNEmbedder
from models.clusterer import SpeakerClusterer
@dataclass
class DiarizationSegment:
start: float
end: float
speaker: str
duration: float = field(init=False)
def __post_init__(self):
self.duration = round(self.end - self.start, 3)
def to_dict(self) -> dict:
return {
"start": round(self.start, 3),
"end": round(self.end, 3),
"duration": self.duration,
"speaker": self.speaker,
}
@dataclass
class DiarizationResult:
segments: List[DiarizationSegment]
num_speakers: int
audio_duration: float
processing_time: float
sample_rate: int
def to_dict(self) -> dict:
speakers = sorted(set(s.speaker for s in self.segments))
return {
"num_speakers": self.num_speakers,
"audio_duration": round(self.audio_duration, 3),
"processing_time": round(self.processing_time, 3),
"sample_rate": self.sample_rate,
"speakers": speakers,
"segments": [s.to_dict() for s in self.segments],
}
class DiarizationPipeline:
"""End-to-end speaker diarization with pyannote-first fallback behavior."""
SAMPLE_RATE = 16000
WINDOW_DURATION = 2.0
WINDOW_STEP = 1.0
MIN_SEGMENT_DURATION = 0.8
def __init__(
self,
device: str = "auto",
use_pyannote_vad: bool = True,
use_pyannote_diarization: bool = True,
pyannote_diarization_model: str = "pyannote/speaker-diarization-3.1",
hf_token: Optional[str] = None,
num_speakers: Optional[int] = None,
max_speakers: int = 6,
cache_dir: str = "./model_cache",
):
self.device = self._resolve_device(device)
self.use_pyannote_vad = use_pyannote_vad
self.use_pyannote_diarization = use_pyannote_diarization
self.pyannote_diarization_model = pyannote_diarization_model
self.hf_token = hf_token
self.num_speakers = num_speakers
self.max_speakers = max_speakers
self.cache_dir = Path(cache_dir)
self.embedder = EcapaTDNNEmbedder(device=self.device, cache_dir=str(cache_dir))
self.clusterer = SpeakerClusterer(max_speakers=max_speakers, distance_threshold=0.55)
self._vad_pipeline = None
self._full_diar_pipeline = None
logger.info(f"DiarizationPipeline ready | device={self.device}")
def _resolve_device(self, device: str) -> str:
if device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return device
def _to_mono_1d(self, audio: torch.Tensor) -> torch.Tensor:
if audio.dim() == 1:
return audio
if audio.dim() >= 2:
if audio.shape[0] == 1:
return audio[0]
return audio.mean(dim=0)
return audio.reshape(-1)
def _load_pyannote_pipeline(self, model_id: str):
from pyannote.audio import Pipeline
try:
if self.hf_token:
try:
pipeline = Pipeline.from_pretrained(model_id, use_auth_token=self.hf_token)
except TypeError:
pipeline = Pipeline.from_pretrained(model_id, token=self.hf_token)
else:
pipeline = Pipeline.from_pretrained(model_id)
except TypeError:
pipeline = Pipeline.from_pretrained(model_id)
if pipeline is None:
raise RuntimeError(f"Pipeline.from_pretrained returned None for {model_id}")
try:
pipeline.to(torch.device(self.device))
except Exception:
pass
return pipeline
def _load_full_diarization(self):
if self._full_diar_pipeline is not None:
return
try:
logger.info(f"Loading pyannote diarization pipeline: {self.pyannote_diarization_model}")
self._full_diar_pipeline = self._load_pyannote_pipeline(self.pyannote_diarization_model)
logger.success("Pyannote speaker diarization pipeline loaded.")
except Exception as e:
logger.warning(f"Could not load pyannote diarization pipeline: {e}.")
self._full_diar_pipeline = "unavailable"
def _load_vad(self):
if self._vad_pipeline is not None:
return
try:
logger.info("Loading pyannote VAD pipeline...")
self._vad_pipeline = self._load_pyannote_pipeline("pyannote/voice-activity-detection")
logger.success("Pyannote VAD loaded.")
except Exception as e:
logger.warning(f"Could not load pyannote VAD: {e}. Falling back to energy-based VAD.")
self._vad_pipeline = "energy"
def _merge_named_segments(
self, segments: List[DiarizationSegment], gap_tolerance: float = 0.35
) -> List[DiarizationSegment]:
if not segments:
return []
merged = [segments[0]]
for seg in segments[1:]:
last = merged[-1]
if seg.speaker == last.speaker and seg.start - last.end <= gap_tolerance:
merged[-1] = DiarizationSegment(start=last.start, end=seg.end, speaker=last.speaker)
else:
merged.append(seg)
return merged
def _run_full_pyannote(
self,
audio: Union[str, Path, torch.Tensor],
sample_rate: int,
num_speakers: Optional[int],
audio_duration: float,
t_start: float,
) -> Optional[DiarizationResult]:
if not self.use_pyannote_diarization:
return None
self._load_full_diarization()
if self._full_diar_pipeline == "unavailable":
return None
tmp_path = None
source = audio
try:
if not isinstance(audio, (str, Path)):
mono = self._to_mono_1d(audio).detach().cpu().float()
wav = mono.unsqueeze(0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_path = tmp.name
torchaudio.save(tmp_path, wav, sample_rate)
source = tmp_path
kwargs = {}
if num_speakers is not None:
kwargs["num_speakers"] = int(num_speakers)
diar_output = self._full_diar_pipeline(str(source), **kwargs)
raw_segments = []
speaker_map = {}
next_id = 0
for turn, _, speaker in diar_output.itertracks(yield_label=True):
start = float(turn.start)
end = float(turn.end)
if end - start < 0.2:
continue
if speaker not in speaker_map:
speaker_map[speaker] = f"SPEAKER_{next_id:02d}"
next_id += 1
raw_segments.append(
DiarizationSegment(start=start, end=end, speaker=speaker_map[speaker])
)
if not raw_segments:
return None
raw_segments.sort(key=lambda s: (s.start, s.end))
merged_segments = self._merge_named_segments(raw_segments)
num_unique = len(set(s.speaker for s in merged_segments))
logger.success(
f"Pyannote diarization complete: {num_unique} speakers, {len(merged_segments)} segments"
)
return DiarizationResult(
segments=merged_segments,
num_speakers=num_unique,
audio_duration=audio_duration,
processing_time=time.time() - t_start,
sample_rate=sample_rate,
)
except Exception as e:
logger.warning(f"Full pyannote diarization failed: {e}. Falling back to ECAPA+AHC.")
return None
finally:
if tmp_path:
Path(tmp_path).unlink(missing_ok=True)
def _energy_vad(
self, audio: torch.Tensor, frame_duration: float = 0.02, threshold_db: float = -40.0
) -> List[tuple]:
frame_samples = int(frame_duration * self.SAMPLE_RATE)
audio_np = audio.numpy()
frames = [
audio_np[i: i + frame_samples]
for i in range(0, len(audio_np) - frame_samples, frame_samples)
]
energies_db = []
for frame in frames:
rms = np.sqrt(np.mean(frame ** 2) + 1e-10)
energies_db.append(20 * np.log10(rms))
is_speech = np.array(energies_db) > threshold_db
speech_regions = []
in_speech = False
start = 0.0
for i, active in enumerate(is_speech):
t = i * frame_duration
if active and not in_speech:
start = t
in_speech = True
elif not active and in_speech:
speech_regions.append((start, t))
in_speech = False
if in_speech:
speech_regions.append((start, len(audio_np) / self.SAMPLE_RATE))
return speech_regions
def _get_speech_regions(self, audio: torch.Tensor) -> List[tuple]:
if self.use_pyannote_vad:
self._load_vad()
if self._vad_pipeline == "energy" or not self.use_pyannote_vad:
return self._energy_vad(audio)
try:
audio_dict = {
"waveform": audio.unsqueeze(0).to(self.device),
"sample_rate": self.SAMPLE_RATE,
}
vad_output = self._vad_pipeline(audio_dict)
regions = [(seg.start, seg.end) for seg in vad_output.get_timeline().support()]
logger.info(f"Pyannote VAD: {len(regions)} speech regions found")
return regions
except Exception as e:
logger.warning(f"Pyannote VAD failed: {e}. Using energy VAD.")
return self._energy_vad(audio)
def _sliding_window_segments(self, speech_regions: List[tuple]) -> List[tuple]:
segments = []
for region_start, region_end in speech_regions:
duration = region_end - region_start
if duration < self.MIN_SEGMENT_DURATION:
continue
t = region_start
while t + self.WINDOW_DURATION <= region_end:
segments.append((t, t + self.WINDOW_DURATION))
t += self.WINDOW_STEP
if region_end - t >= self.MIN_SEGMENT_DURATION:
segments.append((t, region_end))
return segments
def load_audio(self, path: Union[str, Path, BinaryIO]) -> tuple:
waveform, sample_rate = torchaudio.load(path)
return waveform, sample_rate
def process(
self,
audio: Union[str, Path, torch.Tensor],
sample_rate: int = None,
num_speakers: Optional[int] = None,
) -> DiarizationResult:
t_start = time.time()
if isinstance(audio, (str, Path)):
waveform, sample_rate = self.load_audio(audio)
audio_tensor = self._to_mono_1d(waveform)
else:
assert sample_rate is not None, "sample_rate required when passing tensor"
audio_tensor = self._to_mono_1d(audio)
num_samples = int(audio_tensor.numel())
audio_duration = num_samples / float(sample_rate)
logger.info(f"Processing {audio_duration:.1f}s audio at {sample_rate}Hz")
if num_samples == 0:
logger.warning("Received empty audio input.")
return DiarizationResult(
segments=[],
num_speakers=0,
audio_duration=0.0,
processing_time=time.time() - t_start,
sample_rate=sample_rate,
)
k = num_speakers or self.num_speakers
pyannote_result = self._run_full_pyannote(
audio=audio,
sample_rate=sample_rate,
num_speakers=k,
audio_duration=audio_duration,
t_start=t_start,
)
if pyannote_result is not None:
return pyannote_result
processed = self.embedder.preprocess_audio(audio_tensor, sample_rate)
speech_regions = self._get_speech_regions(processed)
if not speech_regions:
logger.warning("No speech detected in audio.")
return DiarizationResult(
segments=[],
num_speakers=0,
audio_duration=audio_duration,
processing_time=time.time() - t_start,
sample_rate=sample_rate,
)
windows = self._sliding_window_segments(speech_regions)
logger.info(f"Generated {len(windows)} embedding windows")
embeddings, valid_windows = self.embedder.extract_embeddings_from_segments(
processed, self.SAMPLE_RATE, windows
)
if len(embeddings) == 0:
logger.warning("No valid embeddings extracted.")
return DiarizationResult(
segments=[],
num_speakers=0,
audio_duration=audio_duration,
processing_time=time.time() - t_start,
sample_rate=sample_rate,
)
labels = self.clusterer.cluster(embeddings, num_speakers=k)
merged = self.clusterer.merge_consecutive_same_speaker(
valid_windows, labels, gap_tolerance=0.45
)
speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)}
segments = [
DiarizationSegment(start=start, end=end, speaker=speaker_names[spk_id])
for start, end, spk_id in merged
]
num_unique = len(set(labels))
processing_time = time.time() - t_start
logger.success(
f"Fallback diarization complete: {num_unique} speakers, "
f"{len(segments)} segments, {processing_time:.2f}s"
)
return DiarizationResult(
segments=segments,
num_speakers=num_unique,
audio_duration=audio_duration,
processing_time=processing_time,
sample_rate=sample_rate,
)