daasime's picture
Fix permission denied: store models/data outside /app mount
9109931
"""
Speaker Diarization - identify who spoke when.
"""
import torch
import numpy as np
from typing import List, Dict, Optional
from dataclasses import dataclass, field
from sklearn.cluster import AgglomerativeClustering
@dataclass
class SpeakerSegment:
"""A segment of speech from a specific speaker."""
start: float
end: float
speaker_id: str
@property
def duration(self) -> float:
return self.end - self.start
@dataclass
class SpeakerInfo:
"""Information about a speaker."""
speaker_id: str
total_seconds: float = 0.0
segments: List[SpeakerSegment] = field(default_factory=list)
embedding: Optional[np.ndarray] = None
def add_segment(self, segment: SpeakerSegment):
self.segments.append(segment)
self.total_seconds += segment.duration
class SpeakerDiarizer:
"""Speaker diarization using embedding clustering."""
def __init__(self, device: str = None):
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self._embedding_model = None
@property
def embedding_model(self):
"""Lazy load embedding model."""
if self._embedding_model is None:
from speechbrain.inference.speaker import SpeakerRecognition
import os
model_dir = os.environ.get("MODEL_DIR", "pretrained_models")
self._embedding_model = SpeakerRecognition.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=os.path.join(model_dir, "spkrec"),
run_opts={"device": self.device}
)
return self._embedding_model
def diarize(self, audio_path: str,
speech_segments: List = None,
window_size: float = 2.0,
hop_size: float = 0.5,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5) -> Dict[str, SpeakerInfo]:
"""
Perform speaker diarization.
Args:
audio_path: Path to audio file
speech_segments: Optional list of speech segments (from VAD)
window_size: Window size for embedding extraction
hop_size: Hop size between windows
num_speakers: Known number of speakers (None to estimate)
min_speakers: Minimum speakers to detect
max_speakers: Maximum speakers to detect
Returns:
Dict mapping speaker_id to SpeakerInfo
"""
import torchaudio
# Load audio (use soundfile backend to avoid torchcodec dependency)
waveform, sample_rate = torchaudio.load(audio_path, backend="soundfile")
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
duration = waveform.shape[1] / sample_rate
# Extract embeddings for windows
windows = []
embeddings = []
current = 0.0
while current + window_size <= duration:
start_sample = int(current * sample_rate)
end_sample = int((current + window_size) * sample_rate)
window_audio = waveform[:, start_sample:end_sample]
# Check if this window has speech (if VAD provided)
has_speech = True
if speech_segments:
has_speech = any(
s.start <= current + window_size/2 <= s.end
for s in speech_segments
)
if has_speech and window_audio.shape[1] > 0:
# Extract embedding
emb = self.embedding_model.encode_batch(window_audio)
emb = emb.squeeze().cpu().numpy()
windows.append({'start': current, 'end': current + window_size})
embeddings.append(emb)
current += hop_size
if len(embeddings) < 2:
# Not enough data for clustering
speaker_info = SpeakerInfo(speaker_id="speaker_A")
for seg in (speech_segments or []):
speaker_info.add_segment(SpeakerSegment(
start=seg.start, end=seg.end, speaker_id="speaker_A"
))
if embeddings:
speaker_info.embedding = embeddings[0]
return {"speaker_A": speaker_info}
embeddings_array = np.array(embeddings)
# Cluster embeddings
if num_speakers is None:
# Estimate number of speakers
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=0.7,
metric='cosine',
linkage='average'
)
else:
clustering = AgglomerativeClustering(
n_clusters=num_speakers,
metric='cosine',
linkage='average'
)
labels = clustering.fit_predict(embeddings_array)
# Clamp number of speakers
unique_labels = np.unique(labels)
if len(unique_labels) > max_speakers:
# Re-cluster with max speakers
clustering = AgglomerativeClustering(
n_clusters=max_speakers,
metric='cosine',
linkage='average'
)
labels = clustering.fit_predict(embeddings_array)
unique_labels = np.unique(labels)
# Build speaker info
speakers = {}
speaker_names = ['speaker_A', 'speaker_B', 'speaker_C', 'speaker_D', 'speaker_E']
for label in unique_labels:
speaker_id = speaker_names[label] if label < len(speaker_names) else f"speaker_{label}"
speakers[speaker_id] = SpeakerInfo(speaker_id=speaker_id)
# Calculate mean embedding for this speaker
mask = labels == label
speaker_embeddings = embeddings_array[mask]
speakers[speaker_id].embedding = np.mean(speaker_embeddings, axis=0)
# Assign windows to speakers
for i, (window, label) in enumerate(zip(windows, labels)):
speaker_id = speaker_names[label] if label < len(speaker_names) else f"speaker_{label}"
segment = SpeakerSegment(
start=window['start'],
end=window['end'],
speaker_id=speaker_id
)
speakers[speaker_id].add_segment(segment)
# Sort by total speech time (main speaker first)
speakers = dict(sorted(
speakers.items(),
key=lambda x: x[1].total_seconds,
reverse=True
))
return speakers
def get_main_speaker(self, speakers: Dict[str, SpeakerInfo]) -> Optional[SpeakerInfo]:
"""Get the speaker with most speech time."""
if not speakers:
return None
return next(iter(speakers.values()))
def get_additional_speakers(self, speakers: Dict[str, SpeakerInfo]) -> List[SpeakerInfo]:
"""Get all speakers except the main one."""
items = list(speakers.values())
return items[1:] if len(items) > 1 else []