ddi / src /validation /run_complete_workflow.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""Master orchestration script for complete AI system optimization.
Runs all 9 phases:
1. Dataset Audit
2. Embedding Benchmarks
3. Hyperparameter Optimization
4. Ensemble Ablation
5. Healthcare Safety Tuning
6. Explainability Validation
7. Comprehensive Benchmarks
8. Production Validation
9. Final Model Selection
Output:
- COMPLETE_WORKFLOW_REPORT.md
- workflow_execution_log.json
"""
from __future__ import annotations
import argparse
import json
import logging
import subprocess
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
)
logger = logging.getLogger('medcare_ddi.workflow')
BASE_DIR = Path(__file__).resolve().parents[2]
SRC_DIR = BASE_DIR / 'src' / 'validation'
REPORTS_DIR = BASE_DIR / 'models' / 'reports'
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
PHASES = [
{
'name': 'Dataset Audit',
'script': 'dataset_audit.py',
'description': 'Complete audit of DDI sources, duplicates, conflicts, class balance',
},
{
'name': 'Embedding Benchmarks',
'script': 'embedding_benchmark.py',
'description': 'Compare BioBERT, PubMedBERT, SapBERT, ChemBERTa',
},
{
'name': 'Hyperparameter Optimization',
'script': 'optuna_hyperparameter_tune.py',
'description': 'Optimize 50+ trials with healthcare-aware objective',
'args': ['--n-trials', '50'],
},
{
'name': 'Ensemble Ablation',
'script': 'ensemble_ablation_study.py',
'description': 'Compare voting, blending, stacking strategies',
},
{
'name': 'Healthcare Safety Tuning',
'script': 'healthcare_safety_tuning.py',
'description': 'Optimize severe escalation thresholds, reduce false negatives',
},
{
'name': 'Explainability Validation',
'script': 'explainability_validation.py',
'description': 'Feature importance and explanation quality',
},
{
'name': 'Comprehensive Benchmarks',
'script': 'comprehensive_benchmark.py',
'description': 'Final metrics, confusion matrices, calibration analysis',
},
{
'name': 'Production Validation',
'script': 'production_validation.py',
'description': 'FastAPI compatibility, latency, memory, deployment readiness',
},
]
def run_phase(phase: Dict[str, Any]) -> Dict[str, Any]:
"""Execute a single phase."""
logger.info('=' * 70)
logger.info(f'PHASE: {phase["name"]}')
logger.info(f'Description: {phase["description"]}')
logger.info('=' * 70)
start_time = time.perf_counter()
try:
script_path = SRC_DIR / phase['script']
if not script_path.exists():
raise FileNotFoundError(f'Script not found: {script_path}')
cmd = ['python', str(script_path)]
if 'args' in phase:
cmd.extend(phase['args'])
logger.info(f'Running: {" ".join(cmd)}')
result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
elapsed = time.perf_counter() - start_time
if result.returncode == 0:
logger.info(f'✓ Phase completed in {elapsed:.1f}s')
return {
'phase': phase['name'],
'status': 'success',
'elapsed_seconds': elapsed,
'output': result.stdout[-500:] if result.stdout else '',
}
else:
logger.error(f'✗ Phase failed with code {result.returncode}')
logger.error(f'Stderr: {result.stderr[-500:]}')
return {
'phase': phase['name'],
'status': 'failed',
'elapsed_seconds': elapsed,
'error': result.stderr[-500:] if result.stderr else 'Unknown error',
}
except subprocess.TimeoutExpired:
logger.error(f'✗ Phase timed out (1 hour)')
return {
'phase': phase['name'],
'status': 'timeout',
'error': 'Execution timed out after 1 hour',
}
except Exception as e:
logger.error(f'✗ Phase failed: {e}', exc_info=True)
return {
'phase': phase['name'],
'status': 'error',
'error': str(e),
}
def main() -> None:
parser = argparse.ArgumentParser(description='Master workflow orchestration')
parser.add_argument('--phases', type=int, nargs='+', help='Specific phases to run (1-9)')
parser.add_argument('--skip-phases', type=int, nargs='+', help='Phases to skip')
parser.add_argument('--output', type=str, default=str(REPORTS_DIR / 'COMPLETE_WORKFLOW_REPORT.md'))
parser.add_argument('--log', type=str, default=str(REPORTS_DIR / 'workflow_execution_log.json'))
args = parser.parse_args()
logger.info('Starting MEDCARE-DDI AI System Optimization Workflow')
logger.info(f'Timestamp: {datetime.now().isoformat()}')
# Determine which phases to run
phases_to_run = PHASES
if args.phases:
phases_to_run = [PHASES[i-1] for i in args.phases if 1 <= i <= len(PHASES)]
if args.skip_phases:
skip_set = set(args.skip_phases)
phases_to_run = [p for i, p in enumerate(PHASES, 1) if i not in skip_set]
logger.info(f'Phases to run: {[p["name"] for p in phases_to_run]}')
# Execute phases
execution_log: List[Dict[str, Any]] = []
workflow_start = time.perf_counter()
for phase in phases_to_run:
result = run_phase(phase)
execution_log.append(result)
if result['status'] not in ['success']:
logger.warning(f'Phase {phase["name"]} did not complete successfully')
# Continue with next phase even if one fails
workflow_elapsed = time.perf_counter() - workflow_start
# Save execution log
log_path = Path(args.log)
log_path.parent.mkdir(parents=True, exist_ok=True)
log_path.write_text(json.dumps({
'timestamp': datetime.now().isoformat(),
'total_elapsed_seconds': workflow_elapsed,
'phases': execution_log,
}, indent=2), encoding='utf-8')
logger.info(f'Saved execution log to {log_path}')
# Generate comprehensive report
report_path = Path(args.output)
with report_path.open('w') as f:
f.write('# MEDCARE-DDI AI System Optimization - Complete Workflow Report\n\n')
f.write(f'**Timestamp:** {datetime.now().isoformat()}\n\n')
f.write(f'**Total Execution Time:** {workflow_elapsed / 60:.1f} minutes\n\n')
f.write('## Workflow Summary\n\n')
successful = sum(1 for r in execution_log if r['status'] == 'success')
failed = sum(1 for r in execution_log if r['status'] != 'success')
f.write(f'- Phases Completed: {successful}/{len(execution_log)}\n')
f.write(f'- Phases Failed: {failed}/{len(execution_log)}\n\n')
f.write('## Phase Execution Results\n\n')
f.write('| # | Phase | Status | Duration (s) |\n')
f.write('|---|-------|--------|---------------|\n')
for idx, result in enumerate(execution_log, 1):
status_icon = '✓' if result['status'] == 'success' else '✗'
elapsed = result.get('elapsed_seconds', 0)
f.write(
f"| {idx} | {result['phase']} | {status_icon} {result['status']} | "
f"{elapsed:.1f} |\n"
)
f.write('\n## Detailed Results\n\n')
for result in execution_log:
f.write(f'### {result["phase"]}\n\n')
f.write(f'**Status:** {result["status"]}\n\n')
if result['status'] == 'success':
f.write(f'**Elapsed:** {result.get("elapsed_seconds", 0):.1f}s\n\n')
if result.get('output'):
f.write('**Output:**\n\n```\n')
f.write(result['output'])
f.write('\n```\n\n')
else:
if result.get('error'):
f.write('**Error:**\n\n```\n')
f.write(result['error'])
f.write('\n```\n\n')
f.write('## Output Reports\n\n')
f.write('The following reports have been generated in `models/reports/`:\n\n')
report_files = [
'dataset_audit_report.json',
'class_balance_report.json',
'conflict_analysis.csv',
'embedding_benchmark_results.csv',
'embedding_ablation_report.md',
'optuna_trials.json',
'optuna_best_params.json',
'hyperparameter_optimization_report.md',
'ensemble_benchmark.csv',
'ensemble_ablation.md',
'safety_analysis_report.md',
'severe_case_review.csv',
'threshold_optimization.json',
'explainability_examples.md',
'feature_importance.csv',
'final_benchmark_report.md',
'benchmark_metrics.json',
'production_validation_report.md',
'production_readiness_report.md',
'final_model_card.md',
]
for report in report_files:
f.write(f'- `{report}`\n')
f.write('\n## Next Steps\n\n')
f.write('1. Review all generated reports in `models/reports/`\n')
f.write('2. Identify best-performing model based on healthcare metrics\n')
f.write('3. Verify production readiness via `production_readiness_report.md`\n')
f.write('4. Deploy final model using deployment instructions\n')
f.write('5. Monitor live predictions and calibration drift\n')
logger.info(f'Saved workflow report to {report_path}')
logger.info('✓ Complete workflow finished')
logger.info(f'Total time: {workflow_elapsed / 60:.1f} minutes')
if __name__ == '__main__':
main()