saadmannan's picture
initial commit
b77cba7
#!/usr/bin/env python3
"""
Unit tests for VAD module
"""
import pytest
import torch
import numpy as np
from pathlib import Path
import sys
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.vad import SileroVAD
class TestSileroVAD:
"""Test cases for Silero VAD."""
@pytest.fixture
def vad(self):
"""Create VAD instance for testing."""
return SileroVAD(threshold=0.5)
def test_initialization(self, vad):
"""Test VAD initialization."""
assert vad is not None
assert vad.threshold == 0.5
assert vad.sampling_rate == 16000
assert vad.model is not None
def test_process_chunk(self, vad):
"""Test processing a single audio chunk."""
# Create test audio
chunk = np.random.randn(1536).astype(np.float32)
# Process
prob = vad.process_chunk(chunk)
# Verify
assert isinstance(prob, float)
assert 0.0 <= prob <= 1.0
def test_get_speech_timestamps(self, vad):
"""Test getting speech timestamps."""
# Create test audio with speech-like pattern
sr = 16000
duration = 5
audio = np.zeros(sr * duration, dtype=np.float32)
# Add "speech" in middle (higher energy)
audio[sr:sr*3] = 0.5 * np.random.randn(sr * 2)
# Get timestamps
timestamps = vad.get_speech_timestamps(audio, return_seconds=True)
# Verify
assert isinstance(timestamps, list)
for ts in timestamps:
assert 'start' in ts
assert 'end' in ts
assert ts['end'] > ts['start']
def test_reset_states(self, vad):
"""Test state reset."""
# Process some audio
chunk = np.random.randn(1536).astype(np.float32)
vad.process_chunk(chunk)
# Reset
vad.reset_states()
# Should work without error
prob = vad.process_chunk(chunk)
assert isinstance(prob, float)
def test_benchmark_latency(self, vad):
"""Test latency benchmarking."""
metrics = vad.benchmark_latency(duration_seconds=1.0)
# Verify metrics
assert 'total_processing_time_ms' in metrics
assert 'audio_duration_s' in metrics
assert 'latency_per_second_ms' in metrics
assert 'real_time_factor' in metrics
# Check latency target
assert metrics['latency_per_second_ms'] < 1000 # Should be much faster
def test_different_thresholds(self):
"""Test VAD with different thresholds."""
thresholds = [0.3, 0.5, 0.7]
for threshold in thresholds:
vad = SileroVAD(threshold=threshold)
assert vad.threshold == threshold
# Test processing
audio = np.random.randn(16000).astype(np.float32)
timestamps = vad.get_speech_timestamps(audio)
assert isinstance(timestamps, list)
def test_vad_import():
"""Test that VAD can be imported."""
from src.vad import SileroVAD
assert SileroVAD is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])