Spaces:
Sleeping
Sleeping
File size: 5,677 Bytes
4051511 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import torch
from typing import Dict, List, Optional
from pyannote.audio import Pipeline
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SpeakerDiarizer:
"""Handles speaker diarization using pyannote.audio"""
def __init__(self, hf_token: Optional[str] = None):
"""
Initialize speaker diarization pipeline
Args:
hf_token: Hugging Face access token (required for pyannote models)
"""
self.hf_token = hf_token or os.getenv('HF_TOKEN')
if not self.hf_token:
logger.warning("No HF_TOKEN provided. Diarization may fail.")
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_pipeline(self, progress_callback=None):
"""Load the diarization pipeline"""
if progress_callback:
progress_callback("Loading speaker diarization model...")
try:
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=self.hf_token
)
# Move to GPU if available
if self.device == "cuda":
self.pipeline.to(torch.device("cuda"))
if progress_callback:
progress_callback("Diarization model loaded successfully")
logger.info(f"Diarization pipeline loaded on {self.device}")
except Exception as e:
logger.error(f"Failed to load diarization pipeline: {e}")
raise Exception(
f"Failed to load diarization model. "
f"Make sure you have accepted the terms at: "
f"https://huggingface.co/pyannote/speaker-diarization-3.1 "
f"and provided a valid HF_TOKEN. Error: {str(e)}"
)
def diarize(self, audio_path: str, progress_callback=None) -> Dict:
"""
Perform speaker diarization on audio file
Args:
audio_path: Path to audio file
progress_callback: Optional callback for progress updates
Returns:
Dictionary mapping time segments to speaker labels
"""
if self.pipeline is None:
self.load_pipeline(progress_callback)
if progress_callback:
progress_callback("Analyzing speakers in audio...")
try:
# Run diarization
diarization = self.pipeline(audio_path)
# Convert to dictionary of segments
segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append({
'start': turn.start,
'end': turn.end,
'speaker': speaker
})
if progress_callback:
num_speakers = len(set(seg['speaker'] for seg in segments))
progress_callback(f"Diarization complete. Found {num_speakers} speakers")
logger.info(f"Diarization found {len(segments)} segments")
return {'segments': segments}
except Exception as e:
logger.error(f"Diarization failed: {e}")
raise Exception(f"Speaker diarization failed: {str(e)}")
def align_with_transcription(
self,
diarization_result: Dict,
transcription_result: Dict,
progress_callback=None
) -> Dict[int, str]:
"""
Align speaker labels with transcription chunks
Args:
diarization_result: Result from diarize()
transcription_result: Result from transcription
progress_callback: Optional callback for progress updates
Returns:
Dictionary mapping chunk index to speaker label
"""
if progress_callback:
progress_callback("Aligning speakers with transcription...")
speaker_labels = {}
diarization_segments = diarization_result.get('segments', [])
transcription_chunks = transcription_result.get('chunks', [])
for chunk_idx, chunk in enumerate(transcription_chunks):
timestamp = chunk.get('timestamp', (None, None))
if timestamp[0] is None:
continue
chunk_start = timestamp[0]
chunk_end = timestamp[1] if timestamp[1] is not None else chunk_start + 1.0
# Find overlapping speaker segments
chunk_mid = (chunk_start + chunk_end) / 2
best_speaker = None
best_overlap = 0
for seg in diarization_segments:
seg_start = seg['start']
seg_end = seg['end']
# Check if chunk midpoint is in this segment
if seg_start <= chunk_mid <= seg_end:
best_speaker = seg['speaker']
break
# Calculate overlap
overlap_start = max(chunk_start, seg_start)
overlap_end = min(chunk_end, seg_end)
overlap = max(0, overlap_end - overlap_start)
if overlap > best_overlap:
best_overlap = overlap
best_speaker = seg['speaker']
if best_speaker:
speaker_labels[chunk_idx] = best_speaker
if progress_callback:
progress_callback("Speaker alignment complete")
logger.info(f"Aligned {len(speaker_labels)} chunks with speakers")
return speaker_labels
@staticmethod
def is_available() -> bool:
"""Check if diarization is available (HF_TOKEN set)"""
return os.getenv('HF_TOKEN') is not None
|