|
|
""" |
|
|
Integration tests for speech pathology diagnosis API. |
|
|
|
|
|
Tests API endpoints, error mapping, and therapy recommendations. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import numpy as np |
|
|
import tempfile |
|
|
import soundfile as sf |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def test_phoneme_mapping(): |
|
|
"""Test phoneme mapping functionality.""" |
|
|
logger.info("Testing phoneme mapping...") |
|
|
|
|
|
try: |
|
|
from models.phoneme_mapper import PhonemeMapper |
|
|
|
|
|
mapper = PhonemeMapper(frame_duration_ms=20) |
|
|
|
|
|
|
|
|
phonemes = mapper.text_to_phonemes("robot") |
|
|
assert len(phonemes) > 0, "Should extract phonemes" |
|
|
logger.info(f"β
'robot' β {len(phonemes)} phonemes: {[p[0] for p in phonemes]}") |
|
|
|
|
|
|
|
|
frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25) |
|
|
assert len(frame_phonemes) == 25, "Should have 25 frames" |
|
|
logger.info(f"β
Aligned to {len(frame_phonemes)} frames") |
|
|
|
|
|
|
|
|
cat_frames = mapper.map_text_to_frames("cat", num_frames=15) |
|
|
assert len(cat_frames) == 15, "Should have 15 frames" |
|
|
logger.info(f"β
'cat' β {len(cat_frames)} frame phonemes") |
|
|
|
|
|
return True |
|
|
|
|
|
except ImportError as e: |
|
|
logger.warning(f"β οΈ G2P library not available: {e}") |
|
|
return False |
|
|
except Exception as e: |
|
|
logger.error(f"β Phoneme mapping test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_error_taxonomy(): |
|
|
"""Test error taxonomy and therapy mapping.""" |
|
|
logger.info("Testing error taxonomy...") |
|
|
|
|
|
try: |
|
|
from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel |
|
|
|
|
|
mapper = ErrorMapper() |
|
|
|
|
|
|
|
|
error = mapper.map_classifier_output(0, 0.95, "/k/") |
|
|
assert error.error_type == ErrorType.NORMAL |
|
|
assert error.severity == 0.0 |
|
|
logger.info(f"β
Normal error mapping: {error.error_type}") |
|
|
|
|
|
|
|
|
error = mapper.map_classifier_output(1, 0.78, "/s/") |
|
|
assert error.error_type == ErrorType.SUBSTITUTION |
|
|
assert error.wrong_sound is not None |
|
|
logger.info(f"β
Substitution error: {error.error_type}, wrong_sound={error.wrong_sound}") |
|
|
logger.info(f" Therapy: {error.therapy[:60]}...") |
|
|
|
|
|
|
|
|
error = mapper.map_classifier_output(2, 0.85, "/r/") |
|
|
assert error.error_type == ErrorType.OMISSION |
|
|
logger.info(f"β
Omission error: {error.error_type}") |
|
|
logger.info(f" Therapy: {error.therapy[:60]}...") |
|
|
|
|
|
|
|
|
error = mapper.map_classifier_output(3, 0.65, "/s/") |
|
|
assert error.error_type == ErrorType.DISTORTION |
|
|
logger.info(f"β
Distortion error: {error.error_type}") |
|
|
logger.info(f" Therapy: {error.therapy[:60]}...") |
|
|
|
|
|
|
|
|
assert mapper.get_severity_level(0.0) == SeverityLevel.NONE |
|
|
assert mapper.get_severity_level(0.2) == SeverityLevel.LOW |
|
|
assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM |
|
|
assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH |
|
|
logger.info("β
Severity level mapping correct") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Error taxonomy test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_batch_diagnosis_endpoint(pipeline, phoneme_mapper, error_mapper): |
|
|
"""Test batch diagnosis endpoint functionality.""" |
|
|
logger.info("Testing batch diagnosis endpoint...") |
|
|
|
|
|
try: |
|
|
|
|
|
duration = 2.0 |
|
|
sample_rate = 16000 |
|
|
num_samples = int(duration * sample_rate) |
|
|
audio = 0.5 * np.sin(2 * np.pi * 440 * np.linspace(0, duration, num_samples)) |
|
|
audio = audio.astype(np.float32) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: |
|
|
temp_path = f.name |
|
|
sf.write(temp_path, audio, sample_rate) |
|
|
|
|
|
try: |
|
|
|
|
|
result = pipeline.predict_phone_level(temp_path, return_timestamps=True) |
|
|
|
|
|
|
|
|
text = "test audio" |
|
|
frame_phonemes = phoneme_mapper.map_text_to_frames( |
|
|
text, |
|
|
num_frames=result.num_frames, |
|
|
audio_duration=result.duration |
|
|
) |
|
|
|
|
|
|
|
|
errors = [] |
|
|
for i, frame_pred in enumerate(result.frame_predictions): |
|
|
class_id = frame_pred.articulation_class |
|
|
if frame_pred.fluency_label == 'stutter': |
|
|
class_id += 4 |
|
|
|
|
|
error_detail = error_mapper.map_classifier_output( |
|
|
class_id=class_id, |
|
|
confidence=frame_pred.confidence, |
|
|
phoneme=frame_phonemes[i] if i < len(frame_phonemes) else '', |
|
|
fluency_label=frame_pred.fluency_label |
|
|
) |
|
|
|
|
|
if error_detail.error_type != ErrorType.NORMAL: |
|
|
errors.append(error_detail) |
|
|
|
|
|
logger.info(f"β
Batch diagnosis: {result.num_frames} frames, {len(errors)} errors detected") |
|
|
|
|
|
return True |
|
|
|
|
|
finally: |
|
|
import os |
|
|
if os.path.exists(temp_path): |
|
|
os.remove(temp_path) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Batch diagnosis test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_therapy_recommendations(): |
|
|
"""Test therapy recommendation coverage.""" |
|
|
logger.info("Testing therapy recommendations...") |
|
|
|
|
|
try: |
|
|
from models.error_taxonomy import ErrorMapper, ErrorType |
|
|
|
|
|
mapper = ErrorMapper() |
|
|
|
|
|
|
|
|
test_cases = [ |
|
|
("/s/", ErrorType.SUBSTITUTION, "/ΞΈ/"), |
|
|
("/r/", ErrorType.OMISSION, None), |
|
|
("/s/", ErrorType.DISTORTION, None), |
|
|
] |
|
|
|
|
|
for phoneme, error_type, wrong_sound in test_cases: |
|
|
therapy = mapper.get_therapy(error_type, phoneme, wrong_sound) |
|
|
assert therapy and len(therapy) > 0, f"Therapy should not be empty for {phoneme}" |
|
|
logger.info(f"β
{phoneme} {error_type.value}: {therapy[:50]}...") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Therapy recommendations test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def run_all_integration_tests(): |
|
|
"""Run all integration tests.""" |
|
|
logger.info("=" * 60) |
|
|
logger.info("Running Integration Tests") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
logger.info("\n1. Phoneme Mapping Test") |
|
|
results["phoneme_mapping"] = test_phoneme_mapping() |
|
|
|
|
|
|
|
|
logger.info("\n2. Error Taxonomy Test") |
|
|
results["error_taxonomy"] = test_error_taxonomy() |
|
|
|
|
|
|
|
|
logger.info("\n3. Therapy Recommendations Test") |
|
|
results["therapy_recommendations"] = test_therapy_recommendations() |
|
|
|
|
|
|
|
|
try: |
|
|
from inference.inference_pipeline import create_inference_pipeline |
|
|
from models.phoneme_mapper import PhonemeMapper |
|
|
from models.error_taxonomy import ErrorMapper |
|
|
|
|
|
logger.info("\n4. Batch Diagnosis Test") |
|
|
pipeline = create_inference_pipeline() |
|
|
phoneme_mapper = PhonemeMapper() |
|
|
error_mapper = ErrorMapper() |
|
|
|
|
|
results["batch_diagnosis"] = test_batch_diagnosis_endpoint( |
|
|
pipeline, phoneme_mapper, error_mapper |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ Batch diagnosis test skipped: {e}") |
|
|
results["batch_diagnosis"] = False |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("Integration Test Summary") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
for test_name, passed in results.items(): |
|
|
status = "β
PASSED" if passed else "β FAILED" |
|
|
logger.info(f"{status}: {test_name}") |
|
|
|
|
|
all_passed = all(results.values()) |
|
|
return all_passed, results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
all_passed, results = run_all_integration_tests() |
|
|
|
|
|
if all_passed: |
|
|
logger.info("\nβ
All integration tests passed!") |
|
|
exit(0) |
|
|
else: |
|
|
logger.error("\nβ Some integration tests failed!") |
|
|
exit(1) |
|
|
|
|
|
|