#!/usr/bin/env python3 """ Silero VAD Wrapper for Real-Time Voice Activity Detection Optimized for <100ms latency with streaming support """ import torch import numpy as np from typing import List, Dict, Optional, Tuple import time from pathlib import Path class SileroVAD: """ Production-ready Silero VAD wrapper with streaming support. Features: - Real-time processing with <100ms latency - Configurable sensitivity thresholds - Streaming audio buffer management - ONNX runtime support for optimization """ def __init__( self, threshold: float = 0.5, sampling_rate: int = 16000, min_speech_duration_ms: int = 250, min_silence_duration_ms: int = 100, window_size_samples: int = 1536, use_onnx: bool = False ): """ Initialize Silero VAD. Args: threshold: Speech probability threshold (0.0-1.0) sampling_rate: Audio sample rate (8000 or 16000) min_speech_duration_ms: Minimum speech segment duration min_silence_duration_ms: Minimum silence duration between segments window_size_samples: VAD window size (512, 1024, or 1536) use_onnx: Use ONNX runtime for faster inference """ self.threshold = threshold self.sampling_rate = sampling_rate self.min_speech_duration_ms = min_speech_duration_ms self.min_silence_duration_ms = min_silence_duration_ms self.window_size_samples = window_size_samples self.use_onnx = use_onnx # Load model self.model = self._load_model() # State for streaming self.reset_states() print(f"āœ“ Silero VAD initialized (threshold={threshold}, sr={sampling_rate}Hz)") def _load_model(self): """Load Silero VAD model.""" try: # Try importing from silero_vad package from silero_vad import load_silero_vad model = load_silero_vad(onnx=self.use_onnx) return model except ImportError: # Fallback: load from torch hub model, utils = torch.hub.load( repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, onnx=self.use_onnx ) return model def reset_states(self): """Reset internal states for streaming.""" self.model.reset_states() def process_chunk(self, audio_chunk: np.ndarray) -> float: """ Process a single audio chunk and return speech probability. Args: audio_chunk: Audio data (numpy array, float32, mono) Returns: Speech probability (0.0-1.0) """ # Convert to torch tensor if isinstance(audio_chunk, np.ndarray): audio_tensor = torch.from_numpy(audio_chunk).float() else: audio_tensor = audio_chunk # Get speech probability with torch.no_grad(): speech_prob = self.model(audio_tensor, self.sampling_rate).item() return speech_prob def get_speech_timestamps( self, audio: np.ndarray, return_seconds: bool = False ) -> List[Dict[str, float]]: """ Get speech timestamps from audio. Args: audio: Audio data (numpy array, float32, mono) return_seconds: Return timestamps in seconds instead of samples Returns: List of dicts with 'start' and 'end' keys """ try: from silero_vad import get_speech_timestamps # Convert to torch tensor if isinstance(audio, np.ndarray): audio_tensor = torch.from_numpy(audio).float() else: audio_tensor = audio # Get timestamps timestamps = get_speech_timestamps( audio_tensor, self.model, threshold=self.threshold, sampling_rate=self.sampling_rate, min_speech_duration_ms=self.min_speech_duration_ms, min_silence_duration_ms=self.min_silence_duration_ms, window_size_samples=self.window_size_samples, return_seconds=return_seconds ) return timestamps except ImportError: # Fallback: manual implementation return self._get_speech_timestamps_manual(audio, return_seconds) def _get_speech_timestamps_manual( self, audio: np.ndarray, return_seconds: bool = False ) -> List[Dict[str, float]]: """Manual implementation of speech timestamp detection.""" if isinstance(audio, np.ndarray): audio_tensor = torch.from_numpy(audio).float() else: audio_tensor = audio # Process in windows window_size = self.window_size_samples speech_probs = [] self.reset_states() for i in range(0, len(audio_tensor), window_size): chunk = audio_tensor[i:i + window_size] if len(chunk) < window_size: # Pad last chunk chunk = torch.nn.functional.pad(chunk, (0, window_size - len(chunk))) prob = self.process_chunk(chunk) speech_probs.append(prob) # Find speech segments timestamps = [] in_speech = False speech_start = 0 for i, prob in enumerate(speech_probs): sample_idx = i * window_size if prob >= self.threshold and not in_speech: # Speech start in_speech = True speech_start = sample_idx elif prob < self.threshold and in_speech: # Speech end in_speech = False speech_end = sample_idx # Check minimum duration duration_ms = (speech_end - speech_start) / self.sampling_rate * 1000 if duration_ms >= self.min_speech_duration_ms: if return_seconds: timestamps.append({ 'start': speech_start / self.sampling_rate, 'end': speech_end / self.sampling_rate }) else: timestamps.append({ 'start': speech_start, 'end': speech_end }) # Handle case where speech continues to end if in_speech: speech_end = len(audio_tensor) if return_seconds: timestamps.append({ 'start': speech_start / self.sampling_rate, 'end': speech_end / self.sampling_rate }) else: timestamps.append({ 'start': speech_start, 'end': speech_end }) return timestamps def process_file(self, audio_path: str) -> Tuple[List[Dict], float]: """ Process an audio file and return speech segments with latency. Args: audio_path: Path to audio file Returns: Tuple of (timestamps, processing_time_ms) """ # Load audio audio = self.read_audio(audio_path) # Measure processing time start_time = time.time() timestamps = self.get_speech_timestamps(audio, return_seconds=True) processing_time = (time.time() - start_time) * 1000 # Convert to ms return timestamps, processing_time @staticmethod def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor: """ Read audio file and convert to required format. Args: path: Path to audio file sampling_rate: Target sample rate Returns: Audio tensor (mono, float32) """ try: from silero_vad import read_audio return read_audio(path, sampling_rate=sampling_rate) except ImportError: # Fallback: use librosa import librosa audio, sr = librosa.load(path, sr=sampling_rate, mono=True) return torch.from_numpy(audio).float() def benchmark_latency(self, duration_seconds: float = 10.0) -> Dict[str, float]: """ Benchmark VAD latency on synthetic audio. Args: duration_seconds: Duration of test audio Returns: Dict with latency metrics """ # Generate test audio num_samples = int(duration_seconds * self.sampling_rate) test_audio = torch.randn(num_samples) # Warm-up self.reset_states() _ = self.get_speech_timestamps(test_audio.numpy()) # Benchmark self.reset_states() start_time = time.time() timestamps = self.get_speech_timestamps(test_audio.numpy()) end_time = time.time() processing_time_ms = (end_time - start_time) * 1000 latency_per_second = processing_time_ms / duration_seconds return { 'total_processing_time_ms': processing_time_ms, 'audio_duration_s': duration_seconds, 'latency_per_second_ms': latency_per_second, 'real_time_factor': processing_time_ms / (duration_seconds * 1000), 'num_segments': len(timestamps) } def demo(): """Demo VAD functionality.""" print("\n" + "="*60) print("SILERO VAD DEMO") print("="*60) # Initialize VAD vad = SileroVAD(threshold=0.5) # Benchmark latency print("\nšŸ“Š Benchmarking latency...") metrics = vad.benchmark_latency(duration_seconds=10.0) print(f" Total processing time: {metrics['total_processing_time_ms']:.2f}ms") print(f" Audio duration: {metrics['audio_duration_s']:.1f}s") print(f" Latency per second: {metrics['latency_per_second_ms']:.2f}ms") print(f" Real-time factor: {metrics['real_time_factor']:.4f}x") if metrics['latency_per_second_ms'] < 100: print(" āœ… Target latency achieved (<100ms)") else: print(" āš ļø Latency above target (>100ms)") print("\n" + "="*60) if __name__ == "__main__": demo()