Whisper-Transcriber / utils /diarization.py
Whisper Transcriber Bot
Initial commit: Complete Whisper Transcriber implementation
4051511
import os
import torch
from typing import Dict, List, Optional
from pyannote.audio import Pipeline
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SpeakerDiarizer:
"""Handles speaker diarization using pyannote.audio"""
def __init__(self, hf_token: Optional[str] = None):
"""
Initialize speaker diarization pipeline
Args:
hf_token: Hugging Face access token (required for pyannote models)
"""
self.hf_token = hf_token or os.getenv('HF_TOKEN')
if not self.hf_token:
logger.warning("No HF_TOKEN provided. Diarization may fail.")
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_pipeline(self, progress_callback=None):
"""Load the diarization pipeline"""
if progress_callback:
progress_callback("Loading speaker diarization model...")
try:
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=self.hf_token
)
# Move to GPU if available
if self.device == "cuda":
self.pipeline.to(torch.device("cuda"))
if progress_callback:
progress_callback("Diarization model loaded successfully")
logger.info(f"Diarization pipeline loaded on {self.device}")
except Exception as e:
logger.error(f"Failed to load diarization pipeline: {e}")
raise Exception(
f"Failed to load diarization model. "
f"Make sure you have accepted the terms at: "
f"https://huggingface.co/pyannote/speaker-diarization-3.1 "
f"and provided a valid HF_TOKEN. Error: {str(e)}"
)
def diarize(self, audio_path: str, progress_callback=None) -> Dict:
"""
Perform speaker diarization on audio file
Args:
audio_path: Path to audio file
progress_callback: Optional callback for progress updates
Returns:
Dictionary mapping time segments to speaker labels
"""
if self.pipeline is None:
self.load_pipeline(progress_callback)
if progress_callback:
progress_callback("Analyzing speakers in audio...")
try:
# Run diarization
diarization = self.pipeline(audio_path)
# Convert to dictionary of segments
segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append({
'start': turn.start,
'end': turn.end,
'speaker': speaker
})
if progress_callback:
num_speakers = len(set(seg['speaker'] for seg in segments))
progress_callback(f"Diarization complete. Found {num_speakers} speakers")
logger.info(f"Diarization found {len(segments)} segments")
return {'segments': segments}
except Exception as e:
logger.error(f"Diarization failed: {e}")
raise Exception(f"Speaker diarization failed: {str(e)}")
def align_with_transcription(
self,
diarization_result: Dict,
transcription_result: Dict,
progress_callback=None
) -> Dict[int, str]:
"""
Align speaker labels with transcription chunks
Args:
diarization_result: Result from diarize()
transcription_result: Result from transcription
progress_callback: Optional callback for progress updates
Returns:
Dictionary mapping chunk index to speaker label
"""
if progress_callback:
progress_callback("Aligning speakers with transcription...")
speaker_labels = {}
diarization_segments = diarization_result.get('segments', [])
transcription_chunks = transcription_result.get('chunks', [])
for chunk_idx, chunk in enumerate(transcription_chunks):
timestamp = chunk.get('timestamp', (None, None))
if timestamp[0] is None:
continue
chunk_start = timestamp[0]
chunk_end = timestamp[1] if timestamp[1] is not None else chunk_start + 1.0
# Find overlapping speaker segments
chunk_mid = (chunk_start + chunk_end) / 2
best_speaker = None
best_overlap = 0
for seg in diarization_segments:
seg_start = seg['start']
seg_end = seg['end']
# Check if chunk midpoint is in this segment
if seg_start <= chunk_mid <= seg_end:
best_speaker = seg['speaker']
break
# Calculate overlap
overlap_start = max(chunk_start, seg_start)
overlap_end = min(chunk_end, seg_end)
overlap = max(0, overlap_end - overlap_start)
if overlap > best_overlap:
best_overlap = overlap
best_speaker = seg['speaker']
if best_speaker:
speaker_labels[chunk_idx] = best_speaker
if progress_callback:
progress_callback("Speaker alignment complete")
logger.info(f"Aligned {len(speaker_labels)} chunks with speakers")
return speaker_labels
@staticmethod
def is_available() -> bool:
"""Check if diarization is available (HF_TOKEN set)"""
return os.getenv('HF_TOKEN') is not None