|
|
""" |
|
|
Tests for STT module |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
from stt import WhisperTranscriber |
|
|
from stt.audio_processor import load_audio, resample_audio, get_audio_duration |
|
|
|
|
|
|
|
|
class TestWhisperTranscriber: |
|
|
"""Tests for WhisperTranscriber""" |
|
|
|
|
|
@pytest.fixture |
|
|
def model_path(self): |
|
|
"""Model path fixture""" |
|
|
return Path(__file__).parent.parent |
|
|
|
|
|
def test_initialization(self, model_path): |
|
|
"""Test basic initialization""" |
|
|
|
|
|
if not (model_path / "config.json").exists(): |
|
|
pytest.skip("Model files not found") |
|
|
|
|
|
transcriber = WhisperTranscriber( |
|
|
model_path=model_path, |
|
|
device="cpu", |
|
|
dtype="float32" |
|
|
) |
|
|
|
|
|
assert transcriber is not None |
|
|
assert transcriber.device == "cpu" |
|
|
|
|
|
def test_device_resolution(self, model_path): |
|
|
"""Test device resolution""" |
|
|
if not (model_path / "config.json").exists(): |
|
|
pytest.skip("Model files not found") |
|
|
|
|
|
transcriber = WhisperTranscriber( |
|
|
model_path=model_path, |
|
|
device="auto" |
|
|
) |
|
|
|
|
|
|
|
|
assert transcriber.device in ["cpu", "cuda", "mps"] |
|
|
|
|
|
def test_get_model_info(self, model_path): |
|
|
"""Test getting model info""" |
|
|
if not (model_path / "config.json").exists(): |
|
|
pytest.skip("Model files not found") |
|
|
|
|
|
transcriber = WhisperTranscriber( |
|
|
model_path=model_path, |
|
|
device="cpu" |
|
|
) |
|
|
|
|
|
info = transcriber.get_model_info() |
|
|
assert 'model_path' in info |
|
|
assert 'device' in info |
|
|
assert 'language' in info |
|
|
|
|
|
|
|
|
class TestAudioProcessor: |
|
|
"""Tests for audio processing functions""" |
|
|
|
|
|
def test_resample_audio(self): |
|
|
"""Test audio resampling""" |
|
|
|
|
|
audio = np.random.randn(16000) |
|
|
|
|
|
|
|
|
resampled = resample_audio(audio, orig_sr=16000, target_sr=8000) |
|
|
|
|
|
assert len(resampled) == 8000 |
|
|
|
|
|
def test_resample_same_rate(self): |
|
|
"""Test resampling with same rate (should return same audio)""" |
|
|
audio = np.random.randn(16000) |
|
|
|
|
|
resampled = resample_audio(audio, orig_sr=16000, target_sr=16000) |
|
|
|
|
|
assert len(resampled) == len(audio) |
|
|
np.testing.assert_array_equal(audio, resampled) |
|
|
|
|
|
def test_get_audio_duration(self): |
|
|
"""Test audio duration calculation""" |
|
|
|
|
|
audio = np.random.randn(16000) |
|
|
|
|
|
duration = get_audio_duration(audio, sr=16000) |
|
|
|
|
|
assert duration == 1.0 |
|
|
|
|
|
def test_get_audio_duration_different_sr(self): |
|
|
"""Test audio duration with different sample rate""" |
|
|
|
|
|
audio = np.random.randn(16000) |
|
|
|
|
|
duration = get_audio_duration(audio, sr=8000) |
|
|
|
|
|
assert duration == 2.0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__, "-v"]) |
|
|
|