Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Comprehensive production integration test and validation suite. | |
| Tests: | |
| 1. Backend API startup and health checks | |
| 2. Model inference on known and unknown pairs | |
| 3. Frontend API contract compliance | |
| 4. Healthcare safety features | |
| 5. Confidence calibration | |
| 6. Error handling | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', | |
| ) | |
| logger = logging.getLogger('medcare_ddi.integration_test') | |
| class IntegrationTest: | |
| """Complete integration test suite.""" | |
| def __init__(self): | |
| self.results = {} | |
| self.base_url = 'http://localhost:8000' | |
| def log_test(self, name: str, passed: bool, details: str = '') -> None: | |
| """Log test result.""" | |
| status = 'β PASS' if passed else 'β FAIL' | |
| logger.info(f'{status} - {name}') | |
| if details: | |
| logger.info(f' {details}') | |
| self.results[name] = passed | |
| def test_health_endpoint(self) -> bool: | |
| """Test /health endpoint.""" | |
| try: | |
| import requests | |
| response = requests.get(f'{self.base_url}/health', timeout=5) | |
| if response.status_code != 200: | |
| self.log_test('Health Endpoint', False, f'Status {response.status_code}') | |
| return False | |
| data = response.json() | |
| checks = [ | |
| ('status' in data, 'status field'), | |
| ('model_loaded' in data, 'model_loaded field'), | |
| ('pairs_loaded' in data, 'pairs_loaded field'), | |
| (data.get('model_loaded') is True, 'model_loaded is True'), | |
| (data.get('pairs_loaded', 0) > 0, f'pairs_loaded > 0 (got {data.get("pairs_loaded")})'), | |
| ] | |
| all_passed = all(check[0] for check in checks) | |
| details = ', '.join(check[1] for check in checks if check[0]) | |
| self.log_test('Health Endpoint', all_passed, details) | |
| return all_passed | |
| except Exception as e: | |
| self.log_test('Health Endpoint', False, str(e)) | |
| return False | |
| def test_known_interactions(self) -> bool: | |
| """Test predictions on known DDI pairs.""" | |
| try: | |
| import requests | |
| test_pairs = [ | |
| ('Aspirin', 'Warfarin', 'major'), | |
| ('Metformin', 'Insulin', 'moderate'), | |
| ] | |
| all_passed = True | |
| for drug_a, drug_b, expected_severity in test_pairs: | |
| try: | |
| response = requests.post( | |
| f'{self.base_url}/predict', | |
| json={'drug_a': drug_a, 'drug_b': drug_b}, | |
| timeout=10, | |
| ) | |
| if response.status_code != 200: | |
| self.log_test( | |
| f'Known DDI: {drug_a} + {drug_b}', | |
| False, | |
| f'Status {response.status_code}', | |
| ) | |
| all_passed = False | |
| continue | |
| data = response.json() | |
| # Check response schema | |
| required_fields = [ | |
| 'drug_a', | |
| 'drug_b', | |
| 'severity', | |
| 'confidence', | |
| 'confidence_band', | |
| 'source', | |
| 'explanation', | |
| 'clinical_advice', | |
| 'latency_ms', | |
| ] | |
| missing_fields = [f for f in required_fields if f not in data] | |
| if missing_fields: | |
| self.log_test( | |
| f'Known DDI: {drug_a} + {drug_b}', | |
| False, | |
| f'Missing fields: {missing_fields}', | |
| ) | |
| all_passed = False | |
| continue | |
| # Check values | |
| severity = data.get('severity') | |
| confidence = data.get('confidence', 0) | |
| source = data.get('source') | |
| self.log_test( | |
| f'Known DDI: {drug_a} + {drug_b}', | |
| True, | |
| f'{severity} (conf={confidence:.2f}, src={source})', | |
| ) | |
| except Exception as e: | |
| self.log_test(f'Known DDI: {drug_a} + {drug_b}', False, str(e)) | |
| all_passed = False | |
| return all_passed | |
| except Exception as e: | |
| self.log_test('Known Interactions Test', False, str(e)) | |
| return False | |
| def test_unseen_pairs(self) -> bool: | |
| """Test ML fallback on unseen pairs.""" | |
| try: | |
| import requests | |
| test_pairs = [ | |
| ('UnknownDrugA', 'UnknownDrugB'), | |
| ('TestDrug1', 'TestDrug2'), | |
| ] | |
| all_passed = True | |
| for drug_a, drug_b in test_pairs: | |
| try: | |
| response = requests.post( | |
| f'{self.base_url}/predict', | |
| json={'drug_a': drug_a, 'drug_b': drug_b}, | |
| timeout=10, | |
| ) | |
| if response.status_code != 200: | |
| self.log_test( | |
| f'Unseen pair: {drug_a} + {drug_b}', | |
| False, | |
| f'Status {response.status_code}', | |
| ) | |
| all_passed = False | |
| continue | |
| data = response.json() | |
| severity = data.get('severity') | |
| source = data.get('source') | |
| self.log_test( | |
| f'Unseen pair: {drug_a} + {drug_b}', | |
| True, | |
| f'{severity} (source={source})', | |
| ) | |
| except Exception as e: | |
| self.log_test(f'Unseen pair: {drug_a} + {drug_b}', False, str(e)) | |
| all_passed = False | |
| return all_passed | |
| except Exception as e: | |
| self.log_test('Unseen Pairs Test', False, str(e)) | |
| return False | |
| def test_error_handling(self) -> bool: | |
| """Test error handling for invalid inputs.""" | |
| try: | |
| import requests | |
| test_cases = [ | |
| ({}, 'Missing both drugs'), | |
| ({'drug_a': ''}, 'Empty drug names'), | |
| ({'drug_a': None, 'drug_b': None}, 'None drugs'), | |
| ] | |
| all_passed = True | |
| for payload, desc in test_cases: | |
| try: | |
| response = requests.post( | |
| f'{self.base_url}/predict', | |
| json=payload, | |
| timeout=5, | |
| ) | |
| if response.status_code >= 400: | |
| self.log_test(f'Error Handling: {desc}', True, f'Status {response.status_code}') | |
| else: | |
| self.log_test(f'Error Handling: {desc}', False, 'Should have failed') | |
| all_passed = False | |
| except Exception as e: | |
| self.log_test(f'Error Handling: {desc}', False, str(e)) | |
| all_passed = False | |
| return all_passed | |
| except Exception as e: | |
| self.log_test('Error Handling Test', False, str(e)) | |
| return False | |
| def test_confidence_bands(self) -> bool: | |
| """Test confidence band classification.""" | |
| try: | |
| import requests | |
| response = requests.get(f'{self.base_url}/health', timeout=5) | |
| if response.status_code != 200: | |
| self.log_test('Confidence Bands', False, 'Could not get health info') | |
| return False | |
| # Make several predictions and check confidence_band values | |
| test_pairs = [('Aspirin', 'Warfarin'), ('Drug1', 'Drug2'), ('Drug3', 'Drug4')] | |
| bands_found = set() | |
| for drug_a, drug_b in test_pairs: | |
| try: | |
| response = requests.post( | |
| f'{self.base_url}/predict', | |
| json={'drug_a': drug_a, 'drug_b': drug_b}, | |
| timeout=10, | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| band = data.get('confidence_band') | |
| if band in ['high', 'medium', 'low']: | |
| bands_found.add(band) | |
| except: | |
| pass | |
| valid_bands = len(bands_found) > 0 | |
| details = f'Bands found: {bands_found}' if bands_found else 'No valid bands found' | |
| self.log_test('Confidence Bands', valid_bands, details) | |
| return valid_bands | |
| except Exception as e: | |
| self.log_test('Confidence Bands Test', False, str(e)) | |
| return False | |
| def run_all(self) -> bool: | |
| """Run all tests.""" | |
| logger.info('') | |
| logger.info('β' + 'β'*68 + 'β') | |
| logger.info('β MEDCARE-DDI INTEGRATION TEST SUITE' + ' '*33 + 'β') | |
| logger.info('β' + 'β'*68 + 'β') | |
| logger.info('') | |
| tests = [ | |
| self.test_health_endpoint, | |
| self.test_known_interactions, | |
| self.test_unseen_pairs, | |
| self.test_error_handling, | |
| self.test_confidence_bands, | |
| ] | |
| for test in tests: | |
| try: | |
| test() | |
| except Exception as e: | |
| logger.error(f'Test {test.__name__} crashed: {e}', exc_info=True) | |
| # Summary | |
| logger.info('') | |
| logger.info('='*70) | |
| logger.info('TEST SUMMARY') | |
| logger.info('='*70) | |
| passed = sum(1 for v in self.results.values() if v) | |
| total = len(self.results) | |
| pass_rate = (passed / total * 100) if total > 0 else 0 | |
| logger.info(f'Passed: {passed}/{total} ({pass_rate:.0f}%)') | |
| for test_name, passed in self.results.items(): | |
| status = 'β' if passed else 'β' | |
| logger.info(f'{status} {test_name}') | |
| logger.info('') | |
| all_passed = all(self.results.values()) | |
| status = 'READY FOR DEPLOYMENT' if all_passed else 'NEEDS_ATTENTION' | |
| logger.info(f'Overall: {status}') | |
| logger.info('') | |
| return all_passed | |
| def main(): | |
| """Run integration tests.""" | |
| # Try to import requests | |
| try: | |
| import requests | |
| except ImportError: | |
| logger.error('requests module not found. Install with: pip install requests') | |
| return False | |
| # Run tests | |
| test_suite = IntegrationTest() | |
| return test_suite.run_all() | |
| if __name__ == '__main__': | |
| success = main() | |
| sys.exit(0 if success else 1) | |