Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| PHASE 9: Final Model Governance & Production Selection | |
| ====================================================== | |
| Selects production model based on healthcare criteria: | |
| 1. Severe Recall β₯ 90% (absolute requirement) | |
| 2. Calibration Quality (ECE < 0.05) | |
| 3. AUROC (discrimination quality) | |
| 4. Stability (generalization across folds) | |
| 5. Latency (p99 < 200ms) | |
| 6. Explainability consistency | |
| Author: MEDCARE-DDI AI Research Team | |
| Date: May 2026 | |
| """ | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from datetime import datetime | |
| import argparse | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ModelGovernanceEngine: | |
| """Scientific model selection based on healthcare criteria""" | |
| # Selection weights (healthcare-aware) | |
| SELECTION_WEIGHTS = { | |
| 'severe_recall': 0.40, # CRITICAL: minimize false negatives | |
| 'calibration': 0.20, # CRITICAL: confidence must be trustworthy | |
| 'auroc': 0.15, # Discrimination ability | |
| 'macro_f1': 0.10, # Overall balance | |
| 'latency': 0.10, # Production feasibility | |
| 'stability': 0.05 # Generalization | |
| } | |
| # Hard constraints (MUST be met) | |
| HARD_CONSTRAINTS = { | |
| 'severe_recall': 0.90, # Minimum recall on severe class | |
| 'calibration': 0.05, # Maximum ECE | |
| } | |
| def __init__(self, reports_dir: Path): | |
| self.reports_dir = reports_dir | |
| self.models = {} | |
| self.selected_model = None | |
| def load_benchmark_results(self) -> dict: | |
| """Load results from Phase 7 (Comprehensive Benchmarking)""" | |
| benchmark_file = self.reports_dir / 'final_benchmark_report.md' | |
| metrics_file = self.reports_dir / 'benchmark_metrics.json' | |
| if not metrics_file.exists(): | |
| logger.warning(f"Benchmark metrics file not found: {metrics_file}") | |
| return {} | |
| with metrics_file.open('r') as f: | |
| metrics = json.load(f) | |
| return metrics | |
| def load_safety_analysis(self) -> dict: | |
| """Load results from Phase 5 (Healthcare Safety)""" | |
| safety_file = self.reports_dir / 'threshold_optimization.json' | |
| if not safety_file.exists(): | |
| logger.warning(f"Safety analysis file not found: {safety_file}") | |
| return {} | |
| with safety_file.open('r') as f: | |
| safety_data = json.load(f) | |
| return safety_data | |
| def load_latency_metrics(self) -> dict: | |
| """Load latency metrics from Phase 8""" | |
| latency_file = self.reports_dir / 'latency_summary.csv' | |
| if not latency_file.exists(): | |
| logger.warning(f"Latency file not found: {latency_file}") | |
| return {'p99_ms': 100} # Default estimate | |
| # Parse latency CSV | |
| latency_metrics = {} | |
| try: | |
| with latency_file.open('r') as f: | |
| lines = f.readlines() | |
| for line in lines[1:]: # Skip header | |
| parts = line.strip().split(',') | |
| if len(parts) >= 2: | |
| metric_name = parts[0] | |
| value = float(parts[1]) | |
| latency_metrics[metric_name] = value | |
| except Exception as e: | |
| logger.error(f"Error parsing latency file: {e}") | |
| return latency_metrics | |
| def evaluate_models(self) -> dict: | |
| """Evaluate all candidate models""" | |
| logger.info("Loading benchmark results...") | |
| benchmark_results = self.load_benchmark_results() | |
| logger.info("Loading safety analysis...") | |
| safety_data = self.load_safety_analysis() | |
| logger.info("Loading latency metrics...") | |
| latency_data = self.load_latency_metrics() | |
| # Normalize metrics for scoring | |
| severe_recall = benchmark_results.get('severe_recall', 0.85) | |
| auroc = benchmark_results.get('auroc', 0.88) | |
| macro_f1 = benchmark_results.get('macro_f1', 0.80) | |
| ece = benchmark_results.get('ece', 0.06) | |
| p99_latency = latency_data.get('p99_ms', 150) | |
| # Compute calibration quality (inverse of ECE) | |
| calibration_quality = max(0, 1.0 - ece) | |
| # Compute latency score (inversely proportional to p99) | |
| latency_score = max(0, 1.0 - min(p99_latency / 200.0, 1.0)) | |
| # Estimate stability from ensemble variance | |
| stability = benchmark_results.get('stability', 0.92) | |
| # Check hard constraints | |
| constraints_met = { | |
| 'severe_recall': severe_recall >= self.HARD_CONSTRAINTS['severe_recall'], | |
| 'calibration': ece <= self.HARD_CONSTRAINTS['calibration'] | |
| } | |
| all_constraints_met = all(constraints_met.values()) | |
| if not all_constraints_met: | |
| logger.warning("β οΈ HARD CONSTRAINTS NOT MET:") | |
| for constraint, met in constraints_met.items(): | |
| status = "β" if met else "β" | |
| logger.warning(f" {status} {constraint}") | |
| # Compute healthcare-aware score | |
| healthcare_score = ( | |
| self.SELECTION_WEIGHTS['severe_recall'] * severe_recall + | |
| self.SELECTION_WEIGHTS['calibration'] * calibration_quality + | |
| self.SELECTION_WEIGHTS['auroc'] * auroc + | |
| self.SELECTION_WEIGHTS['macro_f1'] * macro_f1 + | |
| self.SELECTION_WEIGHTS['latency'] * latency_score + | |
| self.SELECTION_WEIGHTS['stability'] * stability | |
| ) | |
| model_info = { | |
| 'constraints_met': all_constraints_met, | |
| 'individual_metrics': { | |
| 'severe_recall': severe_recall, | |
| 'auroc': auroc, | |
| 'macro_f1': macro_f1, | |
| 'ece': ece, | |
| 'p99_latency_ms': p99_latency, | |
| 'stability': stability | |
| }, | |
| 'normalized_scores': { | |
| 'severe_recall': severe_recall, | |
| 'calibration': calibration_quality, | |
| 'auroc': auroc, | |
| 'macro_f1': macro_f1, | |
| 'latency': latency_score, | |
| 'stability': stability | |
| }, | |
| 'healthcare_score': healthcare_score, | |
| 'benchmark_results': benchmark_results, | |
| 'safety_analysis': safety_data | |
| } | |
| return model_info | |
| def select_production_model(self) -> dict: | |
| """Select final production model""" | |
| logger.info("\n" + "="*70) | |
| logger.info("PHASE 9: FINAL MODEL GOVERNANCE") | |
| logger.info("="*70) | |
| logger.info("Evaluating candidate models...") | |
| model_eval = self.evaluate_models() | |
| logger.info("\n" + "-"*70) | |
| logger.info("MODEL EVALUATION RESULTS") | |
| logger.info("-"*70) | |
| logger.info(f"\nπ Individual Metrics:") | |
| for metric, value in model_eval['individual_metrics'].items(): | |
| if 'ms' in metric: | |
| logger.info(f" {metric}: {value:.1f}") | |
| else: | |
| logger.info(f" {metric}: {value:.4f}") | |
| logger.info(f"\nπ Normalized Scores:") | |
| for metric, score in model_eval['normalized_scores'].items(): | |
| logger.info(f" {metric}: {score:.4f}") | |
| logger.info(f"\nπ Healthcare Score: {model_eval['healthcare_score']:.4f}") | |
| logger.info(f"\nβ Constraints Met: {model_eval['constraints_met']}") | |
| if not model_eval['constraints_met']: | |
| logger.warning("\nβ οΈ MODEL DOES NOT MEET HARD CONSTRAINTS") | |
| logger.warning(" Severe Recall β₯ 90%: required for safety") | |
| logger.warning(" ECE < 0.05: required for trustworthy confidence") | |
| logger.warning("\nRECOMMENDATION: Re-run Phase 3 (hyperparameter tuning)") | |
| logger.warning(" with more trials or higher focal_gamma") | |
| return model_eval | |
| def generate_model_card(self, model_eval: dict) -> str: | |
| """Generate production model card""" | |
| card = [] | |
| card.append("# MEDCARE-DDI v2.1 Production Model Card\n") | |
| card.append("## Model Specification\n") | |
| card.append(f"- **Generated:** {datetime.now().isoformat()}\n") | |
| card.append(f"- **Purpose:** Drug-Drug Interaction Severity Prediction\n") | |
| card.append(f"- **Target:** Clinical decision support (NOT autonomous)\n\n") | |
| card.append("## Performance Metrics\n") | |
| card.append("### Primary (Healthcare-Critical)\n") | |
| metrics = model_eval['individual_metrics'] | |
| card.append(f"- **Severe Recall:** {metrics['severe_recall']:.4f} (β₯0.90 required)\n") | |
| card.append(f"- **Calibration (ECE):** {metrics['ece']:.4f} (<0.05 required)\n") | |
| card.append(f"- **AUROC:** {metrics['auroc']:.4f} (discrimination quality)\n\n") | |
| card.append("### Secondary\n") | |
| card.append(f"- **Macro F1:** {metrics['macro_f1']:.4f}\n") | |
| card.append(f"- **p99 Latency:** {metrics['p99_latency_ms']:.1f}ms (<200ms required)\n") | |
| card.append(f"- **Stability:** {metrics['stability']:.4f}\n\n") | |
| card.append("## Selection Criteria Weights\n") | |
| for criterion, weight in self.SELECTION_WEIGHTS.items(): | |
| card.append(f"- {criterion}: {weight:.0%}\n") | |
| card.append(f"\n**Overall Healthcare Score:** {model_eval['healthcare_score']:.4f}\n\n") | |
| card.append("## Safety Constraints\n") | |
| card.append("β Severe Recall β₯ 90% (minimize false negatives on dangerous interactions)\n") | |
| card.append("β ECE < 0.05 (confidence scores must be trustworthy)\n") | |
| card.append("β AUROC β₯ 0.90 (good discrimination across classes)\n") | |
| card.append("β p99 Latency < 200ms (real-time clinical use)\n\n") | |
| card.append("## Deployment Instructions\n") | |
| card.append("```bash\n") | |
| card.append("export MODEL_PATH=models/ddi_mlp_production.pt\n") | |
| card.append("export CALIBRATION_PATH=models/calibration_artifacts_production.pkl\n") | |
| card.append("uvicorn src.inference.app_production:app --host 0.0.0.0 --port 8000 --workers 4\n") | |
| card.append("```\n\n") | |
| card.append("## Healthcare Safety Guarantees\n") | |
| card.append("- β Exact DDInter lookup prioritized (trusted evidence first)\n") | |
| card.append("- β Conservative severe escalation (when uncertain)\n") | |
| card.append("- β Confidence bands (LOW/MEDIUM/HIGH) for clinical context\n") | |
| card.append("- β Temperature-scaled calibration (learned confidence adjustment)\n") | |
| card.append("- β Explainability (SHAP features for interpretability)\n") | |
| card.append("- β Not autonomous: supports clinical decision-making only\n\n") | |
| card.append("## Monitoring Recommendations\n") | |
| card.append("1. **Daily**: Health check endpoint /health\n") | |
| card.append("2. **Continuous**: Latency tracking (alert if p99 > 200ms)\n") | |
| card.append("3. **Weekly**: Calibration drift monitoring\n") | |
| card.append("4. **Monthly**: Severe recall tracking (if ground truth available)\n") | |
| card.append("5. **Quarterly**: Model retraining with new data\n\n") | |
| card.append("## References\n") | |
| card.append("- OPTIMIZATION_FRAMEWORK.py: Complete methodology\n") | |
| card.append("- final_benchmark_report.md: Detailed metrics\n") | |
| card.append("- safety_analysis_report.md: Healthcare safety analysis\n") | |
| card.append("- production_readiness_report.md: Deployment checklist\n") | |
| return "\n".join(card) | |
| def generate_governance_summary(self, model_eval: dict) -> str: | |
| """Generate governance and decision summary""" | |
| summary = [] | |
| summary.append("# Model Governance Summary\n") | |
| summary.append(f"**Decision Date:** {datetime.now().isoformat()}\n\n") | |
| summary.append("## Selection Decision\n") | |
| if model_eval['constraints_met']: | |
| summary.append("β **APPROVED FOR PRODUCTION**\n\n") | |
| summary.append("The selected model meets all hard constraints:\n") | |
| summary.append(f"- Severe Recall: {model_eval['individual_metrics']['severe_recall']:.1%}\n") | |
| summary.append(f"- Calibration (ECE): {model_eval['individual_metrics']['ece']:.4f}\n") | |
| summary.append(f"- AUROC: {model_eval['individual_metrics']['auroc']:.4f}\n\n") | |
| else: | |
| summary.append("β **NOT APPROVED FOR PRODUCTION**\n\n") | |
| summary.append("The model FAILS hard constraints:\n") | |
| if model_eval['individual_metrics']['severe_recall'] < 0.90: | |
| summary.append(f"- Severe Recall {model_eval['individual_metrics']['severe_recall']:.1%} < 90% (CRITICAL)\n") | |
| if model_eval['individual_metrics']['ece'] > 0.05: | |
| summary.append(f"- ECE {model_eval['individual_metrics']['ece']:.4f} > 0.05 (CRITICAL)\n") | |
| summary.append("\nRECOMMENDATION: Re-run optimization phases\n\n") | |
| summary.append("## Recommendation for Next Steps\n") | |
| summary.append("1. Review safety_analysis_report.md for healthcare safety findings\n") | |
| summary.append("2. Review explainability_validation.md for model interpretability\n") | |
| summary.append("3. Deploy using production_readiness_report.md instructions\n") | |
| summary.append("4. Set up monitoring dashboard for calibration drift\n") | |
| summary.append("5. Plan quarterly retraining schedule\n\n") | |
| summary.append("## Model Card Location\n") | |
| summary.append("See: final_model_card.md\n") | |
| return "\n".join(summary) | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Phase 9: Final Model Governance') | |
| parser.add_argument('--reports-dir', type=str, default='models/reports', | |
| help='Reports directory') | |
| parser.add_argument('--seed', type=int, default=2026, help='Random seed') | |
| parser.add_argument('--experiment-id', type=str, help='Experiment ID (for logging)') | |
| args = parser.parse_args() | |
| reports_dir = Path(args.reports_dir) | |
| # Create governance engine | |
| engine = ModelGovernanceEngine(reports_dir) | |
| # Select production model | |
| model_eval = engine.select_production_model() | |
| # Generate model card | |
| model_card = engine.generate_model_card(model_eval) | |
| model_card_file = reports_dir / 'final_model_card.md' | |
| with model_card_file.open('w') as f: | |
| f.write(model_card) | |
| logger.info(f"\nβ Model card saved: {model_card_file}") | |
| # Generate governance summary | |
| governance_summary = engine.generate_governance_summary(model_eval) | |
| governance_file = reports_dir / 'governance_summary.md' | |
| with governance_file.open('w') as f: | |
| f.write(governance_summary) | |
| logger.info(f"β Governance summary saved: {governance_file}") | |
| # Save structured decision | |
| decision = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'constraints_met': model_eval['constraints_met'], | |
| 'approval_status': 'APPROVED' if model_eval['constraints_met'] else 'REJECTED', | |
| 'healthcare_score': model_eval['healthcare_score'], | |
| 'metrics': model_eval['individual_metrics'], | |
| 'weights': engine.SELECTION_WEIGHTS, | |
| 'constraints': engine.HARD_CONSTRAINTS | |
| } | |
| decision_file = reports_dir / 'final_governance_decision.json' | |
| with decision_file.open('w') as f: | |
| json.dump(decision, f, indent=2) | |
| logger.info(f"β Decision record saved: {decision_file}") | |
| logger.info("\n" + "="*70) | |
| logger.info("PHASE 9 COMPLETE") | |
| logger.info("="*70) | |
| return 0 if model_eval['constraints_met'] else 1 | |
| if __name__ == '__main__': | |
| sys.exit(main()) | |