Trans_for_doctors / tests /test_stt.py
Mintik24's picture
🎉 Полный рефакторинг проекта Medical Transcriber
e275025
"""
Tests for STT module
"""
import pytest
from pathlib import Path
import numpy as np
from stt import WhisperTranscriber
from stt.audio_processor import load_audio, resample_audio, get_audio_duration
class TestWhisperTranscriber:
"""Tests for WhisperTranscriber"""
@pytest.fixture
def model_path(self):
"""Model path fixture"""
return Path(__file__).parent.parent
def test_initialization(self, model_path):
"""Test basic initialization"""
# Skip if model files don't exist
if not (model_path / "config.json").exists():
pytest.skip("Model files not found")
transcriber = WhisperTranscriber(
model_path=model_path,
device="cpu", # Use CPU for testing
dtype="float32"
)
assert transcriber is not None
assert transcriber.device == "cpu"
def test_device_resolution(self, model_path):
"""Test device resolution"""
if not (model_path / "config.json").exists():
pytest.skip("Model files not found")
transcriber = WhisperTranscriber(
model_path=model_path,
device="auto"
)
# Should resolve to cpu, cuda, or mps
assert transcriber.device in ["cpu", "cuda", "mps"]
def test_get_model_info(self, model_path):
"""Test getting model info"""
if not (model_path / "config.json").exists():
pytest.skip("Model files not found")
transcriber = WhisperTranscriber(
model_path=model_path,
device="cpu"
)
info = transcriber.get_model_info()
assert 'model_path' in info
assert 'device' in info
assert 'language' in info
class TestAudioProcessor:
"""Tests for audio processing functions"""
def test_resample_audio(self):
"""Test audio resampling"""
# Create dummy audio
audio = np.random.randn(16000) # 1 second at 16kHz
# Resample to 8kHz
resampled = resample_audio(audio, orig_sr=16000, target_sr=8000)
assert len(resampled) == 8000
def test_resample_same_rate(self):
"""Test resampling with same rate (should return same audio)"""
audio = np.random.randn(16000)
resampled = resample_audio(audio, orig_sr=16000, target_sr=16000)
assert len(resampled) == len(audio)
np.testing.assert_array_equal(audio, resampled)
def test_get_audio_duration(self):
"""Test audio duration calculation"""
# 1 second at 16kHz
audio = np.random.randn(16000)
duration = get_audio_duration(audio, sr=16000)
assert duration == 1.0
def test_get_audio_duration_different_sr(self):
"""Test audio duration with different sample rate"""
# 2 seconds at 8kHz
audio = np.random.randn(16000)
duration = get_audio_duration(audio, sr=8000)
assert duration == 2.0
if __name__ == "__main__":
pytest.main([__file__, "-v"])