saadmannan's picture
initial commit
b77cba7
#!/usr/bin/env python3
"""
Silero VAD Wrapper for Real-Time Voice Activity Detection
Optimized for <100ms latency with streaming support
"""
import torch
import numpy as np
from typing import List, Dict, Optional, Tuple
import time
from pathlib import Path
class SileroVAD:
"""
Production-ready Silero VAD wrapper with streaming support.
Features:
- Real-time processing with <100ms latency
- Configurable sensitivity thresholds
- Streaming audio buffer management
- ONNX runtime support for optimization
"""
def __init__(
self,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_speech_duration_ms: int = 250,
min_silence_duration_ms: int = 100,
window_size_samples: int = 1536,
use_onnx: bool = False
):
"""
Initialize Silero VAD.
Args:
threshold: Speech probability threshold (0.0-1.0)
sampling_rate: Audio sample rate (8000 or 16000)
min_speech_duration_ms: Minimum speech segment duration
min_silence_duration_ms: Minimum silence duration between segments
window_size_samples: VAD window size (512, 1024, or 1536)
use_onnx: Use ONNX runtime for faster inference
"""
self.threshold = threshold
self.sampling_rate = sampling_rate
self.min_speech_duration_ms = min_speech_duration_ms
self.min_silence_duration_ms = min_silence_duration_ms
self.window_size_samples = window_size_samples
self.use_onnx = use_onnx
# Load model
self.model = self._load_model()
# State for streaming
self.reset_states()
print(f"✓ Silero VAD initialized (threshold={threshold}, sr={sampling_rate}Hz)")
def _load_model(self):
"""Load Silero VAD model."""
try:
# Try importing from silero_vad package
from silero_vad import load_silero_vad
model = load_silero_vad(onnx=self.use_onnx)
return model
except ImportError:
# Fallback: load from torch hub
model, utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=self.use_onnx
)
return model
def reset_states(self):
"""Reset internal states for streaming."""
self.model.reset_states()
def process_chunk(self, audio_chunk: np.ndarray) -> float:
"""
Process a single audio chunk and return speech probability.
Args:
audio_chunk: Audio data (numpy array, float32, mono)
Returns:
Speech probability (0.0-1.0)
"""
# Convert to torch tensor
if isinstance(audio_chunk, np.ndarray):
audio_tensor = torch.from_numpy(audio_chunk).float()
else:
audio_tensor = audio_chunk
# Get speech probability
with torch.no_grad():
speech_prob = self.model(audio_tensor, self.sampling_rate).item()
return speech_prob
def get_speech_timestamps(
self,
audio: np.ndarray,
return_seconds: bool = False
) -> List[Dict[str, float]]:
"""
Get speech timestamps from audio.
Args:
audio: Audio data (numpy array, float32, mono)
return_seconds: Return timestamps in seconds instead of samples
Returns:
List of dicts with 'start' and 'end' keys
"""
try:
from silero_vad import get_speech_timestamps
# Convert to torch tensor
if isinstance(audio, np.ndarray):
audio_tensor = torch.from_numpy(audio).float()
else:
audio_tensor = audio
# Get timestamps
timestamps = get_speech_timestamps(
audio_tensor,
self.model,
threshold=self.threshold,
sampling_rate=self.sampling_rate,
min_speech_duration_ms=self.min_speech_duration_ms,
min_silence_duration_ms=self.min_silence_duration_ms,
window_size_samples=self.window_size_samples,
return_seconds=return_seconds
)
return timestamps
except ImportError:
# Fallback: manual implementation
return self._get_speech_timestamps_manual(audio, return_seconds)
def _get_speech_timestamps_manual(
self,
audio: np.ndarray,
return_seconds: bool = False
) -> List[Dict[str, float]]:
"""Manual implementation of speech timestamp detection."""
if isinstance(audio, np.ndarray):
audio_tensor = torch.from_numpy(audio).float()
else:
audio_tensor = audio
# Process in windows
window_size = self.window_size_samples
speech_probs = []
self.reset_states()
for i in range(0, len(audio_tensor), window_size):
chunk = audio_tensor[i:i + window_size]
if len(chunk) < window_size:
# Pad last chunk
chunk = torch.nn.functional.pad(chunk, (0, window_size - len(chunk)))
prob = self.process_chunk(chunk)
speech_probs.append(prob)
# Find speech segments
timestamps = []
in_speech = False
speech_start = 0
for i, prob in enumerate(speech_probs):
sample_idx = i * window_size
if prob >= self.threshold and not in_speech:
# Speech start
in_speech = True
speech_start = sample_idx
elif prob < self.threshold and in_speech:
# Speech end
in_speech = False
speech_end = sample_idx
# Check minimum duration
duration_ms = (speech_end - speech_start) / self.sampling_rate * 1000
if duration_ms >= self.min_speech_duration_ms:
if return_seconds:
timestamps.append({
'start': speech_start / self.sampling_rate,
'end': speech_end / self.sampling_rate
})
else:
timestamps.append({
'start': speech_start,
'end': speech_end
})
# Handle case where speech continues to end
if in_speech:
speech_end = len(audio_tensor)
if return_seconds:
timestamps.append({
'start': speech_start / self.sampling_rate,
'end': speech_end / self.sampling_rate
})
else:
timestamps.append({
'start': speech_start,
'end': speech_end
})
return timestamps
def process_file(self, audio_path: str) -> Tuple[List[Dict], float]:
"""
Process an audio file and return speech segments with latency.
Args:
audio_path: Path to audio file
Returns:
Tuple of (timestamps, processing_time_ms)
"""
# Load audio
audio = self.read_audio(audio_path)
# Measure processing time
start_time = time.time()
timestamps = self.get_speech_timestamps(audio, return_seconds=True)
processing_time = (time.time() - start_time) * 1000 # Convert to ms
return timestamps, processing_time
@staticmethod
def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor:
"""
Read audio file and convert to required format.
Args:
path: Path to audio file
sampling_rate: Target sample rate
Returns:
Audio tensor (mono, float32)
"""
try:
from silero_vad import read_audio
return read_audio(path, sampling_rate=sampling_rate)
except ImportError:
# Fallback: use librosa
import librosa
audio, sr = librosa.load(path, sr=sampling_rate, mono=True)
return torch.from_numpy(audio).float()
def benchmark_latency(self, duration_seconds: float = 10.0) -> Dict[str, float]:
"""
Benchmark VAD latency on synthetic audio.
Args:
duration_seconds: Duration of test audio
Returns:
Dict with latency metrics
"""
# Generate test audio
num_samples = int(duration_seconds * self.sampling_rate)
test_audio = torch.randn(num_samples)
# Warm-up
self.reset_states()
_ = self.get_speech_timestamps(test_audio.numpy())
# Benchmark
self.reset_states()
start_time = time.time()
timestamps = self.get_speech_timestamps(test_audio.numpy())
end_time = time.time()
processing_time_ms = (end_time - start_time) * 1000
latency_per_second = processing_time_ms / duration_seconds
return {
'total_processing_time_ms': processing_time_ms,
'audio_duration_s': duration_seconds,
'latency_per_second_ms': latency_per_second,
'real_time_factor': processing_time_ms / (duration_seconds * 1000),
'num_segments': len(timestamps)
}
def demo():
"""Demo VAD functionality."""
print("\n" + "="*60)
print("SILERO VAD DEMO")
print("="*60)
# Initialize VAD
vad = SileroVAD(threshold=0.5)
# Benchmark latency
print("\n📊 Benchmarking latency...")
metrics = vad.benchmark_latency(duration_seconds=10.0)
print(f" Total processing time: {metrics['total_processing_time_ms']:.2f}ms")
print(f" Audio duration: {metrics['audio_duration_s']:.1f}s")
print(f" Latency per second: {metrics['latency_per_second_ms']:.2f}ms")
print(f" Real-time factor: {metrics['real_time_factor']:.4f}x")
if metrics['latency_per_second_ms'] < 100:
print(" ✅ Target latency achieved (<100ms)")
else:
print(" ⚠️ Latency above target (>100ms)")
print("\n" + "="*60)
if __name__ == "__main__":
demo()