Spaces:
Running
Running
| """Ablation study for the enriched MEDCARE-DDI feature pipeline. | |
| Runs four arms: | |
| - DDInter only | |
| - DDInter + DrugBank | |
| - DDInter + TWOSIDES | |
| - Full pipeline | |
| Outputs: | |
| - CSV summary | |
| - Markdown report | |
| - bar chart | |
| - confusion matrix PNG per arm | |
| - JSON summary | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| import gc | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| ROOT = Path(__file__).resolve().parents[2] | |
| sys.path.insert(0, str(ROOT / 'src')) | |
| from training.feature_pipeline import build_feature_pipeline | |
| from training.retrain_production_model import TrainConfig, train_and_evaluate | |
| REPORT_DIR = ROOT / 'models' / 'reports' | |
| REPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| def _stratified_sample(df: pd.DataFrame, sample_size: int, seed: int = 2026) -> pd.DataFrame: | |
| if sample_size <= 0 or sample_size >= len(df): | |
| return df.copy() | |
| group_cols = ['label'] | |
| grouped = df.groupby(group_cols, group_keys=False) | |
| fractions = min(1.0, sample_size / float(len(df))) | |
| sampled = grouped.apply(lambda part: part.sample(frac=fractions, random_state=seed)) | |
| if len(sampled) > sample_size: | |
| sampled = sampled.sample(n=sample_size, random_state=seed) | |
| return sampled.reset_index(drop=True) | |
| def _mask_groups(X: np.ndarray, group_slices: dict[str, tuple[int, int]], enabled_groups: list[str]) -> np.ndarray: | |
| masked = np.zeros_like(X) | |
| for group_name in enabled_groups: | |
| if group_name not in group_slices: | |
| continue | |
| start, end = group_slices[group_name] | |
| masked[:, start:end] = X[:, start:end] | |
| return masked | |
| def _save_confusion_matrix(cm: list[list[int]], labels: list[str], out_path: Path) -> None: | |
| matrix = np.asarray(cm, dtype=np.int64) | |
| fig, ax = plt.subplots(figsize=(6, 5)) | |
| im = ax.imshow(matrix, cmap='Blues') | |
| fig.colorbar(im, ax=ax) | |
| ax.set_xticks(range(len(labels))) | |
| ax.set_yticks(range(len(labels))) | |
| ax.set_xticklabels(labels, rotation=45, ha='right') | |
| ax.set_yticklabels(labels) | |
| ax.set_xlabel('Predicted') | |
| ax.set_ylabel('True') | |
| ax.set_title('Confusion Matrix') | |
| threshold = matrix.max() / 2.0 if matrix.size else 0 | |
| for i in range(matrix.shape[0]): | |
| for j in range(matrix.shape[1]): | |
| ax.text(j, i, str(matrix[i, j]), ha='center', va='center', color='white' if matrix[i, j] > threshold else 'black') | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| def _markdown_table(df: pd.DataFrame) -> str: | |
| headers = list(df.columns) | |
| rows = [headers] | |
| for _, row in df.iterrows(): | |
| rows.append([str(row[col]) for col in headers]) | |
| widths = [max(len(cell) for cell in column) for column in zip(*rows)] | |
| lines = [] | |
| lines.append('| ' + ' | '.join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) + ' |') | |
| lines.append('| ' + ' | '.join('-' * widths[idx] for idx in range(len(headers))) + ' |') | |
| for row in rows[1:]: | |
| lines.append('| ' + ' | '.join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)) + ' |') | |
| return '\n'.join(lines) | |
| def main() -> None: | |
| print('Building shared feature pipeline...', flush=True) | |
| pairs_df, artifacts = build_feature_pipeline(save_artifacts=True, sample_size=3000, seed=2026) | |
| group_slices = artifacts.group_slices | |
| X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32) | |
| base_groups = ['pair_encoding', 'semantic_embeddings', 'normalized_numeric'] | |
| arms = { | |
| 'ddinter_only': base_groups, | |
| 'ddinter_drugbank': base_groups + ['drugbank_active', 'drugbank_atc', 'drugbank_category'], | |
| 'ddinter_twosides': base_groups + ['twosides_signal'], | |
| 'full': base_groups + ['drugbank_active', 'drugbank_atc', 'drugbank_category', 'twosides_signal', 'polypharmacy'], | |
| } | |
| results: list[dict[str, object]] = [] | |
| summary_by_arm: dict[str, dict[str, object]] = {} | |
| for arm_name, enabled_groups in arms.items(): | |
| arm_X = _mask_groups(X, group_slices, enabled_groups) | |
| arm_df = pairs_df[['drug_a', 'drug_b', 'label', 'support']].copy() | |
| arm_df['_X'] = list(arm_X.tolist()) | |
| train_df, temp_df = train_test_split(arm_df, test_size=0.2, stratify=arm_df['label'], random_state=2026) | |
| val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=2026) | |
| config = TrainConfig( | |
| seed=2026, | |
| embedding_dim=32, | |
| hidden_dim=64, | |
| dropout=0.15, | |
| lr=1e-3, | |
| batch_size=128, | |
| weight_decay=1e-5, | |
| epochs=2, | |
| loss_type='focal', | |
| sampler='weighted', | |
| class_weights=[], | |
| ) | |
| print(f'Running arm={arm_name} with groups={enabled_groups} and samples={len(arm_df)}', flush=True) | |
| report = train_and_evaluate(config, train_df, val_df, test_df, vocab={}) | |
| summary = { | |
| 'arm': arm_name, | |
| 'accuracy': report['accuracy'], | |
| 'macro_f1': report['macro_f1'], | |
| 'severe_recall': report['severe_recall'], | |
| 'num_test_examples': report['num_test_examples'], | |
| 'enabled_groups': enabled_groups, | |
| } | |
| results.append(summary) | |
| summary_by_arm[arm_name] = report | |
| cm_path = REPORT_DIR / f'ablation_confusion_matrix_{arm_name}.png' | |
| _save_confusion_matrix(report['confusion_matrix'], report['label_names'], cm_path) | |
| del arm_X, arm_df, train_df, val_df, test_df, report | |
| gc.collect() | |
| summary_df = pd.DataFrame(results).sort_values(by=['severe_recall', 'macro_f1'], ascending=False) | |
| summary_csv = REPORT_DIR / 'ablation_summary.csv' | |
| summary_df.to_csv(summary_csv, index=False) | |
| summary_json = REPORT_DIR / 'ablation_summary.json' | |
| summary_json.write_text(json.dumps(results, indent=2), encoding='utf-8') | |
| chart_path = REPORT_DIR / 'ablation_metrics.png' | |
| fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharex=True) | |
| for ax, metric in zip(axes, ['accuracy', 'macro_f1', 'severe_recall']): | |
| ax.bar(summary_df['arm'], summary_df[metric], color=['#4C78A8', '#72B7B2', '#F58518', '#54A24B']) | |
| ax.set_title(metric.replace('_', ' ').title()) | |
| ax.set_ylim(0, 1) | |
| ax.tick_params(axis='x', rotation=20) | |
| fig.tight_layout() | |
| fig.savefig(chart_path, dpi=160) | |
| plt.close(fig) | |
| report_md = REPORT_DIR / 'ablation_report.md' | |
| lines = [ | |
| '# Ablation Study Report', | |
| '', | |
| '## Summary', | |
| '', | |
| _markdown_table(summary_df), | |
| '', | |
| '## Interpretation', | |
| '', | |
| '- DDInter-only is the baseline arm.', | |
| '- DrugBank arm adds active ingredients, ATC codes, and categories.', | |
| '- TWOSIDES arm adds side-effect signal overlap.', | |
| '- Full pipeline includes all feature groups and should perform best if the feature integration is correct.', | |
| '', | |
| '## Artifacts', | |
| '', | |
| f'- CSV: {summary_csv}', | |
| f'- JSON: {summary_json}', | |
| f'- Chart: {chart_path}', | |
| ] | |
| report_md.write_text('\n'.join(lines), encoding='utf-8') | |
| print('Ablation complete.') | |
| print(f'Summary CSV: {summary_csv}') | |
| print(f'Summary JSON: {summary_json}') | |
| print(f'Report: {report_md}') | |
| print(f'Chart: {chart_path}') | |
| if __name__ == '__main__': | |
| main() | |