File size: 8,712 Bytes
278e294 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
"""
Integration tests for speech pathology diagnosis API.
Tests API endpoints, error mapping, and therapy recommendations.
"""
import logging
import numpy as np
import tempfile
import soundfile as sf
from pathlib import Path
import json
logger = logging.getLogger(__name__)
def test_phoneme_mapping():
"""Test phoneme mapping functionality."""
logger.info("Testing phoneme mapping...")
try:
from models.phoneme_mapper import PhonemeMapper
mapper = PhonemeMapper(frame_duration_ms=20)
# Test 1: Simple word
phonemes = mapper.text_to_phonemes("robot")
assert len(phonemes) > 0, "Should extract phonemes"
logger.info(f"β
'robot' β {len(phonemes)} phonemes: {[p[0] for p in phonemes]}")
# Test 2: Frame alignment
frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25)
assert len(frame_phonemes) == 25, "Should have 25 frames"
logger.info(f"β
Aligned to {len(frame_phonemes)} frames")
# Test 3: Complete pipeline
cat_frames = mapper.map_text_to_frames("cat", num_frames=15)
assert len(cat_frames) == 15, "Should have 15 frames"
logger.info(f"β
'cat' β {len(cat_frames)} frame phonemes")
return True
except ImportError as e:
logger.warning(f"β οΈ G2P library not available: {e}")
return False
except Exception as e:
logger.error(f"β Phoneme mapping test failed: {e}")
return False
def test_error_taxonomy():
"""Test error taxonomy and therapy mapping."""
logger.info("Testing error taxonomy...")
try:
from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel
mapper = ErrorMapper()
# Test 1: Normal (class 0)
error = mapper.map_classifier_output(0, 0.95, "/k/")
assert error.error_type == ErrorType.NORMAL
assert error.severity == 0.0
logger.info(f"β
Normal error mapping: {error.error_type}")
# Test 2: Substitution (class 1)
error = mapper.map_classifier_output(1, 0.78, "/s/")
assert error.error_type == ErrorType.SUBSTITUTION
assert error.wrong_sound is not None
logger.info(f"β
Substitution error: {error.error_type}, wrong_sound={error.wrong_sound}")
logger.info(f" Therapy: {error.therapy[:60]}...")
# Test 3: Omission (class 2)
error = mapper.map_classifier_output(2, 0.85, "/r/")
assert error.error_type == ErrorType.OMISSION
logger.info(f"β
Omission error: {error.error_type}")
logger.info(f" Therapy: {error.therapy[:60]}...")
# Test 4: Distortion (class 3)
error = mapper.map_classifier_output(3, 0.65, "/s/")
assert error.error_type == ErrorType.DISTORTION
logger.info(f"β
Distortion error: {error.error_type}")
logger.info(f" Therapy: {error.therapy[:60]}...")
# Test 5: Severity levels
assert mapper.get_severity_level(0.0) == SeverityLevel.NONE
assert mapper.get_severity_level(0.2) == SeverityLevel.LOW
assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM
assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH
logger.info("β
Severity level mapping correct")
return True
except Exception as e:
logger.error(f"β Error taxonomy test failed: {e}")
return False
def test_batch_diagnosis_endpoint(pipeline, phoneme_mapper, error_mapper):
"""Test batch diagnosis endpoint functionality."""
logger.info("Testing batch diagnosis endpoint...")
try:
# Generate test audio
duration = 2.0
sample_rate = 16000
num_samples = int(duration * sample_rate)
audio = 0.5 * np.sin(2 * np.pi * 440 * np.linspace(0, duration, num_samples))
audio = audio.astype(np.float32)
# Save to temp file
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
temp_path = f.name
sf.write(temp_path, audio, sample_rate)
try:
# Run inference
result = pipeline.predict_phone_level(temp_path, return_timestamps=True)
# Map phonemes
text = "test audio"
frame_phonemes = phoneme_mapper.map_text_to_frames(
text,
num_frames=result.num_frames,
audio_duration=result.duration
)
# Process errors
errors = []
for i, frame_pred in enumerate(result.frame_predictions):
class_id = frame_pred.articulation_class
if frame_pred.fluency_label == 'stutter':
class_id += 4
error_detail = error_mapper.map_classifier_output(
class_id=class_id,
confidence=frame_pred.confidence,
phoneme=frame_phonemes[i] if i < len(frame_phonemes) else '',
fluency_label=frame_pred.fluency_label
)
if error_detail.error_type != ErrorType.NORMAL:
errors.append(error_detail)
logger.info(f"β
Batch diagnosis: {result.num_frames} frames, {len(errors)} errors detected")
return True
finally:
import os
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e:
logger.error(f"β Batch diagnosis test failed: {e}")
return False
def test_therapy_recommendations():
"""Test therapy recommendation coverage."""
logger.info("Testing therapy recommendations...")
try:
from models.error_taxonomy import ErrorMapper, ErrorType
mapper = ErrorMapper()
# Test common phonemes
test_cases = [
("/s/", ErrorType.SUBSTITUTION, "/ΞΈ/"),
("/r/", ErrorType.OMISSION, None),
("/s/", ErrorType.DISTORTION, None),
]
for phoneme, error_type, wrong_sound in test_cases:
therapy = mapper.get_therapy(error_type, phoneme, wrong_sound)
assert therapy and len(therapy) > 0, f"Therapy should not be empty for {phoneme}"
logger.info(f"β
{phoneme} {error_type.value}: {therapy[:50]}...")
return True
except Exception as e:
logger.error(f"β Therapy recommendations test failed: {e}")
return False
def run_all_integration_tests():
"""Run all integration tests."""
logger.info("=" * 60)
logger.info("Running Integration Tests")
logger.info("=" * 60)
results = {}
# Test 1: Phoneme mapping
logger.info("\n1. Phoneme Mapping Test")
results["phoneme_mapping"] = test_phoneme_mapping()
# Test 2: Error taxonomy
logger.info("\n2. Error Taxonomy Test")
results["error_taxonomy"] = test_error_taxonomy()
# Test 3: Therapy recommendations
logger.info("\n3. Therapy Recommendations Test")
results["therapy_recommendations"] = test_therapy_recommendations()
# Test 4: Batch diagnosis (if pipeline available)
try:
from inference.inference_pipeline import create_inference_pipeline
from models.phoneme_mapper import PhonemeMapper
from models.error_taxonomy import ErrorMapper
logger.info("\n4. Batch Diagnosis Test")
pipeline = create_inference_pipeline()
phoneme_mapper = PhonemeMapper()
error_mapper = ErrorMapper()
results["batch_diagnosis"] = test_batch_diagnosis_endpoint(
pipeline, phoneme_mapper, error_mapper
)
except Exception as e:
logger.warning(f"β οΈ Batch diagnosis test skipped: {e}")
results["batch_diagnosis"] = False
# Summary
logger.info("\n" + "=" * 60)
logger.info("Integration Test Summary")
logger.info("=" * 60)
for test_name, passed in results.items():
status = "β
PASSED" if passed else "β FAILED"
logger.info(f"{status}: {test_name}")
all_passed = all(results.values())
return all_passed, results
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
all_passed, results = run_all_integration_tests()
if all_passed:
logger.info("\nβ
All integration tests passed!")
exit(0)
else:
logger.error("\nβ Some integration tests failed!")
exit(1)
|