zlaqa-version-c-ai-enginee / tests /integration_tests.py
anfastech's picture
New: implemented many, many changes. 10% Phone-level detection: WORKING
278e294
"""
Integration tests for speech pathology diagnosis API.
Tests API endpoints, error mapping, and therapy recommendations.
"""
import logging
import numpy as np
import tempfile
import soundfile as sf
from pathlib import Path
import json
logger = logging.getLogger(__name__)
def test_phoneme_mapping():
"""Test phoneme mapping functionality."""
logger.info("Testing phoneme mapping...")
try:
from models.phoneme_mapper import PhonemeMapper
mapper = PhonemeMapper(frame_duration_ms=20)
# Test 1: Simple word
phonemes = mapper.text_to_phonemes("robot")
assert len(phonemes) > 0, "Should extract phonemes"
logger.info(f"βœ… 'robot' β†’ {len(phonemes)} phonemes: {[p[0] for p in phonemes]}")
# Test 2: Frame alignment
frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25)
assert len(frame_phonemes) == 25, "Should have 25 frames"
logger.info(f"βœ… Aligned to {len(frame_phonemes)} frames")
# Test 3: Complete pipeline
cat_frames = mapper.map_text_to_frames("cat", num_frames=15)
assert len(cat_frames) == 15, "Should have 15 frames"
logger.info(f"βœ… 'cat' β†’ {len(cat_frames)} frame phonemes")
return True
except ImportError as e:
logger.warning(f"⚠️ G2P library not available: {e}")
return False
except Exception as e:
logger.error(f"❌ Phoneme mapping test failed: {e}")
return False
def test_error_taxonomy():
"""Test error taxonomy and therapy mapping."""
logger.info("Testing error taxonomy...")
try:
from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel
mapper = ErrorMapper()
# Test 1: Normal (class 0)
error = mapper.map_classifier_output(0, 0.95, "/k/")
assert error.error_type == ErrorType.NORMAL
assert error.severity == 0.0
logger.info(f"βœ… Normal error mapping: {error.error_type}")
# Test 2: Substitution (class 1)
error = mapper.map_classifier_output(1, 0.78, "/s/")
assert error.error_type == ErrorType.SUBSTITUTION
assert error.wrong_sound is not None
logger.info(f"βœ… Substitution error: {error.error_type}, wrong_sound={error.wrong_sound}")
logger.info(f" Therapy: {error.therapy[:60]}...")
# Test 3: Omission (class 2)
error = mapper.map_classifier_output(2, 0.85, "/r/")
assert error.error_type == ErrorType.OMISSION
logger.info(f"βœ… Omission error: {error.error_type}")
logger.info(f" Therapy: {error.therapy[:60]}...")
# Test 4: Distortion (class 3)
error = mapper.map_classifier_output(3, 0.65, "/s/")
assert error.error_type == ErrorType.DISTORTION
logger.info(f"βœ… Distortion error: {error.error_type}")
logger.info(f" Therapy: {error.therapy[:60]}...")
# Test 5: Severity levels
assert mapper.get_severity_level(0.0) == SeverityLevel.NONE
assert mapper.get_severity_level(0.2) == SeverityLevel.LOW
assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM
assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH
logger.info("βœ… Severity level mapping correct")
return True
except Exception as e:
logger.error(f"❌ Error taxonomy test failed: {e}")
return False
def test_batch_diagnosis_endpoint(pipeline, phoneme_mapper, error_mapper):
"""Test batch diagnosis endpoint functionality."""
logger.info("Testing batch diagnosis endpoint...")
try:
# Generate test audio
duration = 2.0
sample_rate = 16000
num_samples = int(duration * sample_rate)
audio = 0.5 * np.sin(2 * np.pi * 440 * np.linspace(0, duration, num_samples))
audio = audio.astype(np.float32)
# Save to temp file
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
temp_path = f.name
sf.write(temp_path, audio, sample_rate)
try:
# Run inference
result = pipeline.predict_phone_level(temp_path, return_timestamps=True)
# Map phonemes
text = "test audio"
frame_phonemes = phoneme_mapper.map_text_to_frames(
text,
num_frames=result.num_frames,
audio_duration=result.duration
)
# Process errors
errors = []
for i, frame_pred in enumerate(result.frame_predictions):
class_id = frame_pred.articulation_class
if frame_pred.fluency_label == 'stutter':
class_id += 4
error_detail = error_mapper.map_classifier_output(
class_id=class_id,
confidence=frame_pred.confidence,
phoneme=frame_phonemes[i] if i < len(frame_phonemes) else '',
fluency_label=frame_pred.fluency_label
)
if error_detail.error_type != ErrorType.NORMAL:
errors.append(error_detail)
logger.info(f"βœ… Batch diagnosis: {result.num_frames} frames, {len(errors)} errors detected")
return True
finally:
import os
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e:
logger.error(f"❌ Batch diagnosis test failed: {e}")
return False
def test_therapy_recommendations():
"""Test therapy recommendation coverage."""
logger.info("Testing therapy recommendations...")
try:
from models.error_taxonomy import ErrorMapper, ErrorType
mapper = ErrorMapper()
# Test common phonemes
test_cases = [
("/s/", ErrorType.SUBSTITUTION, "/ΞΈ/"),
("/r/", ErrorType.OMISSION, None),
("/s/", ErrorType.DISTORTION, None),
]
for phoneme, error_type, wrong_sound in test_cases:
therapy = mapper.get_therapy(error_type, phoneme, wrong_sound)
assert therapy and len(therapy) > 0, f"Therapy should not be empty for {phoneme}"
logger.info(f"βœ… {phoneme} {error_type.value}: {therapy[:50]}...")
return True
except Exception as e:
logger.error(f"❌ Therapy recommendations test failed: {e}")
return False
def run_all_integration_tests():
"""Run all integration tests."""
logger.info("=" * 60)
logger.info("Running Integration Tests")
logger.info("=" * 60)
results = {}
# Test 1: Phoneme mapping
logger.info("\n1. Phoneme Mapping Test")
results["phoneme_mapping"] = test_phoneme_mapping()
# Test 2: Error taxonomy
logger.info("\n2. Error Taxonomy Test")
results["error_taxonomy"] = test_error_taxonomy()
# Test 3: Therapy recommendations
logger.info("\n3. Therapy Recommendations Test")
results["therapy_recommendations"] = test_therapy_recommendations()
# Test 4: Batch diagnosis (if pipeline available)
try:
from inference.inference_pipeline import create_inference_pipeline
from models.phoneme_mapper import PhonemeMapper
from models.error_taxonomy import ErrorMapper
logger.info("\n4. Batch Diagnosis Test")
pipeline = create_inference_pipeline()
phoneme_mapper = PhonemeMapper()
error_mapper = ErrorMapper()
results["batch_diagnosis"] = test_batch_diagnosis_endpoint(
pipeline, phoneme_mapper, error_mapper
)
except Exception as e:
logger.warning(f"⚠️ Batch diagnosis test skipped: {e}")
results["batch_diagnosis"] = False
# Summary
logger.info("\n" + "=" * 60)
logger.info("Integration Test Summary")
logger.info("=" * 60)
for test_name, passed in results.items():
status = "βœ… PASSED" if passed else "❌ FAILED"
logger.info(f"{status}: {test_name}")
all_passed = all(results.values())
return all_passed, results
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
all_passed, results = run_all_integration_tests()
if all_passed:
logger.info("\nβœ… All integration tests passed!")
exit(0)
else:
logger.error("\n❌ Some integration tests failed!")
exit(1)