Spaces:
Running
Running
| """Ensemble ablation study. | |
| Compares: | |
| - Voting | |
| - Blending | |
| - Stacking | |
| - Individual models (XGBoost, LightGBM, MLP, RF) | |
| Output: | |
| - ensemble_benchmark.csv | |
| - ensemble_ablation.md | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from preprocessing.artifact_manager import manager | |
| from sklearn.metrics import accuracy_score, f1_score, recall_score, roc_auc_score | |
| from sklearn.model_selection import train_test_split | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', | |
| ) | |
| logger = logging.getLogger('medcare_ddi.ensemble_ablation') | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| DATA_DIR = BASE_DIR / 'data' | |
| PROCESSED_DIR = DATA_DIR / 'processed' | |
| MODEL_DIR = BASE_DIR / 'models' | |
| REPORTS_DIR = MODEL_DIR / 'reports' | |
| REPORTS_DIR.mkdir(parents=True, exist_ok=True) | |
| LABEL_NAMES = ['unknown', 'minor', 'moderate', 'major'] | |
| LABEL_TO_INDEX = {label: idx for idx, label in enumerate(LABEL_NAMES)} | |
| def load_training_data() -> tuple[np.ndarray, np.ndarray]: | |
| """Load preprocessed features and labels.""" | |
| feature_pipeline_path = MODEL_DIR / 'feature_pipeline_multisource.pkl' | |
| if not feature_pipeline_path.exists(): | |
| raise FileNotFoundError(f'Feature pipeline not found: {feature_pipeline_path}') | |
| feature_pipeline = joblib.load(feature_pipeline_path) | |
| ddinter_path = PROCESSED_DIR / 'ddinter_combined.parquet' | |
| if not ddinter_path.exists(): | |
| raise FileNotFoundError(f'DDInter not found: {ddinter_path}') | |
| df = manager.load_artifact('ddinter_combined') | |
| logger.info(f'Loaded {len(df)} DDInter records') | |
| y = np.array([LABEL_TO_INDEX.get(str(lbl).lower(), 0) for lbl in df['Level']], dtype=np.int64) | |
| # Create features | |
| from training.feature_pipeline_multisource import transform_pair_features | |
| features = [] | |
| for _, row in df.iterrows(): | |
| try: | |
| vec = transform_pair_features(row['Drug_A'], row['Drug_B'], feature_pipeline) | |
| features.append(vec) | |
| except Exception as e: | |
| logger.warning(f'Feature extraction failed: {e}') | |
| continue | |
| X = np.vstack(features).astype(np.float32) | |
| return X[:len(features)], y[:len(features)] | |
| def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> Dict[str, float]: | |
| """Compute all metrics.""" | |
| severe_idx = LABEL_TO_INDEX['major'] | |
| accuracy = float(accuracy_score(y_true, y_pred)) | |
| macro_f1 = float(f1_score(y_true, y_pred, average='macro', zero_division=0)) | |
| severe_recall = float(recall_score(y_true, y_pred, labels=[severe_idx], average='macro', zero_division=0)) | |
| try: | |
| y_true_ovr = np.eye(len(LABEL_NAMES))[y_true] | |
| auroc = float(roc_auc_score(y_true_ovr, y_proba, average='macro', multi_class='ovr')) | |
| except Exception: | |
| auroc = 0.0 | |
| healthcare_score = 0.4 * severe_recall + 0.3 * macro_f1 + 0.2 * auroc | |
| return { | |
| 'accuracy': accuracy, | |
| 'macro_f1': macro_f1, | |
| 'severe_recall': severe_recall, | |
| 'auroc': auroc, | |
| 'healthcare_score': healthcare_score, | |
| } | |
| def benchmark_ensemble_strategies(X_train: np.ndarray, X_val: np.ndarray, y_train: np.ndarray, y_val: np.ndarray) -> Dict[str, Any]: | |
| """Compare different ensemble strategies.""" | |
| logger.info('Training base models...') | |
| from training.ensemble import train_base_models, EnsemblePredictor | |
| ensemble_dir = REPORTS_DIR / 'ensemble_ablation_base' | |
| train_base_models(X_train, y_train, ensemble_dir, random_state=2026) | |
| # Load individual models | |
| models = {} | |
| for name in ['xgb', 'lgbm', 'mlp', 'rf']: | |
| path = ensemble_dir / f'{name}.joblib' | |
| if path.exists(): | |
| models[name] = joblib.load(path) | |
| results = {} | |
| # Individual models | |
| for name, model in models.items(): | |
| logger.info(f'Evaluating {name}...') | |
| if hasattr(model, 'predict_proba'): | |
| probs = model.predict_proba(X_val) | |
| preds = np.argmax(probs, axis=1) | |
| metrics = compute_metrics(y_val, preds, probs) | |
| results[name] = metrics | |
| # Voting | |
| if (ensemble_dir / 'voting.joblib').exists(): | |
| logger.info('Evaluating voting ensemble...') | |
| voting = joblib.load(ensemble_dir / 'voting.joblib') | |
| probs = voting.predict_proba(X_val) | |
| preds = np.argmax(probs, axis=1) | |
| metrics = compute_metrics(y_val, preds, probs) | |
| results['voting'] = metrics | |
| # Calibrated voting | |
| if (ensemble_dir / 'calibrated_voting.joblib').exists(): | |
| logger.info('Evaluating calibrated voting...') | |
| calib = joblib.load(ensemble_dir / 'calibrated_voting.joblib') | |
| probs = calib.predict_proba(X_val) | |
| preds = np.argmax(probs, axis=1) | |
| metrics = compute_metrics(y_val, preds, probs) | |
| results['calibrated_voting'] = metrics | |
| # Stacking | |
| if (ensemble_dir / 'stacker.joblib').exists(): | |
| logger.info('Evaluating stacker...') | |
| stacker = joblib.load(ensemble_dir / 'stacker.joblib') | |
| # Get base probs for stacking | |
| base_probs = [] | |
| for name in ['xgb', 'lgbm', 'mlp', 'rf']: | |
| if name in models and hasattr(models[name], 'predict_proba'): | |
| base_probs.append(models[name].predict_proba(X_val)) | |
| if base_probs: | |
| stacked = np.hstack(base_probs) | |
| probs = stacker.predict_proba(stacked) | |
| preds = np.argmax(probs, axis=1) | |
| metrics = compute_metrics(y_val, preds, probs) | |
| results['stacking'] = metrics | |
| return results | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description='Ensemble ablation study') | |
| parser.add_argument('--seed', type=int, default=2026) | |
| parser.add_argument('--output-csv', type=str, default=str(REPORTS_DIR / 'ensemble_benchmark.csv')) | |
| parser.add_argument('--output-md', type=str, default=str(REPORTS_DIR / 'ensemble_ablation.md')) | |
| args = parser.parse_args() | |
| logger.info('Loading data...') | |
| X, y = load_training_data() | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X, y, test_size=0.2, random_state=args.seed, stratify=y | |
| ) | |
| logger.info(f'Train: {X_train.shape}, Val: {X_val.shape}') | |
| # Benchmark | |
| results = benchmark_ensemble_strategies(X_train, X_val, y_train, y_val) | |
| # Save CSV | |
| csv_path = Path(args.output_csv) | |
| csv_path.parent.mkdir(parents=True, exist_ok=True) | |
| with csv_path.open('w', newline='') as f: | |
| fieldnames = ['model', 'accuracy', 'macro_f1', 'severe_recall', 'auroc', 'healthcare_score'] | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| writer.writeheader() | |
| for model_name, metrics in results.items(): | |
| writer.writerow({ | |
| 'model': model_name, | |
| 'accuracy': metrics.get('accuracy', 0), | |
| 'macro_f1': metrics.get('macro_f1', 0), | |
| 'severe_recall': metrics.get('severe_recall', 0), | |
| 'auroc': metrics.get('auroc', 0), | |
| 'healthcare_score': metrics.get('healthcare_score', 0), | |
| }) | |
| logger.info(f'Saved CSV to {csv_path}') | |
| # Save markdown report | |
| md_path = Path(args.output_md) | |
| with md_path.open('w') as f: | |
| f.write('# Ensemble Ablation Study\n\n') | |
| f.write('## Summary\n\n') | |
| if results: | |
| best_by_severe = max(results.items(), key=lambda x: x[1].get('severe_recall', 0)) | |
| f.write(f'**Best by Severe Recall: {best_by_severe[0]}**\n\n') | |
| f.write(f'- Severe Recall: {best_by_severe[1].get("severe_recall", 0):.4f}\n') | |
| f.write(f'- Accuracy: {best_by_severe[1].get("accuracy", 0):.4f}\n') | |
| f.write(f'- Macro F1: {best_by_severe[1].get("macro_f1", 0):.4f}\n') | |
| f.write(f'- AUROC: {best_by_severe[1].get("auroc", 0):.4f}\n') | |
| f.write(f'- Healthcare Score: {best_by_severe[1].get("healthcare_score", 0):.4f}\n\n') | |
| f.write('## Results\n\n') | |
| f.write('| Model | Accuracy | Macro F1 | Severe Recall | AUROC | Healthcare Score |\n') | |
| f.write('|-------|----------|----------|---------------|-------|------------------|\n') | |
| for model_name, metrics in sorted(results.items(), key=lambda x: x[1].get('healthcare_score', 0), reverse=True): | |
| f.write( | |
| f"| {model_name} | " | |
| f"{metrics.get('accuracy', 0):.4f} | " | |
| f"{metrics.get('macro_f1', 0):.4f} | " | |
| f"{metrics.get('severe_recall', 0):.4f} | " | |
| f"{metrics.get('auroc', 0):.4f} | " | |
| f"{metrics.get('healthcare_score', 0):.4f} |\n" | |
| ) | |
| logger.info(f'Saved report to {md_path}') | |
| logger.info('✓ Ensemble ablation complete') | |
| if __name__ == '__main__': | |
| main() | |