Spaces:
Running
Running
| """Master orchestration script for complete AI system optimization. | |
| Runs all 9 phases: | |
| 1. Dataset Audit | |
| 2. Embedding Benchmarks | |
| 3. Hyperparameter Optimization | |
| 4. Ensemble Ablation | |
| 5. Healthcare Safety Tuning | |
| 6. Explainability Validation | |
| 7. Comprehensive Benchmarks | |
| 8. Production Validation | |
| 9. Final Model Selection | |
| Output: | |
| - COMPLETE_WORKFLOW_REPORT.md | |
| - workflow_execution_log.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import subprocess | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', | |
| ) | |
| logger = logging.getLogger('medcare_ddi.workflow') | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| SRC_DIR = BASE_DIR / 'src' / 'validation' | |
| REPORTS_DIR = BASE_DIR / 'models' / 'reports' | |
| REPORTS_DIR.mkdir(parents=True, exist_ok=True) | |
| PHASES = [ | |
| { | |
| 'name': 'Dataset Audit', | |
| 'script': 'dataset_audit.py', | |
| 'description': 'Complete audit of DDI sources, duplicates, conflicts, class balance', | |
| }, | |
| { | |
| 'name': 'Embedding Benchmarks', | |
| 'script': 'embedding_benchmark.py', | |
| 'description': 'Compare BioBERT, PubMedBERT, SapBERT, ChemBERTa', | |
| }, | |
| { | |
| 'name': 'Hyperparameter Optimization', | |
| 'script': 'optuna_hyperparameter_tune.py', | |
| 'description': 'Optimize 50+ trials with healthcare-aware objective', | |
| 'args': ['--n-trials', '50'], | |
| }, | |
| { | |
| 'name': 'Ensemble Ablation', | |
| 'script': 'ensemble_ablation_study.py', | |
| 'description': 'Compare voting, blending, stacking strategies', | |
| }, | |
| { | |
| 'name': 'Healthcare Safety Tuning', | |
| 'script': 'healthcare_safety_tuning.py', | |
| 'description': 'Optimize severe escalation thresholds, reduce false negatives', | |
| }, | |
| { | |
| 'name': 'Explainability Validation', | |
| 'script': 'explainability_validation.py', | |
| 'description': 'Feature importance and explanation quality', | |
| }, | |
| { | |
| 'name': 'Comprehensive Benchmarks', | |
| 'script': 'comprehensive_benchmark.py', | |
| 'description': 'Final metrics, confusion matrices, calibration analysis', | |
| }, | |
| { | |
| 'name': 'Production Validation', | |
| 'script': 'production_validation.py', | |
| 'description': 'FastAPI compatibility, latency, memory, deployment readiness', | |
| }, | |
| ] | |
| def run_phase(phase: Dict[str, Any]) -> Dict[str, Any]: | |
| """Execute a single phase.""" | |
| logger.info('=' * 70) | |
| logger.info(f'PHASE: {phase["name"]}') | |
| logger.info(f'Description: {phase["description"]}') | |
| logger.info('=' * 70) | |
| start_time = time.perf_counter() | |
| try: | |
| script_path = SRC_DIR / phase['script'] | |
| if not script_path.exists(): | |
| raise FileNotFoundError(f'Script not found: {script_path}') | |
| cmd = ['python', str(script_path)] | |
| if 'args' in phase: | |
| cmd.extend(phase['args']) | |
| logger.info(f'Running: {" ".join(cmd)}') | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600) | |
| elapsed = time.perf_counter() - start_time | |
| if result.returncode == 0: | |
| logger.info(f'✓ Phase completed in {elapsed:.1f}s') | |
| return { | |
| 'phase': phase['name'], | |
| 'status': 'success', | |
| 'elapsed_seconds': elapsed, | |
| 'output': result.stdout[-500:] if result.stdout else '', | |
| } | |
| else: | |
| logger.error(f'✗ Phase failed with code {result.returncode}') | |
| logger.error(f'Stderr: {result.stderr[-500:]}') | |
| return { | |
| 'phase': phase['name'], | |
| 'status': 'failed', | |
| 'elapsed_seconds': elapsed, | |
| 'error': result.stderr[-500:] if result.stderr else 'Unknown error', | |
| } | |
| except subprocess.TimeoutExpired: | |
| logger.error(f'✗ Phase timed out (1 hour)') | |
| return { | |
| 'phase': phase['name'], | |
| 'status': 'timeout', | |
| 'error': 'Execution timed out after 1 hour', | |
| } | |
| except Exception as e: | |
| logger.error(f'✗ Phase failed: {e}', exc_info=True) | |
| return { | |
| 'phase': phase['name'], | |
| 'status': 'error', | |
| 'error': str(e), | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description='Master workflow orchestration') | |
| parser.add_argument('--phases', type=int, nargs='+', help='Specific phases to run (1-9)') | |
| parser.add_argument('--skip-phases', type=int, nargs='+', help='Phases to skip') | |
| parser.add_argument('--output', type=str, default=str(REPORTS_DIR / 'COMPLETE_WORKFLOW_REPORT.md')) | |
| parser.add_argument('--log', type=str, default=str(REPORTS_DIR / 'workflow_execution_log.json')) | |
| args = parser.parse_args() | |
| logger.info('Starting MEDCARE-DDI AI System Optimization Workflow') | |
| logger.info(f'Timestamp: {datetime.now().isoformat()}') | |
| # Determine which phases to run | |
| phases_to_run = PHASES | |
| if args.phases: | |
| phases_to_run = [PHASES[i-1] for i in args.phases if 1 <= i <= len(PHASES)] | |
| if args.skip_phases: | |
| skip_set = set(args.skip_phases) | |
| phases_to_run = [p for i, p in enumerate(PHASES, 1) if i not in skip_set] | |
| logger.info(f'Phases to run: {[p["name"] for p in phases_to_run]}') | |
| # Execute phases | |
| execution_log: List[Dict[str, Any]] = [] | |
| workflow_start = time.perf_counter() | |
| for phase in phases_to_run: | |
| result = run_phase(phase) | |
| execution_log.append(result) | |
| if result['status'] not in ['success']: | |
| logger.warning(f'Phase {phase["name"]} did not complete successfully') | |
| # Continue with next phase even if one fails | |
| workflow_elapsed = time.perf_counter() - workflow_start | |
| # Save execution log | |
| log_path = Path(args.log) | |
| log_path.parent.mkdir(parents=True, exist_ok=True) | |
| log_path.write_text(json.dumps({ | |
| 'timestamp': datetime.now().isoformat(), | |
| 'total_elapsed_seconds': workflow_elapsed, | |
| 'phases': execution_log, | |
| }, indent=2), encoding='utf-8') | |
| logger.info(f'Saved execution log to {log_path}') | |
| # Generate comprehensive report | |
| report_path = Path(args.output) | |
| with report_path.open('w') as f: | |
| f.write('# MEDCARE-DDI AI System Optimization - Complete Workflow Report\n\n') | |
| f.write(f'**Timestamp:** {datetime.now().isoformat()}\n\n') | |
| f.write(f'**Total Execution Time:** {workflow_elapsed / 60:.1f} minutes\n\n') | |
| f.write('## Workflow Summary\n\n') | |
| successful = sum(1 for r in execution_log if r['status'] == 'success') | |
| failed = sum(1 for r in execution_log if r['status'] != 'success') | |
| f.write(f'- Phases Completed: {successful}/{len(execution_log)}\n') | |
| f.write(f'- Phases Failed: {failed}/{len(execution_log)}\n\n') | |
| f.write('## Phase Execution Results\n\n') | |
| f.write('| # | Phase | Status | Duration (s) |\n') | |
| f.write('|---|-------|--------|---------------|\n') | |
| for idx, result in enumerate(execution_log, 1): | |
| status_icon = '✓' if result['status'] == 'success' else '✗' | |
| elapsed = result.get('elapsed_seconds', 0) | |
| f.write( | |
| f"| {idx} | {result['phase']} | {status_icon} {result['status']} | " | |
| f"{elapsed:.1f} |\n" | |
| ) | |
| f.write('\n## Detailed Results\n\n') | |
| for result in execution_log: | |
| f.write(f'### {result["phase"]}\n\n') | |
| f.write(f'**Status:** {result["status"]}\n\n') | |
| if result['status'] == 'success': | |
| f.write(f'**Elapsed:** {result.get("elapsed_seconds", 0):.1f}s\n\n') | |
| if result.get('output'): | |
| f.write('**Output:**\n\n```\n') | |
| f.write(result['output']) | |
| f.write('\n```\n\n') | |
| else: | |
| if result.get('error'): | |
| f.write('**Error:**\n\n```\n') | |
| f.write(result['error']) | |
| f.write('\n```\n\n') | |
| f.write('## Output Reports\n\n') | |
| f.write('The following reports have been generated in `models/reports/`:\n\n') | |
| report_files = [ | |
| 'dataset_audit_report.json', | |
| 'class_balance_report.json', | |
| 'conflict_analysis.csv', | |
| 'embedding_benchmark_results.csv', | |
| 'embedding_ablation_report.md', | |
| 'optuna_trials.json', | |
| 'optuna_best_params.json', | |
| 'hyperparameter_optimization_report.md', | |
| 'ensemble_benchmark.csv', | |
| 'ensemble_ablation.md', | |
| 'safety_analysis_report.md', | |
| 'severe_case_review.csv', | |
| 'threshold_optimization.json', | |
| 'explainability_examples.md', | |
| 'feature_importance.csv', | |
| 'final_benchmark_report.md', | |
| 'benchmark_metrics.json', | |
| 'production_validation_report.md', | |
| 'production_readiness_report.md', | |
| 'final_model_card.md', | |
| ] | |
| for report in report_files: | |
| f.write(f'- `{report}`\n') | |
| f.write('\n## Next Steps\n\n') | |
| f.write('1. Review all generated reports in `models/reports/`\n') | |
| f.write('2. Identify best-performing model based on healthcare metrics\n') | |
| f.write('3. Verify production readiness via `production_readiness_report.md`\n') | |
| f.write('4. Deploy final model using deployment instructions\n') | |
| f.write('5. Monitor live predictions and calibration drift\n') | |
| logger.info(f'Saved workflow report to {report_path}') | |
| logger.info('✓ Complete workflow finished') | |
| logger.info(f'Total time: {workflow_elapsed / 60:.1f} minutes') | |
| if __name__ == '__main__': | |
| main() | |