Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Silero VAD Wrapper for Real-Time Voice Activity Detection | |
| Optimized for <100ms latency with streaming support | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Optional, Tuple | |
| import time | |
| from pathlib import Path | |
| class SileroVAD: | |
| """ | |
| Production-ready Silero VAD wrapper with streaming support. | |
| Features: | |
| - Real-time processing with <100ms latency | |
| - Configurable sensitivity thresholds | |
| - Streaming audio buffer management | |
| - ONNX runtime support for optimization | |
| """ | |
| def __init__( | |
| self, | |
| threshold: float = 0.5, | |
| sampling_rate: int = 16000, | |
| min_speech_duration_ms: int = 250, | |
| min_silence_duration_ms: int = 100, | |
| window_size_samples: int = 1536, | |
| use_onnx: bool = False | |
| ): | |
| """ | |
| Initialize Silero VAD. | |
| Args: | |
| threshold: Speech probability threshold (0.0-1.0) | |
| sampling_rate: Audio sample rate (8000 or 16000) | |
| min_speech_duration_ms: Minimum speech segment duration | |
| min_silence_duration_ms: Minimum silence duration between segments | |
| window_size_samples: VAD window size (512, 1024, or 1536) | |
| use_onnx: Use ONNX runtime for faster inference | |
| """ | |
| self.threshold = threshold | |
| self.sampling_rate = sampling_rate | |
| self.min_speech_duration_ms = min_speech_duration_ms | |
| self.min_silence_duration_ms = min_silence_duration_ms | |
| self.window_size_samples = window_size_samples | |
| self.use_onnx = use_onnx | |
| # Load model | |
| self.model = self._load_model() | |
| # State for streaming | |
| self.reset_states() | |
| print(f"✓ Silero VAD initialized (threshold={threshold}, sr={sampling_rate}Hz)") | |
| def _load_model(self): | |
| """Load Silero VAD model.""" | |
| try: | |
| # Try importing from silero_vad package | |
| from silero_vad import load_silero_vad | |
| model = load_silero_vad(onnx=self.use_onnx) | |
| return model | |
| except ImportError: | |
| # Fallback: load from torch hub | |
| model, utils = torch.hub.load( | |
| repo_or_dir='snakers4/silero-vad', | |
| model='silero_vad', | |
| force_reload=False, | |
| onnx=self.use_onnx | |
| ) | |
| return model | |
| def reset_states(self): | |
| """Reset internal states for streaming.""" | |
| self.model.reset_states() | |
| def process_chunk(self, audio_chunk: np.ndarray) -> float: | |
| """ | |
| Process a single audio chunk and return speech probability. | |
| Args: | |
| audio_chunk: Audio data (numpy array, float32, mono) | |
| Returns: | |
| Speech probability (0.0-1.0) | |
| """ | |
| # Convert to torch tensor | |
| if isinstance(audio_chunk, np.ndarray): | |
| audio_tensor = torch.from_numpy(audio_chunk).float() | |
| else: | |
| audio_tensor = audio_chunk | |
| # Get speech probability | |
| with torch.no_grad(): | |
| speech_prob = self.model(audio_tensor, self.sampling_rate).item() | |
| return speech_prob | |
| def get_speech_timestamps( | |
| self, | |
| audio: np.ndarray, | |
| return_seconds: bool = False | |
| ) -> List[Dict[str, float]]: | |
| """ | |
| Get speech timestamps from audio. | |
| Args: | |
| audio: Audio data (numpy array, float32, mono) | |
| return_seconds: Return timestamps in seconds instead of samples | |
| Returns: | |
| List of dicts with 'start' and 'end' keys | |
| """ | |
| try: | |
| from silero_vad import get_speech_timestamps | |
| # Convert to torch tensor | |
| if isinstance(audio, np.ndarray): | |
| audio_tensor = torch.from_numpy(audio).float() | |
| else: | |
| audio_tensor = audio | |
| # Get timestamps | |
| timestamps = get_speech_timestamps( | |
| audio_tensor, | |
| self.model, | |
| threshold=self.threshold, | |
| sampling_rate=self.sampling_rate, | |
| min_speech_duration_ms=self.min_speech_duration_ms, | |
| min_silence_duration_ms=self.min_silence_duration_ms, | |
| window_size_samples=self.window_size_samples, | |
| return_seconds=return_seconds | |
| ) | |
| return timestamps | |
| except ImportError: | |
| # Fallback: manual implementation | |
| return self._get_speech_timestamps_manual(audio, return_seconds) | |
| def _get_speech_timestamps_manual( | |
| self, | |
| audio: np.ndarray, | |
| return_seconds: bool = False | |
| ) -> List[Dict[str, float]]: | |
| """Manual implementation of speech timestamp detection.""" | |
| if isinstance(audio, np.ndarray): | |
| audio_tensor = torch.from_numpy(audio).float() | |
| else: | |
| audio_tensor = audio | |
| # Process in windows | |
| window_size = self.window_size_samples | |
| speech_probs = [] | |
| self.reset_states() | |
| for i in range(0, len(audio_tensor), window_size): | |
| chunk = audio_tensor[i:i + window_size] | |
| if len(chunk) < window_size: | |
| # Pad last chunk | |
| chunk = torch.nn.functional.pad(chunk, (0, window_size - len(chunk))) | |
| prob = self.process_chunk(chunk) | |
| speech_probs.append(prob) | |
| # Find speech segments | |
| timestamps = [] | |
| in_speech = False | |
| speech_start = 0 | |
| for i, prob in enumerate(speech_probs): | |
| sample_idx = i * window_size | |
| if prob >= self.threshold and not in_speech: | |
| # Speech start | |
| in_speech = True | |
| speech_start = sample_idx | |
| elif prob < self.threshold and in_speech: | |
| # Speech end | |
| in_speech = False | |
| speech_end = sample_idx | |
| # Check minimum duration | |
| duration_ms = (speech_end - speech_start) / self.sampling_rate * 1000 | |
| if duration_ms >= self.min_speech_duration_ms: | |
| if return_seconds: | |
| timestamps.append({ | |
| 'start': speech_start / self.sampling_rate, | |
| 'end': speech_end / self.sampling_rate | |
| }) | |
| else: | |
| timestamps.append({ | |
| 'start': speech_start, | |
| 'end': speech_end | |
| }) | |
| # Handle case where speech continues to end | |
| if in_speech: | |
| speech_end = len(audio_tensor) | |
| if return_seconds: | |
| timestamps.append({ | |
| 'start': speech_start / self.sampling_rate, | |
| 'end': speech_end / self.sampling_rate | |
| }) | |
| else: | |
| timestamps.append({ | |
| 'start': speech_start, | |
| 'end': speech_end | |
| }) | |
| return timestamps | |
| def process_file(self, audio_path: str) -> Tuple[List[Dict], float]: | |
| """ | |
| Process an audio file and return speech segments with latency. | |
| Args: | |
| audio_path: Path to audio file | |
| Returns: | |
| Tuple of (timestamps, processing_time_ms) | |
| """ | |
| # Load audio | |
| audio = self.read_audio(audio_path) | |
| # Measure processing time | |
| start_time = time.time() | |
| timestamps = self.get_speech_timestamps(audio, return_seconds=True) | |
| processing_time = (time.time() - start_time) * 1000 # Convert to ms | |
| return timestamps, processing_time | |
| def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor: | |
| """ | |
| Read audio file and convert to required format. | |
| Args: | |
| path: Path to audio file | |
| sampling_rate: Target sample rate | |
| Returns: | |
| Audio tensor (mono, float32) | |
| """ | |
| try: | |
| from silero_vad import read_audio | |
| return read_audio(path, sampling_rate=sampling_rate) | |
| except ImportError: | |
| # Fallback: use librosa | |
| import librosa | |
| audio, sr = librosa.load(path, sr=sampling_rate, mono=True) | |
| return torch.from_numpy(audio).float() | |
| def benchmark_latency(self, duration_seconds: float = 10.0) -> Dict[str, float]: | |
| """ | |
| Benchmark VAD latency on synthetic audio. | |
| Args: | |
| duration_seconds: Duration of test audio | |
| Returns: | |
| Dict with latency metrics | |
| """ | |
| # Generate test audio | |
| num_samples = int(duration_seconds * self.sampling_rate) | |
| test_audio = torch.randn(num_samples) | |
| # Warm-up | |
| self.reset_states() | |
| _ = self.get_speech_timestamps(test_audio.numpy()) | |
| # Benchmark | |
| self.reset_states() | |
| start_time = time.time() | |
| timestamps = self.get_speech_timestamps(test_audio.numpy()) | |
| end_time = time.time() | |
| processing_time_ms = (end_time - start_time) * 1000 | |
| latency_per_second = processing_time_ms / duration_seconds | |
| return { | |
| 'total_processing_time_ms': processing_time_ms, | |
| 'audio_duration_s': duration_seconds, | |
| 'latency_per_second_ms': latency_per_second, | |
| 'real_time_factor': processing_time_ms / (duration_seconds * 1000), | |
| 'num_segments': len(timestamps) | |
| } | |
| def demo(): | |
| """Demo VAD functionality.""" | |
| print("\n" + "="*60) | |
| print("SILERO VAD DEMO") | |
| print("="*60) | |
| # Initialize VAD | |
| vad = SileroVAD(threshold=0.5) | |
| # Benchmark latency | |
| print("\n📊 Benchmarking latency...") | |
| metrics = vad.benchmark_latency(duration_seconds=10.0) | |
| print(f" Total processing time: {metrics['total_processing_time_ms']:.2f}ms") | |
| print(f" Audio duration: {metrics['audio_duration_s']:.1f}s") | |
| print(f" Latency per second: {metrics['latency_per_second_ms']:.2f}ms") | |
| print(f" Real-time factor: {metrics['real_time_factor']:.4f}x") | |
| if metrics['latency_per_second_ms'] < 100: | |
| print(" ✅ Target latency achieved (<100ms)") | |
| else: | |
| print(" ⚠️ Latency above target (>100ms)") | |
| print("\n" + "="*60) | |
| if __name__ == "__main__": | |
| demo() | |