"""Comprehensive smoke and integration tests for MEDCARE-DDI system. Tests: 1. Model and pipeline loading 2. Exact DDInter lookup 3. ML fallback predictions 4. Calibration artifacts 5. Hybrid inference 6. API responses """ from __future__ import annotations import json import logging import sys from pathlib import Path from typing import Any, Dict # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', ) logger = logging.getLogger('medcare_ddi.validation') # Add src to path ROOT = Path(__file__).resolve().parents[2] sys.path.insert(0, str(ROOT / 'src')) from inference.predictor import ( HybridDDIPredictor, MODEL_DIR, PRODUCTION_MODEL_PATH, MODEL_PATH, CALIBRATION_PATH, FEATURE_PIPELINE_MULTISOURCE_PATH, DATA_PATH, ) class ValidationReport: """Tracks validation results.""" def __init__(self): self.tests = [] self.passed = 0 self.failed = 0 self.errors = [] def test(self, name: str, passed: bool, details: str = '') -> None: """Record a test result.""" status = '✓' if passed else '✗' logger.info(f'{status} {name}') if details: logger.info(f' {details}') self.tests.append({ 'name': name, 'passed': passed, 'details': details, }) if passed: self.passed += 1 else: self.failed += 1 self.errors.append(name) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { 'total': len(self.tests), 'passed': self.passed, 'failed': self.failed, 'pass_rate': round(self.passed / len(self.tests) * 100, 2) if self.tests else 0, 'tests': self.tests, 'errors': self.errors, } def save(self, path: Path) -> None: """Save report to JSON.""" with path.open('w') as f: json.dump(self.to_dict(), f, indent=2) logger.info(f'Report saved to {path}') def test_file_existence(report: ValidationReport) -> None: """Test that all required files exist.""" logger.info('\n' + '='*70) logger.info('PHASE 1: FILE INTEGRITY') logger.info('='*70) # Check data report.test( 'DDInter data exists', DATA_PATH.exists(), f'Path: {DATA_PATH}' ) # Check feature pipeline report.test( 'Feature pipeline exists', FEATURE_PIPELINE_MULTISOURCE_PATH.exists(), f'Path: {FEATURE_PIPELINE_MULTISOURCE_PATH}' ) # Check model - prefer production model, fallback to standard model_exists = PRODUCTION_MODEL_PATH.exists() or MODEL_PATH.exists() model_used = PRODUCTION_MODEL_PATH if PRODUCTION_MODEL_PATH.exists() else MODEL_PATH report.test( 'Model checkpoint exists', model_exists, f'Using: {model_used.name}' ) # Check calibration (optional) if CALIBRATION_PATH.exists(): report.test( 'Calibration artifacts (optional)', True, f'Path: {CALIBRATION_PATH} (optional)' ) else: logger.info('⚠ Calibration artifacts (optional)') logger.info(f' Path: {CALIBRATION_PATH} (optional, missing)') def test_predictor_loading(report: ValidationReport) -> HybridDDIPredictor | None: """Test loading the predictor.""" logger.info('\n' + '='*70) logger.info('PHASE 2: PREDICTOR INITIALIZATION') logger.info('='*70) try: logger.info('Loading predictor...') predictor = HybridDDIPredictor.from_default_paths(use_production=True) report.test( 'Predictor loads successfully', True, f'Model version: {predictor.model_version}' ) return predictor except Exception as e: report.test( 'Predictor loads successfully', False, f'Error: {str(e)}' ) return None def test_health_check(report: ValidationReport, predictor: HybridDDIPredictor) -> None: """Test health check.""" logger.info('\n' + '='*70) logger.info('PHASE 3: HEALTH CHECK') logger.info('='*70) if predictor is None: report.test('Health check', False, 'Predictor not loaded') return try: health = predictor.health() report.test( 'Health check returns valid response', 'status' in health and health['status'] in ['ok', 'healthy'], f"Status: {health.get('status')}" ) report.test( 'Model is loaded', health.get('model_loaded', False), f"Model version: {health.get('model_version')}" ) report.test( 'Pipeline is loaded', health.get('pipeline_loaded', False), f"Type: {health.get('model_type')}" ) report.test( 'DDInter pairs loaded', health.get('pairs_loaded', 0) > 0, f"Pairs: {health.get('pairs_loaded'):,}" ) if health.get('calibration_loaded', False): report.test( 'Calibration loaded (optional)', True, 'Optional - improves confidence calibration' ) else: logger.info('⚠ Calibration loaded (optional)') logger.info(' Optional - improves confidence calibration') except Exception as e: report.test('Health check', False, f'Error: {str(e)}') def test_exact_lookup(report: ValidationReport, predictor: HybridDDIPredictor) -> None: """Test exact DDInter lookup.""" logger.info('\n' + '='*70) logger.info('PHASE 4: EXACT LOOKUP TEST') logger.info('='*70) if predictor is None: report.test('Exact lookup test', False, 'Predictor not loaded') return # Test with a known pair from DDInter test_pairs = [ ('Aspirin', 'Warfarin'), ('Metformin', 'Ibuprofen'), ('Omeprazole', 'Lisinopril'), ] for drug_a, drug_b in test_pairs: try: result = predictor.predict(drug_a, drug_b) is_exact = result.get('source') == 'ddinter_lookup' if is_exact: report.test( f'Exact lookup: {drug_a} + {drug_b}', True, f"Severity: {result.get('severity')}, Confidence: {result.get('confidence')}" ) except Exception as e: logger.debug(f'Test pair ({drug_a}, {drug_b}) failed: {e}') def test_ml_fallback(report: ValidationReport, predictor: HybridDDIPredictor) -> None: """Test ML fallback predictions.""" logger.info('\n' + '='*70) logger.info('PHASE 5: ML FALLBACK TEST') logger.info('='*70) if predictor is None: report.test('ML fallback test', False, 'Predictor not loaded') return # Test with unknown pairs to trigger ML test_pairs = [ ('Caffeine', 'Aspirin'), ('Vitamin D', 'Calcium'), ('Magnesium', 'Iron'), ] ml_count = 0 for drug_a, drug_b in test_pairs: try: result = predictor.predict(drug_a, drug_b) # Any result (exact or ML) is valid severity = result.get('severity') confidence = result.get('confidence') confidence_band = result.get('confidence_band') source = result.get('source') if source == 'deep_learning_prediction': ml_count += 1 logger.debug( f'Prediction: {drug_a} + {drug_b} = {severity} ' f'(confidence: {confidence:.3f}, band: {confidence_band})' ) except Exception as e: logger.warning(f'Prediction failed for ({drug_a}, {drug_b}): {e}') report.test( 'ML fallback predictions work', ml_count > 0, f'{ml_count} predictions used ML model' ) def test_confidence_bands(report: ValidationReport, predictor: HybridDDIPredictor) -> None: """Test confidence band classification.""" logger.info('\n' + '='*70) logger.info('PHASE 6: CONFIDENCE BAND TEST') logger.info('='*70) if predictor is None: report.test('Confidence band test', False, 'Predictor not loaded') return try: # Test multiple predictions and check confidence bands test_pairs = [ ('Aspirin', 'Ibuprofen'), ('Warfarin', 'Aspirin'), ('Metformin', 'Warfarin'), ] band_counts = {'high': 0, 'medium': 0, 'low': 0} valid_bands = 0 for drug_a, drug_b in test_pairs: try: result = predictor.predict(drug_a, drug_b) band = result.get('confidence_band', 'low').lower() if band in band_counts: band_counts[band] += 1 valid_bands += 1 except Exception as e: logger.debug(f'Prediction failed: {e}') report.test( 'Confidence bands are valid', valid_bands == len(test_pairs), f"Distribution: HIGH={band_counts['high']}, MEDIUM={band_counts['medium']}, LOW={band_counts['low']}" ) except Exception as e: report.test('Confidence band test', False, f'Error: {str(e)}') def test_response_schema(report: ValidationReport, predictor: HybridDDIPredictor) -> None: """Test response schema completeness.""" logger.info('\n' + '='*70) logger.info('PHASE 7: RESPONSE SCHEMA TEST') logger.info('='*70) if predictor is None: report.test('Response schema test', False, 'Predictor not loaded') return try: result = predictor.predict('Aspirin', 'Warfarin') required_fields = [ 'source', 'confidence', 'severity', 'explanation', 'clinical_advice', 'drug_a_name', 'drug_b_name', ] missing_fields = [field for field in required_fields if field not in result] report.test( 'Response has all required fields', len(missing_fields) == 0, f"Fields: {', '.join(required_fields)}" ) # Test optional fields has_warning = 'warning' in result has_probs = 'probabilities' in result report.test( 'Response includes warning field (safety)', has_warning, 'Enables healthcare safety checks' ) report.test( 'Response includes probabilities (explainability)', has_probs or result.get('source') == 'ddinter_lookup', 'ML predictions should include probabilities' ) except Exception as e: report.test('Response schema test', False, f'Error: {str(e)}') def run_all_tests() -> ValidationReport: """Run all validation tests.""" logger.info('\n' + '#'*70) logger.info('# MEDCARE-DDI SYSTEM VALIDATION SUITE') logger.info('#'*70) report = ValidationReport() # Phase 1: File integrity test_file_existence(report) # Phase 2: Predictor loading predictor = test_predictor_loading(report) if predictor is None: logger.error('Cannot proceed without predictor') return report # Phase 3: Health check test_health_check(report, predictor) # Phase 4: Exact lookup test_exact_lookup(report, predictor) # Phase 5: ML fallback test_ml_fallback(report, predictor) # Phase 6: Confidence bands test_confidence_bands(report, predictor) # Phase 7: Response schema test_response_schema(report, predictor) # Final summary logger.info('\n' + '='*70) logger.info('VALIDATION SUMMARY') logger.info('='*70) logger.info(f"Passed: {report.passed}/{len(report.tests)}") logger.info(f"Failed: {report.failed}/{len(report.tests)}") logger.info(f"Pass rate: {report.to_dict()['pass_rate']:.1f}%") if report.errors: logger.error(f"Failed tests: {', '.join(report.errors)}") return report if __name__ == '__main__': report = run_all_tests() # Save report report_path = MODEL_DIR / 'reports' / 'validation_report.json' report_path.parent.mkdir(parents=True, exist_ok=True) report.save(report_path) # Exit with appropriate code sys.exit(0 if report.failed == 0 else 1)