"""Multisource ablation study for DDInter + DrugBank + TWOSIDES fusion.""" from __future__ import annotations import gc import json import sys from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn.model_selection import train_test_split ROOT = Path(__file__).resolve().parents[2] sys.path.insert(0, str(ROOT)) from .feature_pipeline_multisource import build_feature_pipeline from .retrain_production_model import TrainConfig, train_and_evaluate REPORT_DIR = ROOT / 'models' / 'reports' REPORT_DIR.mkdir(parents=True, exist_ok=True) def _mask_groups(X: np.ndarray, group_slices: dict[str, tuple[int, int]], enabled_groups: list[str]) -> np.ndarray: masked = np.zeros_like(X) for group_name in enabled_groups: if group_name not in group_slices: continue start, end = group_slices[group_name] masked[:, start:end] = X[:, start:end] return masked def _markdown_table(df: pd.DataFrame) -> str: headers = list(df.columns) rows = [headers] for _, row in df.iterrows(): rows.append([str(row[col]) for col in headers]) widths = [max(len(cell) for cell in column) for column in zip(*rows)] lines = [] lines.append('| ' + ' | '.join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) + ' |') lines.append('| ' + ' | '.join('-' * widths[idx] for idx in range(len(headers))) + ' |') for row in rows[1:]: lines.append('| ' + ' | '.join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)) + ' |') return '\n'.join(lines) def _save_confusion_matrix(cm: list[list[int]], labels: list[str], out_path: Path) -> None: matrix = np.asarray(cm, dtype=np.int64) fig, ax = plt.subplots(figsize=(6, 5)) im = ax.imshow(matrix, cmap='Blues') fig.colorbar(im, ax=ax) ax.set_xticks(range(len(labels))) ax.set_yticks(range(len(labels))) ax.set_xticklabels(labels, rotation=45, ha='right') ax.set_yticklabels(labels) ax.set_xlabel('Predicted') ax.set_ylabel('True') ax.set_title('Confusion Matrix') threshold = matrix.max() / 2.0 if matrix.size else 0 for i in range(matrix.shape[0]): for j in range(matrix.shape[1]): ax.text(j, i, str(matrix[i, j]), ha='center', va='center', color='white' if matrix[i, j] > threshold else 'black') fig.tight_layout() fig.savefig(out_path, dpi=160) plt.close(fig) def main() -> None: print('Building multisource feature pipeline...', flush=True) pairs_df, artifacts = build_feature_pipeline(save_artifacts=True, sample_size=3000, seed=2026) X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32) group_slices = artifacts.group_slices base_groups = ['semantic_embeddings', 'graph_topology', 'source_flags'] arms = { 'ddinter_only': base_groups, 'ddinter_drugbank': base_groups + ['drugbank_metadata'], 'ddinter_twosides': base_groups + ['twosides_signals'], 'full_fusion': base_groups + ['drugbank_metadata', 'twosides_signals'], 'full_no_semantic': ['drugbank_metadata', 'twosides_signals', 'graph_topology', 'source_flags'], 'full_no_graph': ['semantic_embeddings', 'drugbank_metadata', 'twosides_signals', 'source_flags'], } config = TrainConfig( seed=2026, embedding_dim=64, hidden_dim=128, dropout=0.2, lr=1e-3, batch_size=128, weight_decay=1e-5, epochs=3, loss_type='focal', sampler='weighted', class_weights=[], ) results: list[dict[str, object]] = [] for arm_name, enabled_groups in arms.items(): arm_X = _mask_groups(X, group_slices, enabled_groups) arm_df = pairs_df[['drug_a', 'drug_b', 'label']].copy() arm_df['_X'] = list(arm_X.tolist()) train_df, temp_df = train_test_split(arm_df, test_size=0.2, stratify=arm_df['label'], random_state=2026) val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=2026) print(f'Running arm={arm_name} with groups={enabled_groups}', flush=True) report = train_and_evaluate(config, train_df, val_df, test_df, vocab={}) result = { 'arm': arm_name, 'accuracy': report['accuracy'], 'macro_f1': report['macro_f1'], 'severe_recall': report['severe_recall'], 'num_test_examples': report['num_test_examples'], 'enabled_groups': enabled_groups, } results.append(result) _save_confusion_matrix(report['confusion_matrix'], report['label_names'], REPORT_DIR / f'ablation_confusion_matrix_{arm_name}.png') del arm_X, arm_df, train_df, val_df, test_df, report gc.collect() summary_df = pd.DataFrame(results).sort_values(by=['severe_recall', 'macro_f1'], ascending=False) summary_csv = REPORT_DIR / 'ablation_summary_multisource.csv' summary_json = REPORT_DIR / 'ablation_summary_multisource.json' summary_md = REPORT_DIR / 'ablation_report_multisource.md' chart_path = REPORT_DIR / 'ablation_metrics_multisource.png' summary_df.to_csv(summary_csv, index=False) summary_json.write_text(json.dumps(results, indent=2), encoding='utf-8') fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharex=True) for ax, metric in zip(axes, ['accuracy', 'macro_f1', 'severe_recall']): ax.bar(summary_df['arm'], summary_df[metric], color=['#4C78A8', '#72B7B2', '#F58518', '#54A24B', '#E45756', '#B279A2']) ax.set_title(metric.replace('_', ' ').title()) ax.set_ylim(0, 1) ax.tick_params(axis='x', rotation=20) fig.tight_layout() fig.savefig(chart_path, dpi=160) plt.close(fig) md_lines = [ '# Multisource Ablation Study', '', _markdown_table(summary_df), '', '## Arms', '', '- `ddinter_only`: semantic_embeddings + graph_topology + source_flags', '- `ddinter_drugbank`: ddinter_only + drugbank_metadata', '- `ddinter_twosides`: ddinter_only + twosides_signals', '- `full_fusion`: all groups', '- `full_no_semantic`: full fusion without semantic_embeddings', '- `full_no_graph`: full fusion without graph_topology', '', '## Artifacts', '', f'- CSV: {summary_csv}', f'- JSON: {summary_json}', f'- Chart: {chart_path}', ] summary_md.write_text('\n'.join(md_lines), encoding='utf-8') print('Ablation complete.') print(f'Summary CSV: {summary_csv}') print(f'Summary JSON: {summary_json}') print(f'Report: {summary_md}') if __name__ == '__main__': main()