Spaces:
Running
Running
| """Multisource ablation study for DDInter + DrugBank + TWOSIDES fusion.""" | |
| from __future__ import annotations | |
| import gc | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| ROOT = Path(__file__).resolve().parents[2] | |
| sys.path.insert(0, str(ROOT)) | |
| from .feature_pipeline_multisource import build_feature_pipeline | |
| from .retrain_production_model import TrainConfig, train_and_evaluate | |
| REPORT_DIR = ROOT / 'models' / 'reports' | |
| REPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| def _mask_groups(X: np.ndarray, group_slices: dict[str, tuple[int, int]], enabled_groups: list[str]) -> np.ndarray: | |
| masked = np.zeros_like(X) | |
| for group_name in enabled_groups: | |
| if group_name not in group_slices: | |
| continue | |
| start, end = group_slices[group_name] | |
| masked[:, start:end] = X[:, start:end] | |
| return masked | |
| def _markdown_table(df: pd.DataFrame) -> str: | |
| headers = list(df.columns) | |
| rows = [headers] | |
| for _, row in df.iterrows(): | |
| rows.append([str(row[col]) for col in headers]) | |
| widths = [max(len(cell) for cell in column) for column in zip(*rows)] | |
| lines = [] | |
| lines.append('| ' + ' | '.join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) + ' |') | |
| lines.append('| ' + ' | '.join('-' * widths[idx] for idx in range(len(headers))) + ' |') | |
| for row in rows[1:]: | |
| lines.append('| ' + ' | '.join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)) + ' |') | |
| return '\n'.join(lines) | |
| def _save_confusion_matrix(cm: list[list[int]], labels: list[str], out_path: Path) -> None: | |
| matrix = np.asarray(cm, dtype=np.int64) | |
| fig, ax = plt.subplots(figsize=(6, 5)) | |
| im = ax.imshow(matrix, cmap='Blues') | |
| fig.colorbar(im, ax=ax) | |
| ax.set_xticks(range(len(labels))) | |
| ax.set_yticks(range(len(labels))) | |
| ax.set_xticklabels(labels, rotation=45, ha='right') | |
| ax.set_yticklabels(labels) | |
| ax.set_xlabel('Predicted') | |
| ax.set_ylabel('True') | |
| ax.set_title('Confusion Matrix') | |
| threshold = matrix.max() / 2.0 if matrix.size else 0 | |
| for i in range(matrix.shape[0]): | |
| for j in range(matrix.shape[1]): | |
| ax.text(j, i, str(matrix[i, j]), ha='center', va='center', color='white' if matrix[i, j] > threshold else 'black') | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| def main() -> None: | |
| print('Building multisource feature pipeline...', flush=True) | |
| pairs_df, artifacts = build_feature_pipeline(save_artifacts=True, sample_size=3000, seed=2026) | |
| X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32) | |
| group_slices = artifacts.group_slices | |
| base_groups = ['semantic_embeddings', 'graph_topology', 'source_flags'] | |
| arms = { | |
| 'ddinter_only': base_groups, | |
| 'ddinter_drugbank': base_groups + ['drugbank_metadata'], | |
| 'ddinter_twosides': base_groups + ['twosides_signals'], | |
| 'full_fusion': base_groups + ['drugbank_metadata', 'twosides_signals'], | |
| 'full_no_semantic': ['drugbank_metadata', 'twosides_signals', 'graph_topology', 'source_flags'], | |
| 'full_no_graph': ['semantic_embeddings', 'drugbank_metadata', 'twosides_signals', 'source_flags'], | |
| } | |
| config = TrainConfig( | |
| seed=2026, | |
| embedding_dim=64, | |
| hidden_dim=128, | |
| dropout=0.2, | |
| lr=1e-3, | |
| batch_size=128, | |
| weight_decay=1e-5, | |
| epochs=3, | |
| loss_type='focal', | |
| sampler='weighted', | |
| class_weights=[], | |
| ) | |
| results: list[dict[str, object]] = [] | |
| for arm_name, enabled_groups in arms.items(): | |
| arm_X = _mask_groups(X, group_slices, enabled_groups) | |
| arm_df = pairs_df[['drug_a', 'drug_b', 'label']].copy() | |
| arm_df['_X'] = list(arm_X.tolist()) | |
| train_df, temp_df = train_test_split(arm_df, test_size=0.2, stratify=arm_df['label'], random_state=2026) | |
| val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=2026) | |
| print(f'Running arm={arm_name} with groups={enabled_groups}', flush=True) | |
| report = train_and_evaluate(config, train_df, val_df, test_df, vocab={}) | |
| result = { | |
| 'arm': arm_name, | |
| 'accuracy': report['accuracy'], | |
| 'macro_f1': report['macro_f1'], | |
| 'severe_recall': report['severe_recall'], | |
| 'num_test_examples': report['num_test_examples'], | |
| 'enabled_groups': enabled_groups, | |
| } | |
| results.append(result) | |
| _save_confusion_matrix(report['confusion_matrix'], report['label_names'], REPORT_DIR / f'ablation_confusion_matrix_{arm_name}.png') | |
| del arm_X, arm_df, train_df, val_df, test_df, report | |
| gc.collect() | |
| summary_df = pd.DataFrame(results).sort_values(by=['severe_recall', 'macro_f1'], ascending=False) | |
| summary_csv = REPORT_DIR / 'ablation_summary_multisource.csv' | |
| summary_json = REPORT_DIR / 'ablation_summary_multisource.json' | |
| summary_md = REPORT_DIR / 'ablation_report_multisource.md' | |
| chart_path = REPORT_DIR / 'ablation_metrics_multisource.png' | |
| summary_df.to_csv(summary_csv, index=False) | |
| summary_json.write_text(json.dumps(results, indent=2), encoding='utf-8') | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharex=True) | |
| for ax, metric in zip(axes, ['accuracy', 'macro_f1', 'severe_recall']): | |
| ax.bar(summary_df['arm'], summary_df[metric], color=['#4C78A8', '#72B7B2', '#F58518', '#54A24B', '#E45756', '#B279A2']) | |
| ax.set_title(metric.replace('_', ' ').title()) | |
| ax.set_ylim(0, 1) | |
| ax.tick_params(axis='x', rotation=20) | |
| fig.tight_layout() | |
| fig.savefig(chart_path, dpi=160) | |
| plt.close(fig) | |
| md_lines = [ | |
| '# Multisource Ablation Study', | |
| '', | |
| _markdown_table(summary_df), | |
| '', | |
| '## Arms', | |
| '', | |
| '- `ddinter_only`: semantic_embeddings + graph_topology + source_flags', | |
| '- `ddinter_drugbank`: ddinter_only + drugbank_metadata', | |
| '- `ddinter_twosides`: ddinter_only + twosides_signals', | |
| '- `full_fusion`: all groups', | |
| '- `full_no_semantic`: full fusion without semantic_embeddings', | |
| '- `full_no_graph`: full fusion without graph_topology', | |
| '', | |
| '## Artifacts', | |
| '', | |
| f'- CSV: {summary_csv}', | |
| f'- JSON: {summary_json}', | |
| f'- Chart: {chart_path}', | |
| ] | |
| summary_md.write_text('\n'.join(md_lines), encoding='utf-8') | |
| print('Ablation complete.') | |
| print(f'Summary CSV: {summary_csv}') | |
| print(f'Summary JSON: {summary_json}') | |
| print(f'Report: {summary_md}') | |
| if __name__ == '__main__': | |
| main() | |