ddi / src /validation /smoke_test.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""Comprehensive smoke and integration tests for MEDCARE-DDI system.
Tests:
1. Model and pipeline loading
2. Exact DDInter lookup
3. ML fallback predictions
4. Calibration artifacts
5. Hybrid inference
6. API responses
"""
from __future__ import annotations
import json
import logging
import sys
from pathlib import Path
from typing import Any, Dict
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
)
logger = logging.getLogger('medcare_ddi.validation')
# Add src to path
ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT / 'src'))
from inference.predictor import (
HybridDDIPredictor,
MODEL_DIR,
PRODUCTION_MODEL_PATH,
MODEL_PATH,
CALIBRATION_PATH,
FEATURE_PIPELINE_MULTISOURCE_PATH,
DATA_PATH,
)
class ValidationReport:
"""Tracks validation results."""
def __init__(self):
self.tests = []
self.passed = 0
self.failed = 0
self.errors = []
def test(self, name: str, passed: bool, details: str = '') -> None:
"""Record a test result."""
status = '✓' if passed else '✗'
logger.info(f'{status} {name}')
if details:
logger.info(f' {details}')
self.tests.append({
'name': name,
'passed': passed,
'details': details,
})
if passed:
self.passed += 1
else:
self.failed += 1
self.errors.append(name)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'total': len(self.tests),
'passed': self.passed,
'failed': self.failed,
'pass_rate': round(self.passed / len(self.tests) * 100, 2) if self.tests else 0,
'tests': self.tests,
'errors': self.errors,
}
def save(self, path: Path) -> None:
"""Save report to JSON."""
with path.open('w') as f:
json.dump(self.to_dict(), f, indent=2)
logger.info(f'Report saved to {path}')
def test_file_existence(report: ValidationReport) -> None:
"""Test that all required files exist."""
logger.info('\n' + '='*70)
logger.info('PHASE 1: FILE INTEGRITY')
logger.info('='*70)
# Check data
report.test(
'DDInter data exists',
DATA_PATH.exists(),
f'Path: {DATA_PATH}'
)
# Check feature pipeline
report.test(
'Feature pipeline exists',
FEATURE_PIPELINE_MULTISOURCE_PATH.exists(),
f'Path: {FEATURE_PIPELINE_MULTISOURCE_PATH}'
)
# Check model - prefer production model, fallback to standard
model_exists = PRODUCTION_MODEL_PATH.exists() or MODEL_PATH.exists()
model_used = PRODUCTION_MODEL_PATH if PRODUCTION_MODEL_PATH.exists() else MODEL_PATH
report.test(
'Model checkpoint exists',
model_exists,
f'Using: {model_used.name}'
)
# Check calibration (optional)
if CALIBRATION_PATH.exists():
report.test(
'Calibration artifacts (optional)',
True,
f'Path: {CALIBRATION_PATH} (optional)'
)
else:
logger.info('⚠ Calibration artifacts (optional)')
logger.info(f' Path: {CALIBRATION_PATH} (optional, missing)')
def test_predictor_loading(report: ValidationReport) -> HybridDDIPredictor | None:
"""Test loading the predictor."""
logger.info('\n' + '='*70)
logger.info('PHASE 2: PREDICTOR INITIALIZATION')
logger.info('='*70)
try:
logger.info('Loading predictor...')
predictor = HybridDDIPredictor.from_default_paths(use_production=True)
report.test(
'Predictor loads successfully',
True,
f'Model version: {predictor.model_version}'
)
return predictor
except Exception as e:
report.test(
'Predictor loads successfully',
False,
f'Error: {str(e)}'
)
return None
def test_health_check(report: ValidationReport, predictor: HybridDDIPredictor) -> None:
"""Test health check."""
logger.info('\n' + '='*70)
logger.info('PHASE 3: HEALTH CHECK')
logger.info('='*70)
if predictor is None:
report.test('Health check', False, 'Predictor not loaded')
return
try:
health = predictor.health()
report.test(
'Health check returns valid response',
'status' in health and health['status'] in ['ok', 'healthy'],
f"Status: {health.get('status')}"
)
report.test(
'Model is loaded',
health.get('model_loaded', False),
f"Model version: {health.get('model_version')}"
)
report.test(
'Pipeline is loaded',
health.get('pipeline_loaded', False),
f"Type: {health.get('model_type')}"
)
report.test(
'DDInter pairs loaded',
health.get('pairs_loaded', 0) > 0,
f"Pairs: {health.get('pairs_loaded'):,}"
)
if health.get('calibration_loaded', False):
report.test(
'Calibration loaded (optional)',
True,
'Optional - improves confidence calibration'
)
else:
logger.info('⚠ Calibration loaded (optional)')
logger.info(' Optional - improves confidence calibration')
except Exception as e:
report.test('Health check', False, f'Error: {str(e)}')
def test_exact_lookup(report: ValidationReport, predictor: HybridDDIPredictor) -> None:
"""Test exact DDInter lookup."""
logger.info('\n' + '='*70)
logger.info('PHASE 4: EXACT LOOKUP TEST')
logger.info('='*70)
if predictor is None:
report.test('Exact lookup test', False, 'Predictor not loaded')
return
# Test with a known pair from DDInter
test_pairs = [
('Aspirin', 'Warfarin'),
('Metformin', 'Ibuprofen'),
('Omeprazole', 'Lisinopril'),
]
for drug_a, drug_b in test_pairs:
try:
result = predictor.predict(drug_a, drug_b)
is_exact = result.get('source') == 'ddinter_lookup'
if is_exact:
report.test(
f'Exact lookup: {drug_a} + {drug_b}',
True,
f"Severity: {result.get('severity')}, Confidence: {result.get('confidence')}"
)
except Exception as e:
logger.debug(f'Test pair ({drug_a}, {drug_b}) failed: {e}')
def test_ml_fallback(report: ValidationReport, predictor: HybridDDIPredictor) -> None:
"""Test ML fallback predictions."""
logger.info('\n' + '='*70)
logger.info('PHASE 5: ML FALLBACK TEST')
logger.info('='*70)
if predictor is None:
report.test('ML fallback test', False, 'Predictor not loaded')
return
# Test with unknown pairs to trigger ML
test_pairs = [
('Caffeine', 'Aspirin'),
('Vitamin D', 'Calcium'),
('Magnesium', 'Iron'),
]
ml_count = 0
for drug_a, drug_b in test_pairs:
try:
result = predictor.predict(drug_a, drug_b)
# Any result (exact or ML) is valid
severity = result.get('severity')
confidence = result.get('confidence')
confidence_band = result.get('confidence_band')
source = result.get('source')
if source == 'deep_learning_prediction':
ml_count += 1
logger.debug(
f'Prediction: {drug_a} + {drug_b} = {severity} '
f'(confidence: {confidence:.3f}, band: {confidence_band})'
)
except Exception as e:
logger.warning(f'Prediction failed for ({drug_a}, {drug_b}): {e}')
report.test(
'ML fallback predictions work',
ml_count > 0,
f'{ml_count} predictions used ML model'
)
def test_confidence_bands(report: ValidationReport, predictor: HybridDDIPredictor) -> None:
"""Test confidence band classification."""
logger.info('\n' + '='*70)
logger.info('PHASE 6: CONFIDENCE BAND TEST')
logger.info('='*70)
if predictor is None:
report.test('Confidence band test', False, 'Predictor not loaded')
return
try:
# Test multiple predictions and check confidence bands
test_pairs = [
('Aspirin', 'Ibuprofen'),
('Warfarin', 'Aspirin'),
('Metformin', 'Warfarin'),
]
band_counts = {'high': 0, 'medium': 0, 'low': 0}
valid_bands = 0
for drug_a, drug_b in test_pairs:
try:
result = predictor.predict(drug_a, drug_b)
band = result.get('confidence_band', 'low').lower()
if band in band_counts:
band_counts[band] += 1
valid_bands += 1
except Exception as e:
logger.debug(f'Prediction failed: {e}')
report.test(
'Confidence bands are valid',
valid_bands == len(test_pairs),
f"Distribution: HIGH={band_counts['high']}, MEDIUM={band_counts['medium']}, LOW={band_counts['low']}"
)
except Exception as e:
report.test('Confidence band test', False, f'Error: {str(e)}')
def test_response_schema(report: ValidationReport, predictor: HybridDDIPredictor) -> None:
"""Test response schema completeness."""
logger.info('\n' + '='*70)
logger.info('PHASE 7: RESPONSE SCHEMA TEST')
logger.info('='*70)
if predictor is None:
report.test('Response schema test', False, 'Predictor not loaded')
return
try:
result = predictor.predict('Aspirin', 'Warfarin')
required_fields = [
'source',
'confidence',
'severity',
'explanation',
'clinical_advice',
'drug_a_name',
'drug_b_name',
]
missing_fields = [field for field in required_fields if field not in result]
report.test(
'Response has all required fields',
len(missing_fields) == 0,
f"Fields: {', '.join(required_fields)}"
)
# Test optional fields
has_warning = 'warning' in result
has_probs = 'probabilities' in result
report.test(
'Response includes warning field (safety)',
has_warning,
'Enables healthcare safety checks'
)
report.test(
'Response includes probabilities (explainability)',
has_probs or result.get('source') == 'ddinter_lookup',
'ML predictions should include probabilities'
)
except Exception as e:
report.test('Response schema test', False, f'Error: {str(e)}')
def run_all_tests() -> ValidationReport:
"""Run all validation tests."""
logger.info('\n' + '#'*70)
logger.info('# MEDCARE-DDI SYSTEM VALIDATION SUITE')
logger.info('#'*70)
report = ValidationReport()
# Phase 1: File integrity
test_file_existence(report)
# Phase 2: Predictor loading
predictor = test_predictor_loading(report)
if predictor is None:
logger.error('Cannot proceed without predictor')
return report
# Phase 3: Health check
test_health_check(report, predictor)
# Phase 4: Exact lookup
test_exact_lookup(report, predictor)
# Phase 5: ML fallback
test_ml_fallback(report, predictor)
# Phase 6: Confidence bands
test_confidence_bands(report, predictor)
# Phase 7: Response schema
test_response_schema(report, predictor)
# Final summary
logger.info('\n' + '='*70)
logger.info('VALIDATION SUMMARY')
logger.info('='*70)
logger.info(f"Passed: {report.passed}/{len(report.tests)}")
logger.info(f"Failed: {report.failed}/{len(report.tests)}")
logger.info(f"Pass rate: {report.to_dict()['pass_rate']:.1f}%")
if report.errors:
logger.error(f"Failed tests: {', '.join(report.errors)}")
return report
if __name__ == '__main__':
report = run_all_tests()
# Save report
report_path = MODEL_DIR / 'reports' / 'validation_report.json'
report_path.parent.mkdir(parents=True, exist_ok=True)
report.save(report_path)
# Exit with appropriate code
sys.exit(0 if report.failed == 0 else 1)