Spaces:
Running
Running
| """Comprehensive smoke and integration tests for MEDCARE-DDI system. | |
| Tests: | |
| 1. Model and pipeline loading | |
| 2. Exact DDInter lookup | |
| 3. ML fallback predictions | |
| 4. Calibration artifacts | |
| 5. Hybrid inference | |
| 6. API responses | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', | |
| ) | |
| logger = logging.getLogger('medcare_ddi.validation') | |
| # Add src to path | |
| ROOT = Path(__file__).resolve().parents[2] | |
| sys.path.insert(0, str(ROOT / 'src')) | |
| from inference.predictor import ( | |
| HybridDDIPredictor, | |
| MODEL_DIR, | |
| PRODUCTION_MODEL_PATH, | |
| MODEL_PATH, | |
| CALIBRATION_PATH, | |
| FEATURE_PIPELINE_MULTISOURCE_PATH, | |
| DATA_PATH, | |
| ) | |
| class ValidationReport: | |
| """Tracks validation results.""" | |
| def __init__(self): | |
| self.tests = [] | |
| self.passed = 0 | |
| self.failed = 0 | |
| self.errors = [] | |
| def test(self, name: str, passed: bool, details: str = '') -> None: | |
| """Record a test result.""" | |
| status = '✓' if passed else '✗' | |
| logger.info(f'{status} {name}') | |
| if details: | |
| logger.info(f' {details}') | |
| self.tests.append({ | |
| 'name': name, | |
| 'passed': passed, | |
| 'details': details, | |
| }) | |
| if passed: | |
| self.passed += 1 | |
| else: | |
| self.failed += 1 | |
| self.errors.append(name) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary.""" | |
| return { | |
| 'total': len(self.tests), | |
| 'passed': self.passed, | |
| 'failed': self.failed, | |
| 'pass_rate': round(self.passed / len(self.tests) * 100, 2) if self.tests else 0, | |
| 'tests': self.tests, | |
| 'errors': self.errors, | |
| } | |
| def save(self, path: Path) -> None: | |
| """Save report to JSON.""" | |
| with path.open('w') as f: | |
| json.dump(self.to_dict(), f, indent=2) | |
| logger.info(f'Report saved to {path}') | |
| def test_file_existence(report: ValidationReport) -> None: | |
| """Test that all required files exist.""" | |
| logger.info('\n' + '='*70) | |
| logger.info('PHASE 1: FILE INTEGRITY') | |
| logger.info('='*70) | |
| # Check data | |
| report.test( | |
| 'DDInter data exists', | |
| DATA_PATH.exists(), | |
| f'Path: {DATA_PATH}' | |
| ) | |
| # Check feature pipeline | |
| report.test( | |
| 'Feature pipeline exists', | |
| FEATURE_PIPELINE_MULTISOURCE_PATH.exists(), | |
| f'Path: {FEATURE_PIPELINE_MULTISOURCE_PATH}' | |
| ) | |
| # Check model - prefer production model, fallback to standard | |
| model_exists = PRODUCTION_MODEL_PATH.exists() or MODEL_PATH.exists() | |
| model_used = PRODUCTION_MODEL_PATH if PRODUCTION_MODEL_PATH.exists() else MODEL_PATH | |
| report.test( | |
| 'Model checkpoint exists', | |
| model_exists, | |
| f'Using: {model_used.name}' | |
| ) | |
| # Check calibration (optional) | |
| if CALIBRATION_PATH.exists(): | |
| report.test( | |
| 'Calibration artifacts (optional)', | |
| True, | |
| f'Path: {CALIBRATION_PATH} (optional)' | |
| ) | |
| else: | |
| logger.info('⚠ Calibration artifacts (optional)') | |
| logger.info(f' Path: {CALIBRATION_PATH} (optional, missing)') | |
| def test_predictor_loading(report: ValidationReport) -> HybridDDIPredictor | None: | |
| """Test loading the predictor.""" | |
| logger.info('\n' + '='*70) | |
| logger.info('PHASE 2: PREDICTOR INITIALIZATION') | |
| logger.info('='*70) | |
| try: | |
| logger.info('Loading predictor...') | |
| predictor = HybridDDIPredictor.from_default_paths(use_production=True) | |
| report.test( | |
| 'Predictor loads successfully', | |
| True, | |
| f'Model version: {predictor.model_version}' | |
| ) | |
| return predictor | |
| except Exception as e: | |
| report.test( | |
| 'Predictor loads successfully', | |
| False, | |
| f'Error: {str(e)}' | |
| ) | |
| return None | |
| def test_health_check(report: ValidationReport, predictor: HybridDDIPredictor) -> None: | |
| """Test health check.""" | |
| logger.info('\n' + '='*70) | |
| logger.info('PHASE 3: HEALTH CHECK') | |
| logger.info('='*70) | |
| if predictor is None: | |
| report.test('Health check', False, 'Predictor not loaded') | |
| return | |
| try: | |
| health = predictor.health() | |
| report.test( | |
| 'Health check returns valid response', | |
| 'status' in health and health['status'] in ['ok', 'healthy'], | |
| f"Status: {health.get('status')}" | |
| ) | |
| report.test( | |
| 'Model is loaded', | |
| health.get('model_loaded', False), | |
| f"Model version: {health.get('model_version')}" | |
| ) | |
| report.test( | |
| 'Pipeline is loaded', | |
| health.get('pipeline_loaded', False), | |
| f"Type: {health.get('model_type')}" | |
| ) | |
| report.test( | |
| 'DDInter pairs loaded', | |
| health.get('pairs_loaded', 0) > 0, | |
| f"Pairs: {health.get('pairs_loaded'):,}" | |
| ) | |
| if health.get('calibration_loaded', False): | |
| report.test( | |
| 'Calibration loaded (optional)', | |
| True, | |
| 'Optional - improves confidence calibration' | |
| ) | |
| else: | |
| logger.info('⚠ Calibration loaded (optional)') | |
| logger.info(' Optional - improves confidence calibration') | |
| except Exception as e: | |
| report.test('Health check', False, f'Error: {str(e)}') | |
| def test_exact_lookup(report: ValidationReport, predictor: HybridDDIPredictor) -> None: | |
| """Test exact DDInter lookup.""" | |
| logger.info('\n' + '='*70) | |
| logger.info('PHASE 4: EXACT LOOKUP TEST') | |
| logger.info('='*70) | |
| if predictor is None: | |
| report.test('Exact lookup test', False, 'Predictor not loaded') | |
| return | |
| # Test with a known pair from DDInter | |
| test_pairs = [ | |
| ('Aspirin', 'Warfarin'), | |
| ('Metformin', 'Ibuprofen'), | |
| ('Omeprazole', 'Lisinopril'), | |
| ] | |
| for drug_a, drug_b in test_pairs: | |
| try: | |
| result = predictor.predict(drug_a, drug_b) | |
| is_exact = result.get('source') == 'ddinter_lookup' | |
| if is_exact: | |
| report.test( | |
| f'Exact lookup: {drug_a} + {drug_b}', | |
| True, | |
| f"Severity: {result.get('severity')}, Confidence: {result.get('confidence')}" | |
| ) | |
| except Exception as e: | |
| logger.debug(f'Test pair ({drug_a}, {drug_b}) failed: {e}') | |
| def test_ml_fallback(report: ValidationReport, predictor: HybridDDIPredictor) -> None: | |
| """Test ML fallback predictions.""" | |
| logger.info('\n' + '='*70) | |
| logger.info('PHASE 5: ML FALLBACK TEST') | |
| logger.info('='*70) | |
| if predictor is None: | |
| report.test('ML fallback test', False, 'Predictor not loaded') | |
| return | |
| # Test with unknown pairs to trigger ML | |
| test_pairs = [ | |
| ('Caffeine', 'Aspirin'), | |
| ('Vitamin D', 'Calcium'), | |
| ('Magnesium', 'Iron'), | |
| ] | |
| ml_count = 0 | |
| for drug_a, drug_b in test_pairs: | |
| try: | |
| result = predictor.predict(drug_a, drug_b) | |
| # Any result (exact or ML) is valid | |
| severity = result.get('severity') | |
| confidence = result.get('confidence') | |
| confidence_band = result.get('confidence_band') | |
| source = result.get('source') | |
| if source == 'deep_learning_prediction': | |
| ml_count += 1 | |
| logger.debug( | |
| f'Prediction: {drug_a} + {drug_b} = {severity} ' | |
| f'(confidence: {confidence:.3f}, band: {confidence_band})' | |
| ) | |
| except Exception as e: | |
| logger.warning(f'Prediction failed for ({drug_a}, {drug_b}): {e}') | |
| report.test( | |
| 'ML fallback predictions work', | |
| ml_count > 0, | |
| f'{ml_count} predictions used ML model' | |
| ) | |
| def test_confidence_bands(report: ValidationReport, predictor: HybridDDIPredictor) -> None: | |
| """Test confidence band classification.""" | |
| logger.info('\n' + '='*70) | |
| logger.info('PHASE 6: CONFIDENCE BAND TEST') | |
| logger.info('='*70) | |
| if predictor is None: | |
| report.test('Confidence band test', False, 'Predictor not loaded') | |
| return | |
| try: | |
| # Test multiple predictions and check confidence bands | |
| test_pairs = [ | |
| ('Aspirin', 'Ibuprofen'), | |
| ('Warfarin', 'Aspirin'), | |
| ('Metformin', 'Warfarin'), | |
| ] | |
| band_counts = {'high': 0, 'medium': 0, 'low': 0} | |
| valid_bands = 0 | |
| for drug_a, drug_b in test_pairs: | |
| try: | |
| result = predictor.predict(drug_a, drug_b) | |
| band = result.get('confidence_band', 'low').lower() | |
| if band in band_counts: | |
| band_counts[band] += 1 | |
| valid_bands += 1 | |
| except Exception as e: | |
| logger.debug(f'Prediction failed: {e}') | |
| report.test( | |
| 'Confidence bands are valid', | |
| valid_bands == len(test_pairs), | |
| f"Distribution: HIGH={band_counts['high']}, MEDIUM={band_counts['medium']}, LOW={band_counts['low']}" | |
| ) | |
| except Exception as e: | |
| report.test('Confidence band test', False, f'Error: {str(e)}') | |
| def test_response_schema(report: ValidationReport, predictor: HybridDDIPredictor) -> None: | |
| """Test response schema completeness.""" | |
| logger.info('\n' + '='*70) | |
| logger.info('PHASE 7: RESPONSE SCHEMA TEST') | |
| logger.info('='*70) | |
| if predictor is None: | |
| report.test('Response schema test', False, 'Predictor not loaded') | |
| return | |
| try: | |
| result = predictor.predict('Aspirin', 'Warfarin') | |
| required_fields = [ | |
| 'source', | |
| 'confidence', | |
| 'severity', | |
| 'explanation', | |
| 'clinical_advice', | |
| 'drug_a_name', | |
| 'drug_b_name', | |
| ] | |
| missing_fields = [field for field in required_fields if field not in result] | |
| report.test( | |
| 'Response has all required fields', | |
| len(missing_fields) == 0, | |
| f"Fields: {', '.join(required_fields)}" | |
| ) | |
| # Test optional fields | |
| has_warning = 'warning' in result | |
| has_probs = 'probabilities' in result | |
| report.test( | |
| 'Response includes warning field (safety)', | |
| has_warning, | |
| 'Enables healthcare safety checks' | |
| ) | |
| report.test( | |
| 'Response includes probabilities (explainability)', | |
| has_probs or result.get('source') == 'ddinter_lookup', | |
| 'ML predictions should include probabilities' | |
| ) | |
| except Exception as e: | |
| report.test('Response schema test', False, f'Error: {str(e)}') | |
| def run_all_tests() -> ValidationReport: | |
| """Run all validation tests.""" | |
| logger.info('\n' + '#'*70) | |
| logger.info('# MEDCARE-DDI SYSTEM VALIDATION SUITE') | |
| logger.info('#'*70) | |
| report = ValidationReport() | |
| # Phase 1: File integrity | |
| test_file_existence(report) | |
| # Phase 2: Predictor loading | |
| predictor = test_predictor_loading(report) | |
| if predictor is None: | |
| logger.error('Cannot proceed without predictor') | |
| return report | |
| # Phase 3: Health check | |
| test_health_check(report, predictor) | |
| # Phase 4: Exact lookup | |
| test_exact_lookup(report, predictor) | |
| # Phase 5: ML fallback | |
| test_ml_fallback(report, predictor) | |
| # Phase 6: Confidence bands | |
| test_confidence_bands(report, predictor) | |
| # Phase 7: Response schema | |
| test_response_schema(report, predictor) | |
| # Final summary | |
| logger.info('\n' + '='*70) | |
| logger.info('VALIDATION SUMMARY') | |
| logger.info('='*70) | |
| logger.info(f"Passed: {report.passed}/{len(report.tests)}") | |
| logger.info(f"Failed: {report.failed}/{len(report.tests)}") | |
| logger.info(f"Pass rate: {report.to_dict()['pass_rate']:.1f}%") | |
| if report.errors: | |
| logger.error(f"Failed tests: {', '.join(report.errors)}") | |
| return report | |
| if __name__ == '__main__': | |
| report = run_all_tests() | |
| # Save report | |
| report_path = MODEL_DIR / 'reports' / 'validation_report.json' | |
| report_path.parent.mkdir(parents=True, exist_ok=True) | |
| report.save(report_path) | |
| # Exit with appropriate code | |
| sys.exit(0 if report.failed == 0 else 1) | |