Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Integrated VAD + Speaker Diarization Pipeline | |
| Real-time processing with optimized performance | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Optional, Tuple, Union | |
| import time | |
| from pathlib import Path | |
| import json | |
| from .vad import SileroVAD | |
| from .diarization import SpeakerDiarization | |
| class VADDiarizationPipeline: | |
| """ | |
| Integrated pipeline combining VAD and speaker diarization. | |
| Features: | |
| - Two-stage processing: VAD first, then diarization | |
| - Optimized for real-time performance | |
| - Configurable parameters | |
| - Comprehensive output format | |
| """ | |
| def __init__( | |
| self, | |
| vad_threshold: float = 0.5, | |
| use_auth_token: Optional[str] = None, | |
| token: Optional[str] = None, | |
| device: Optional[str] = None, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: Optional[int] = None, | |
| max_speakers: Optional[int] = None, | |
| use_onnx_vad: bool = False | |
| ): | |
| """ | |
| Initialize the integrated pipeline. | |
| Args: | |
| vad_threshold: VAD sensitivity threshold | |
| use_auth_token: (Deprecated) Hugging Face token for diarization | |
| token: Hugging Face token for diarization (new parameter name) | |
| device: Device to use ('cuda' or 'cpu') | |
| num_speakers: Fixed number of speakers | |
| min_speakers: Minimum number of speakers | |
| max_speakers: Maximum number of speakers | |
| use_onnx_vad: Use ONNX for VAD (faster) | |
| """ | |
| print("\n" + "="*60) | |
| print("INITIALIZING VAD + DIARIZATION PIPELINE") | |
| print("="*60) | |
| # Handle both old and new parameter names | |
| auth_token = token or use_auth_token | |
| # Initialize VAD | |
| print("\n[1/2] Loading Voice Activity Detection...") | |
| self.vad = SileroVAD( | |
| threshold=vad_threshold, | |
| use_onnx=use_onnx_vad | |
| ) | |
| # Initialize Diarization | |
| print("\n[2/2] Loading Speaker Diarization...") | |
| self.diarization = SpeakerDiarization( | |
| token=auth_token, | |
| device=device, | |
| num_speakers=num_speakers, | |
| min_speakers=min_speakers, | |
| max_speakers=max_speakers | |
| ) | |
| print("\n" + "="*60) | |
| print("✅ PIPELINE READY") | |
| print("="*60 + "\n") | |
| def process_file( | |
| self, | |
| audio_path: str, | |
| num_speakers: Optional[int] = None, | |
| return_vad: bool = True, | |
| return_stats: bool = True | |
| ) -> Dict: | |
| """ | |
| Process an audio file through the complete pipeline. | |
| Args: | |
| audio_path: Path to audio file | |
| num_speakers: Number of speakers (if known) | |
| return_vad: Include VAD segments in output | |
| return_stats: Include statistics in output | |
| Returns: | |
| Dict with results and metadata | |
| """ | |
| print(f"\n📁 Processing: {audio_path}") | |
| print("-" * 60) | |
| total_start = time.time() | |
| # Stage 1: VAD | |
| print("Stage 1: Voice Activity Detection...") | |
| vad_start = time.time() | |
| vad_segments, vad_time = self.vad.process_file(audio_path) | |
| vad_duration = (time.time() - vad_start) * 1000 | |
| print(f" ✓ Found {len(vad_segments)} speech segments") | |
| print(f" ✓ Processing time: {vad_duration:.2f}ms") | |
| # Stage 2: Diarization | |
| print("\nStage 2: Speaker Diarization...") | |
| diar_start = time.time() | |
| speaker_segments, diar_time, diar_metadata = self.diarization.process_file( | |
| audio_path, | |
| num_speakers=num_speakers | |
| ) | |
| diar_duration = (time.time() - diar_start) * 1000 | |
| print(f" ✓ Identified {diar_metadata['num_speakers']} speakers") | |
| print(f" ✓ Found {diar_metadata['num_segments']} speaker segments") | |
| print(f" ✓ Processing time: {diar_duration:.2f}ms") | |
| # Calculate total time | |
| total_duration = (time.time() - total_start) * 1000 | |
| print(f"\n⏱️ Total processing time: {total_duration:.2f}ms") | |
| print("-" * 60) | |
| # Build result | |
| result = { | |
| 'audio_path': audio_path, | |
| 'speaker_segments': speaker_segments, | |
| 'processing_time': { | |
| 'vad_ms': vad_duration, | |
| 'diarization_ms': diar_duration, | |
| 'total_ms': total_duration | |
| }, | |
| 'metadata': diar_metadata | |
| } | |
| if return_vad: | |
| result['vad_segments'] = vad_segments | |
| if return_stats: | |
| result['speaker_statistics'] = self.diarization.get_speaker_statistics( | |
| speaker_segments | |
| ) | |
| return result | |
| def process_batch( | |
| self, | |
| audio_paths: List[str], | |
| **kwargs | |
| ) -> List[Dict]: | |
| """ | |
| Process multiple audio files. | |
| Args: | |
| audio_paths: List of audio file paths | |
| **kwargs: Additional arguments for process_file | |
| Returns: | |
| List of results | |
| """ | |
| results = [] | |
| print(f"\n📦 Batch processing {len(audio_paths)} files...") | |
| print("="*60) | |
| for i, path in enumerate(audio_paths, 1): | |
| print(f"\n[{i}/{len(audio_paths)}]") | |
| result = self.process_file(path, **kwargs) | |
| results.append(result) | |
| print("\n" + "="*60) | |
| print(f"✅ Batch processing complete ({len(results)} files)") | |
| print("="*60 + "\n") | |
| return results | |
| def format_output(self, result: Dict, format: str = 'text') -> str: | |
| """ | |
| Format pipeline output. | |
| Args: | |
| result: Result from process_file | |
| format: Output format ('text', 'json', 'rttm') | |
| Returns: | |
| Formatted string | |
| """ | |
| if format == 'json': | |
| return json.dumps(result, indent=2) | |
| elif format == 'rttm': | |
| # RTTM format for NIST evaluation | |
| lines = [] | |
| for seg in result['speaker_segments']: | |
| # RTTM format: SPEAKER file 1 start duration <NA> <NA> speaker <NA> <NA> | |
| line = f"SPEAKER {Path(result['audio_path']).stem} 1 {seg['start']:.3f} {seg['duration']:.3f} <NA> <NA> {seg['speaker']} <NA> <NA>" | |
| lines.append(line) | |
| return "\n".join(lines) | |
| else: # text | |
| lines = [] | |
| lines.append("="*60) | |
| lines.append("VAD + SPEAKER DIARIZATION RESULTS") | |
| lines.append("="*60) | |
| lines.append(f"\nFile: {result['audio_path']}") | |
| # Metadata | |
| lines.append(f"\nMetadata:") | |
| lines.append(f" Speakers: {result['metadata']['num_speakers']}") | |
| lines.append(f" Segments: {result['metadata']['num_segments']}") | |
| lines.append(f" Total speech: {result['metadata']['total_speech_time']:.2f}s") | |
| # Processing time | |
| lines.append(f"\nProcessing Time:") | |
| lines.append(f" VAD: {result['processing_time']['vad_ms']:.2f}ms") | |
| lines.append(f" Diarization: {result['processing_time']['diarization_ms']:.2f}ms") | |
| lines.append(f" Total: {result['processing_time']['total_ms']:.2f}ms") | |
| # Speaker statistics | |
| if 'speaker_statistics' in result: | |
| lines.append(f"\nSpeaker Statistics:") | |
| for speaker, stats in result['speaker_statistics'].items(): | |
| lines.append(f" {speaker}:") | |
| lines.append(f" Total time: {stats['total_time']:.2f}s") | |
| lines.append(f" Segments: {stats['num_segments']}") | |
| lines.append(f" Avg duration: {stats['avg_segment_duration']:.2f}s") | |
| # Timeline | |
| lines.append(f"\nSpeaker Timeline:") | |
| lines.append("-"*60) | |
| for seg in result['speaker_segments']: | |
| lines.append(f"{seg['start']:7.2f}s - {seg['end']:7.2f}s: {seg['speaker']}") | |
| lines.append("="*60) | |
| return "\n".join(lines) | |
| def save_results( | |
| self, | |
| result: Dict, | |
| output_path: str, | |
| format: str = 'json' | |
| ): | |
| """ | |
| Save results to file. | |
| Args: | |
| result: Result from process_file | |
| output_path: Output file path | |
| format: Output format ('json', 'rttm', 'text') | |
| """ | |
| output = self.format_output(result, format=format) | |
| with open(output_path, 'w') as f: | |
| f.write(output) | |
| print(f"✓ Results saved to: {output_path}") | |
| def benchmark( | |
| self, | |
| test_audio_path: Optional[str] = None, | |
| duration_seconds: float = 10.0 | |
| ) -> Dict: | |
| """ | |
| Benchmark pipeline performance. | |
| Args: | |
| test_audio_path: Path to test audio (optional) | |
| duration_seconds: Duration for synthetic test | |
| Returns: | |
| Benchmark metrics | |
| """ | |
| print("\n" + "="*60) | |
| print("PIPELINE BENCHMARK") | |
| print("="*60) | |
| # VAD benchmark | |
| print("\n[1/2] Benchmarking VAD...") | |
| vad_metrics = self.vad.benchmark_latency(duration_seconds) | |
| print(f" Latency: {vad_metrics['latency_per_second_ms']:.2f}ms per second") | |
| print(f" Real-time factor: {vad_metrics['real_time_factor']:.4f}x") | |
| if vad_metrics['latency_per_second_ms'] < 100: | |
| print(" ✅ VAD latency target achieved (<100ms)") | |
| else: | |
| print(" ⚠️ VAD latency above target") | |
| # Full pipeline benchmark (if test audio provided) | |
| if test_audio_path: | |
| print("\n[2/2] Benchmarking full pipeline...") | |
| result = self.process_file(test_audio_path, return_stats=False) | |
| print(f" Total time: {result['processing_time']['total_ms']:.2f}ms") | |
| print("\n" + "="*60) | |
| return { | |
| 'vad_metrics': vad_metrics, | |
| 'pipeline_metrics': result['processing_time'] if test_audio_path else None | |
| } | |
| def demo(): | |
| """Demo the integrated pipeline.""" | |
| print("\n" + "="*60) | |
| print("INTEGRATED PIPELINE DEMO") | |
| print("="*60) | |
| import os | |
| # Check for HF token | |
| token = os.environ.get('HF_TOKEN') | |
| if not token: | |
| print("\n⚠️ No HF_TOKEN found in environment") | |
| print("Set it with: export HF_TOKEN='your_token_here'") | |
| print("\nFor now, will demo VAD only...") | |
| # VAD-only demo | |
| vad = SileroVAD() | |
| metrics = vad.benchmark_latency() | |
| print(f"\n✅ VAD latency: {metrics['latency_per_second_ms']:.2f}ms per second") | |
| return | |
| try: | |
| # Initialize pipeline | |
| pipeline = VADDiarizationPipeline( | |
| use_auth_token=token, | |
| vad_threshold=0.5 | |
| ) | |
| # Benchmark | |
| pipeline.benchmark() | |
| print("\n✅ Pipeline demo complete!") | |
| except Exception as e: | |
| print(f"\n❌ Error: {e}") | |
| print("\n" + "="*60) | |
| if __name__ == "__main__": | |
| demo() | |