#!/usr/bin/env python3 """ Unit tests for VAD module """ import pytest import torch import numpy as np from pathlib import Path import sys # Add src to path sys.path.insert(0, str(Path(__file__).parent.parent)) from src.vad import SileroVAD class TestSileroVAD: """Test cases for Silero VAD.""" @pytest.fixture def vad(self): """Create VAD instance for testing.""" return SileroVAD(threshold=0.5) def test_initialization(self, vad): """Test VAD initialization.""" assert vad is not None assert vad.threshold == 0.5 assert vad.sampling_rate == 16000 assert vad.model is not None def test_process_chunk(self, vad): """Test processing a single audio chunk.""" # Create test audio chunk = np.random.randn(1536).astype(np.float32) # Process prob = vad.process_chunk(chunk) # Verify assert isinstance(prob, float) assert 0.0 <= prob <= 1.0 def test_get_speech_timestamps(self, vad): """Test getting speech timestamps.""" # Create test audio with speech-like pattern sr = 16000 duration = 5 audio = np.zeros(sr * duration, dtype=np.float32) # Add "speech" in middle (higher energy) audio[sr:sr*3] = 0.5 * np.random.randn(sr * 2) # Get timestamps timestamps = vad.get_speech_timestamps(audio, return_seconds=True) # Verify assert isinstance(timestamps, list) for ts in timestamps: assert 'start' in ts assert 'end' in ts assert ts['end'] > ts['start'] def test_reset_states(self, vad): """Test state reset.""" # Process some audio chunk = np.random.randn(1536).astype(np.float32) vad.process_chunk(chunk) # Reset vad.reset_states() # Should work without error prob = vad.process_chunk(chunk) assert isinstance(prob, float) def test_benchmark_latency(self, vad): """Test latency benchmarking.""" metrics = vad.benchmark_latency(duration_seconds=1.0) # Verify metrics assert 'total_processing_time_ms' in metrics assert 'audio_duration_s' in metrics assert 'latency_per_second_ms' in metrics assert 'real_time_factor' in metrics # Check latency target assert metrics['latency_per_second_ms'] < 1000 # Should be much faster def test_different_thresholds(self): """Test VAD with different thresholds.""" thresholds = [0.3, 0.5, 0.7] for threshold in thresholds: vad = SileroVAD(threshold=threshold) assert vad.threshold == threshold # Test processing audio = np.random.randn(16000).astype(np.float32) timestamps = vad.get_speech_timestamps(audio) assert isinstance(timestamps, list) def test_vad_import(): """Test that VAD can be imported.""" from src.vad import SileroVAD assert SileroVAD is not None if __name__ == "__main__": pytest.main([__file__, "-v"])