voice-tools / src /services /speaker_separation.py
jcudit's picture
jcudit HF Staff
fix: resolve ZeroGPU pickling errors across all audio processing services
3fb465f
"""
Speaker Separation Service
Performs speaker diarization and separation using pyannote.audio.
Extracts individual speakers from multi-speaker audio files.
"""
import json
import logging
import os
import time
from pathlib import Path
from typing import Callable, Dict, List, Optional
import numpy as np
import torch
try:
import spaces
except ImportError:
# Create a no-op decorator for environments without spaces package
class spaces:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
# Workaround for PyTorch 2.6+ weights_only security feature
# pyannote models are from trusted source (HuggingFace)
# Monkey-patch torch.load to use weights_only=False for pyannote models
_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
# Force weights_only=False since we trust pyannote models from HuggingFace
kwargs["weights_only"] = False
return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from ..config.gpu_config import GPUConfig
from ..lib.audio_io import (
AudioIOError,
convert_m4a_to_wav,
convert_wav_to_m4a,
extract_segment,
get_audio_duration,
read_audio,
write_audio,
)
from ..lib.progress import SPEAKER_SEPARATION_STAGES
from ..models.audio_segment import AudioSegment, SegmentType
from ..models.error_report import ErrorReport
from ..models.speaker_profile import SpeakerProfile
logger = logging.getLogger(__name__)
# Module-level function for GPU-accelerated diarization
# This avoids pickling issues with ZeroGPU by not depending on class instance state
@spaces.GPU(duration=90)
def _run_diarization_on_gpu(
audio_dict: Dict,
hf_token: str,
min_speakers: int,
max_speakers: int,
progress_callback: Optional[Callable] = None,
):
"""
Run diarization on GPU (or CPU if unavailable).
This is a module-level function to avoid pickling issues with ZeroGPU.
The pipeline is loaded fresh within this GPU context.
Args:
audio_dict: Audio data dict with 'waveform' and 'sample_rate'
hf_token: HuggingFace token for model access
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
progress_callback: Optional progress callback
Returns:
Diarization result from pyannote
"""
# Load pipeline fresh in GPU context (avoids pickling)
logger.info("Loading pyannote pipeline in GPU context...")
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token)
# Move to available device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
logger.info(f"Pipeline loaded on {device}")
try:
# Custom progress hook that bridges pyannote progress to our callback
class CustomProgressHook(ProgressHook):
def __init__(self, callback=None):
super().__init__()
self.callback = callback
def __call__(self, step_name, step_artefact, file=None, total=None, completed=None):
# Call parent to maintain pyannote's internal tracking
result = super().__call__(step_name, step_artefact, file, total, completed)
# Forward progress to our callback
if self.callback and completed is not None and total is not None and total > 0:
# Map step names to user-friendly descriptions
stage = SPEAKER_SEPARATION_STAGES.get(step_name, step_name)
# Calculate percentage within this step (0.0 to 1.0)
step_progress = completed / total
# Scale to 0.3-0.8 range (30% to 80% of overall progress)
overall_progress = 0.3 + (step_progress * 0.5)
self.callback(stage, overall_progress, 1.0)
return result
# Use custom hook for pyannote progress with callback forwarding
with CustomProgressHook(callback=progress_callback) as hook:
diarization = pipeline(
audio_dict, min_speakers=min_speakers, max_speakers=max_speakers, hook=hook
)
if progress_callback:
progress_callback("Speaker detection complete", 0.8, 1.0)
# Count speakers by iterating through speaker_diarization
speakers = set()
for turn, speaker in diarization.speaker_diarization:
speakers.add(speaker)
logger.info(f"Detected {len(speakers)} speakers: {', '.join(sorted(speakers))}")
return diarization
finally:
# Clean up
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
class SpeakerSeparationService:
"""
Service for speaker diarization and separation.
Uses pyannote.audio for speaker diarization to identify and separate
individual speakers from multi-speaker audio files.
"""
def __init__(self, hf_token: Optional[str] = None):
"""
Initialize speaker separation service.
Args:
hf_token: HuggingFace API token (required for pyannote models)
If None, will try to get from HF_TOKEN env var
Raises:
ValueError: If HuggingFace token not provided
"""
if hf_token is None:
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError(
"HuggingFace token required. Set HF_TOKEN environment "
"variable or pass hf_token parameter."
)
self.hf_token = hf_token
def convert_to_wav(self, input_path: str, sample_rate: int = 16000) -> str:
"""
Convert M4A/AAC to WAV for pyannote processing.
Args:
input_path: Path to M4A file
sample_rate: Target sample rate (default: 16000 for pyannote)
Returns:
Path to converted WAV file
"""
return convert_m4a_to_wav(input_path, sample_rate=sample_rate)
def separate_speakers(
self,
audio_path: str,
min_speakers: int = 2,
max_speakers: int = 5,
progress_callback: Optional[Callable] = None,
):
"""
Perform speaker diarization on audio file.
Args:
audio_path: Path to audio file (M4A or WAV)
min_speakers: Minimum number of speakers to detect
max_speakers: Maximum number of speakers to detect
progress_callback: Optional callback for progress updates
Returns:
Diarization result from pyannote
Raises:
AudioIOError: If file cannot be read
ValueError: If parameters are invalid
"""
if min_speakers > max_speakers:
raise ValueError(
f"min_speakers ({min_speakers}) cannot exceed max_speakers ({max_speakers})"
)
# Convert M4A to WAV if needed
audio_path = Path(audio_path)
if not audio_path.exists():
raise AudioIOError(f"Audio file not found: {audio_path}")
if audio_path.suffix.lower() in [".m4a", ".aac", ".mp4"]:
logger.info(f"Converting {audio_path.name} to WAV for processing...")
audio_path = Path(self.convert_to_wav(str(audio_path)))
# Run diarization with progress reporting
logger.info(f"Performing speaker diarization (min={min_speakers}, max={max_speakers})...")
if progress_callback:
progress_callback("Starting speaker detection", 0.0, 1.0)
# Load audio ourselves and pass as dict to avoid torchcodec issues
audio_data, sr = read_audio(str(audio_path), target_sr=16000)
audio_dict = {
"waveform": torch.from_numpy(audio_data).unsqueeze(0), # Add channel dimension
"sample_rate": sr,
}
# Call the module-level GPU function (avoids pickling self)
diarization = _run_diarization_on_gpu(
audio_dict=audio_dict,
hf_token=self.hf_token,
min_speakers=min_speakers,
max_speakers=max_speakers,
progress_callback=progress_callback,
)
return diarization
def extract_speaker_segments(self, diarization, speaker_id: str) -> List[AudioSegment]:
"""
Extract audio segments for a specific speaker.
Args:
diarization: Diarization result from pyannote
speaker_id: Speaker ID to extract (e.g., "SPEAKER_00")
Returns:
List of AudioSegment objects for this speaker
"""
segments = []
# pyannote.audio 4.0 API - iterate over speaker_diarization
for turn, speaker in diarization.speaker_diarization:
if speaker == speaker_id:
audio_segment = AudioSegment(
start_time=turn.start,
end_time=turn.end,
speaker_id=speaker_id,
confidence=1.0, # pyannote doesn't provide per-segment confidence
segment_type=SegmentType.SPEECH,
)
segments.append(audio_segment)
logger.debug(f"Extracted {len(segments)} segments for {speaker_id}")
return segments
def export_speaker_audio(
self,
audio: np.ndarray,
sample_rate: int,
output_path: str,
output_sample_rate: int = 44100,
bitrate: str = "192k",
) -> str:
"""
Export speaker audio to M4A format.
Args:
audio: Audio array
sample_rate: Input sample rate
output_path: Output M4A file path
output_sample_rate: Output sample rate (default: 44100)
bitrate: Output bitrate (default: "192k")
Returns:
Path to exported M4A file
"""
output_path = Path(output_path)
# Create output directory
output_path.parent.mkdir(parents=True, exist_ok=True)
# First write to temporary WAV
temp_wav = output_path.with_suffix(".temp.wav")
write_audio(str(temp_wav), audio, sample_rate)
# Convert to M4A
m4a_path = convert_wav_to_m4a(
str(temp_wav), str(output_path), sample_rate=output_sample_rate, bitrate=bitrate
)
# Clean up temp file
temp_wav.unlink()
logger.info(f"Exported speaker audio to {output_path.name}")
return m4a_path
def generate_separation_report(
self,
input_file: str,
speakers: List[str],
segments: Dict[str, List[AudioSegment]],
processing_time: float,
output_files: List[Dict],
input_duration: float,
) -> Dict:
"""
Generate separation report JSON.
Args:
input_file: Input file path
speakers: List of speaker IDs
segments: Dict mapping speaker IDs to their segments
processing_time: Processing time in seconds
output_files: List of output file information
input_duration: Input audio duration in seconds
Returns:
Report dictionary
"""
# Calculate quality metrics
total_segments = sum(len(segs) for segs in segments.values())
avg_confidence = sum(seg.confidence for segs in segments.values() for seg in segs) / max(
total_segments, 1
)
# Count overlapping segments
overlapping = 0
all_segs = [seg for segs in segments.values() for seg in segs]
for i, seg1 in enumerate(all_segs):
for seg2 in all_segs[i + 1 :]:
if seg1.overlaps_with(seg2):
overlapping += 1
report = {
"input_file": str(input_file),
"input_duration_seconds": input_duration,
"speakers_detected": len(speakers),
"processing_time_seconds": processing_time,
"output_files": output_files,
"overlapping_segments": overlapping,
"quality_metrics": {
"average_confidence": round(avg_confidence, 3),
"total_segments": total_segments,
"low_confidence_segments": sum(
1 for segs in segments.values() for seg in segs if seg.confidence < 0.7
),
},
}
return report
def separate_and_export(
self,
input_file: str,
output_dir: str,
min_speakers: int = 2,
max_speakers: int = 5,
output_format: str = "m4a",
sample_rate: int = 44100,
bitrate: str = "192k",
progress_callback: Optional[Callable] = None,
) -> Dict:
"""
Complete workflow: separate speakers and export to individual files.
Args:
input_file: Input M4A audio file
output_dir: Output directory for separated files
min_speakers: Minimum speakers to detect
max_speakers: Maximum speakers to detect
output_format: Output format - m4a, wav, or mp3 (default: "m4a")
sample_rate: Output sample rate (default: 44100)
bitrate: Output bitrate (default: "192k")
progress_callback: Optional progress callback
Returns:
Separation report dictionary or ErrorReport on failure
"""
start_time = time.time()
try:
input_file = Path(input_file)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Get input duration
input_duration = get_audio_duration(str(input_file))
except Exception as e:
logger.error(f"Failed to initialize speaker separation: {e}")
error_report: ErrorReport = {
"status": "failed",
"error": f"Failed to initialize speaker separation: {e}",
"error_type": "audio_io",
}
return error_report
try:
# Perform speaker diarization
if progress_callback:
progress_callback("Loading audio", 0.1, 1.0)
# Note: progress_callback cannot be passed due to ZeroGPU pickling constraints
diarization = self.separate_speakers(
str(input_file),
min_speakers=min_speakers,
max_speakers=max_speakers,
progress_callback=None, # Cannot pass callback to avoid pickling errors
)
except Exception as e:
logger.error(f"Speaker diarization failed: {e}")
error_report: ErrorReport = {
"status": "failed",
"error": f"Speaker diarization failed: {e}",
"error_type": "processing",
}
return error_report
try:
# Get unique speakers by iterating through speaker_diarization
speakers = set()
for turn, speaker in diarization.speaker_diarization:
speakers.add(speaker)
speakers = sorted(list(speakers))
# Extract segments for each speaker
segments = {}
for speaker_id in speakers:
segments[speaker_id] = self.extract_speaker_segments(diarization, speaker_id)
# Load full audio for extraction
if progress_callback:
progress_callback("Performing speaker diarization", 0.2, 1.0)
# Convert to WAV for processing if needed
wav_path = input_file
if input_file.suffix.lower() in [".m4a", ".aac", ".mp4"]:
wav_path = Path(self.convert_to_wav(str(input_file), sample_rate=sample_rate))
audio, sr = read_audio(str(wav_path), target_sr=sample_rate)
except Exception as e:
logger.error(f"Failed to load and process audio: {e}")
error_report: ErrorReport = {
"status": "failed",
"error": f"Failed to load and process audio: {e}",
"error_type": "audio_io",
}
return error_report
try:
# Export each speaker
output_files = []
for i, speaker_id in enumerate(speakers):
if progress_callback:
# Progress from 0.8 to 1.0 for speaker exports
export_progress = 0.8 + (0.2 * (i + 1) / len(speakers))
progress_callback(
f"Exporting speaker {i + 1}/{len(speakers)}", export_progress, 1.0
)
# Extract and concatenate all segments for this speaker
speaker_segments = segments[speaker_id]
speaker_audio_parts = []
for segment in speaker_segments:
segment_audio = extract_segment(audio, sr, segment.start_time, segment.end_time)
speaker_audio_parts.append(segment_audio)
# Concatenate segments
if speaker_audio_parts:
speaker_audio = np.concatenate(speaker_audio_parts)
# Export to M4A
output_file = output_dir / f"speaker_{i:02d}.m4a"
self.export_speaker_audio(
speaker_audio,
sr,
str(output_file),
output_sample_rate=sample_rate,
bitrate=bitrate,
)
output_files.append(
{
"speaker_id": speaker_id,
"file": str(output_file),
"duration": len(speaker_audio) / sr,
"segments_count": len(speaker_segments),
}
)
# Generate and save report
processing_time = time.time() - start_time
report = self.generate_separation_report(
input_file=str(input_file),
speakers=speakers,
segments=segments,
processing_time=processing_time,
output_files=output_files,
input_duration=input_duration,
)
# Write report JSON
report_file = output_dir / "separation_report.json"
with open(report_file, "w") as f:
json.dump(report, f, indent=2)
logger.info(f"Separation complete: {len(speakers)} speakers in {processing_time:.1f}s")
if progress_callback:
progress_callback("Complete", 1.0, 1.0)
return report
except Exception as e:
logger.error(f"Failed to export speakers: {e}")
error_report: ErrorReport = {
"status": "failed",
"error": f"Failed to export speakers: {e}",
"error_type": "processing",
}
return error_report