ustwo-api / src /stage1 /diarization.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
3.84 kB
"""Stage 1 — Speaker Diarization Module.
Uses pyannote-audio 4.x (speaker-diarization-3.1) to identify who spoke when.
"""
import logging
import os
from dataclasses import dataclass
import torch
import torchaudio
from pyannote.audio import Pipeline
logger = logging.getLogger(__name__)
@dataclass
class DiarSegment:
speaker_id: str
start: float
end: float
_pipeline: Pipeline | None = None
def _load_pipeline(model: str, device: str) -> Pipeline:
"""Load and cache the pyannote diarization pipeline."""
global _pipeline
if _pipeline is not None:
return _pipeline
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise RuntimeError(
"HF_TOKEN environment variable required for pyannote model download. "
"Set it with: export HF_TOKEN=your_huggingface_token"
)
logger.info("Loading pyannote pipeline: %s", model)
_pipeline = Pipeline.from_pretrained(model, token=hf_token)
torch_device = torch.device(device if torch.cuda.is_available() and "cuda" in device else "cpu")
_pipeline.to(torch_device)
logger.info("Pyannote pipeline loaded on %s", torch_device)
return _pipeline
def _merge_adjacent(
segments: list[DiarSegment], gap_threshold: float
) -> list[DiarSegment]:
"""Merge same-speaker segments closer than gap_threshold seconds."""
if not segments:
return segments
merged: list[DiarSegment] = [segments[0]]
for seg in segments[1:]:
prev = merged[-1]
if seg.speaker_id == prev.speaker_id and (seg.start - prev.end) < gap_threshold:
prev.end = max(prev.end, seg.end)
else:
merged.append(seg)
return merged
def diarize(audio_path: str, config: dict) -> list[DiarSegment]:
"""Run speaker diarization on audio.
Args:
audio_path: Path to 16kHz mono WAV file.
config: stage1.diarization config section.
Returns:
List of DiarSegment sorted by start time.
"""
model = config.get("model", "pyannote/speaker-diarization-3.1")
num_speakers = config.get("num_speakers", 2)
merge_gap = config.get("merge_gap_sec", 0.15)
device = config.get("device", "cuda:0")
pipeline = _load_pipeline(model, device)
logger.info("Running diarization (num_speakers=%d) on %s", num_speakers, audio_path)
# pyannote 4.x: pass in-memory waveform to avoid torchcodec dependency
waveform, sample_rate = torchaudio.load(audio_path)
audio_dict = {"waveform": waveform, "sample_rate": sample_rate}
output = pipeline(audio_dict, num_speakers=num_speakers)
# pyannote 4.x returns DiarizeOutput with .speaker_diarization Annotation
if hasattr(output, 'speaker_diarization'):
annotation = output.speaker_diarization
elif hasattr(output, 'itertracks'):
annotation = output
else:
raise RuntimeError(f"Unexpected pyannote output type: {type(output)}")
segments: list[DiarSegment] = []
for turn, _, speaker in annotation.itertracks(yield_label=True):
segments.append(DiarSegment(
speaker_id=speaker,
start=round(turn.start, 3),
end=round(turn.end, 3),
))
segments.sort(key=lambda s: s.start)
segments = _merge_adjacent(segments, merge_gap)
# Normalize speaker labels to speaker_0 / speaker_1
speaker_map: dict[str, str] = {}
for seg in segments:
if seg.speaker_id not in speaker_map:
speaker_map[seg.speaker_id] = f"speaker_{len(speaker_map)}"
seg.speaker_id = speaker_map[seg.speaker_id]
logger.info(
"Diarization complete: %d segments, %d speakers",
len(segments), len(speaker_map),
)
if len(speaker_map) < 2:
logger.warning("Only %d speaker(s) detected", len(speaker_map))
return segments