""" Integration tests for speech pathology diagnosis API. Tests API endpoints, error mapping, and therapy recommendations. """ import logging import numpy as np import tempfile import soundfile as sf from pathlib import Path import json logger = logging.getLogger(__name__) def test_phoneme_mapping(): """Test phoneme mapping functionality.""" logger.info("Testing phoneme mapping...") try: from models.phoneme_mapper import PhonemeMapper mapper = PhonemeMapper(frame_duration_ms=20) # Test 1: Simple word phonemes = mapper.text_to_phonemes("robot") assert len(phonemes) > 0, "Should extract phonemes" logger.info(f"✅ 'robot' → {len(phonemes)} phonemes: {[p[0] for p in phonemes]}") # Test 2: Frame alignment frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25) assert len(frame_phonemes) == 25, "Should have 25 frames" logger.info(f"✅ Aligned to {len(frame_phonemes)} frames") # Test 3: Complete pipeline cat_frames = mapper.map_text_to_frames("cat", num_frames=15) assert len(cat_frames) == 15, "Should have 15 frames" logger.info(f"✅ 'cat' → {len(cat_frames)} frame phonemes") return True except ImportError as e: logger.warning(f"⚠️ G2P library not available: {e}") return False except Exception as e: logger.error(f"❌ Phoneme mapping test failed: {e}") return False def test_error_taxonomy(): """Test error taxonomy and therapy mapping.""" logger.info("Testing error taxonomy...") try: from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel mapper = ErrorMapper() # Test 1: Normal (class 0) error = mapper.map_classifier_output(0, 0.95, "/k/") assert error.error_type == ErrorType.NORMAL assert error.severity == 0.0 logger.info(f"✅ Normal error mapping: {error.error_type}") # Test 2: Substitution (class 1) error = mapper.map_classifier_output(1, 0.78, "/s/") assert error.error_type == ErrorType.SUBSTITUTION assert error.wrong_sound is not None logger.info(f"✅ Substitution error: {error.error_type}, wrong_sound={error.wrong_sound}") logger.info(f" Therapy: {error.therapy[:60]}...") # Test 3: Omission (class 2) error = mapper.map_classifier_output(2, 0.85, "/r/") assert error.error_type == ErrorType.OMISSION logger.info(f"✅ Omission error: {error.error_type}") logger.info(f" Therapy: {error.therapy[:60]}...") # Test 4: Distortion (class 3) error = mapper.map_classifier_output(3, 0.65, "/s/") assert error.error_type == ErrorType.DISTORTION logger.info(f"✅ Distortion error: {error.error_type}") logger.info(f" Therapy: {error.therapy[:60]}...") # Test 5: Severity levels assert mapper.get_severity_level(0.0) == SeverityLevel.NONE assert mapper.get_severity_level(0.2) == SeverityLevel.LOW assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH logger.info("✅ Severity level mapping correct") return True except Exception as e: logger.error(f"❌ Error taxonomy test failed: {e}") return False def test_batch_diagnosis_endpoint(pipeline, phoneme_mapper, error_mapper): """Test batch diagnosis endpoint functionality.""" logger.info("Testing batch diagnosis endpoint...") try: # Generate test audio duration = 2.0 sample_rate = 16000 num_samples = int(duration * sample_rate) audio = 0.5 * np.sin(2 * np.pi * 440 * np.linspace(0, duration, num_samples)) audio = audio.astype(np.float32) # Save to temp file with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: temp_path = f.name sf.write(temp_path, audio, sample_rate) try: # Run inference result = pipeline.predict_phone_level(temp_path, return_timestamps=True) # Map phonemes text = "test audio" frame_phonemes = phoneme_mapper.map_text_to_frames( text, num_frames=result.num_frames, audio_duration=result.duration ) # Process errors errors = [] for i, frame_pred in enumerate(result.frame_predictions): class_id = frame_pred.articulation_class if frame_pred.fluency_label == 'stutter': class_id += 4 error_detail = error_mapper.map_classifier_output( class_id=class_id, confidence=frame_pred.confidence, phoneme=frame_phonemes[i] if i < len(frame_phonemes) else '', fluency_label=frame_pred.fluency_label ) if error_detail.error_type != ErrorType.NORMAL: errors.append(error_detail) logger.info(f"✅ Batch diagnosis: {result.num_frames} frames, {len(errors)} errors detected") return True finally: import os if os.path.exists(temp_path): os.remove(temp_path) except Exception as e: logger.error(f"❌ Batch diagnosis test failed: {e}") return False def test_therapy_recommendations(): """Test therapy recommendation coverage.""" logger.info("Testing therapy recommendations...") try: from models.error_taxonomy import ErrorMapper, ErrorType mapper = ErrorMapper() # Test common phonemes test_cases = [ ("/s/", ErrorType.SUBSTITUTION, "/θ/"), ("/r/", ErrorType.OMISSION, None), ("/s/", ErrorType.DISTORTION, None), ] for phoneme, error_type, wrong_sound in test_cases: therapy = mapper.get_therapy(error_type, phoneme, wrong_sound) assert therapy and len(therapy) > 0, f"Therapy should not be empty for {phoneme}" logger.info(f"✅ {phoneme} {error_type.value}: {therapy[:50]}...") return True except Exception as e: logger.error(f"❌ Therapy recommendations test failed: {e}") return False def run_all_integration_tests(): """Run all integration tests.""" logger.info("=" * 60) logger.info("Running Integration Tests") logger.info("=" * 60) results = {} # Test 1: Phoneme mapping logger.info("\n1. Phoneme Mapping Test") results["phoneme_mapping"] = test_phoneme_mapping() # Test 2: Error taxonomy logger.info("\n2. Error Taxonomy Test") results["error_taxonomy"] = test_error_taxonomy() # Test 3: Therapy recommendations logger.info("\n3. Therapy Recommendations Test") results["therapy_recommendations"] = test_therapy_recommendations() # Test 4: Batch diagnosis (if pipeline available) try: from inference.inference_pipeline import create_inference_pipeline from models.phoneme_mapper import PhonemeMapper from models.error_taxonomy import ErrorMapper logger.info("\n4. Batch Diagnosis Test") pipeline = create_inference_pipeline() phoneme_mapper = PhonemeMapper() error_mapper = ErrorMapper() results["batch_diagnosis"] = test_batch_diagnosis_endpoint( pipeline, phoneme_mapper, error_mapper ) except Exception as e: logger.warning(f"⚠️ Batch diagnosis test skipped: {e}") results["batch_diagnosis"] = False # Summary logger.info("\n" + "=" * 60) logger.info("Integration Test Summary") logger.info("=" * 60) for test_name, passed in results.items(): status = "✅ PASSED" if passed else "❌ FAILED" logger.info(f"{status}: {test_name}") all_passed = all(results.values()) return all_passed, results if __name__ == "__main__": logging.basicConfig(level=logging.INFO) all_passed, results = run_all_integration_tests() if all_passed: logger.info("\n✅ All integration tests passed!") exit(0) else: logger.error("\n❌ Some integration tests failed!") exit(1)