"""Corrected ablation study using fixed feature pipeline.""" from __future__ import annotations import json import sys import gc from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn.model_selection import train_test_split ROOT = Path(__file__).resolve().parents[2] sys.path.insert(0, str(ROOT)) from src.training.feature_pipeline_corrected import build_feature_pipeline from src.training.retrain_production_model import TrainConfig, train_and_evaluate REPORT_DIR = ROOT / 'models' / 'reports' REPORT_DIR.mkdir(parents=True, exist_ok=True) def _mask_groups(X: np.ndarray, group_slices: dict[str, tuple[int, int]], enabled_groups: list[str]) -> np.ndarray: """Mask features to enable/disable groups.""" masked = np.zeros_like(X) for group_name in enabled_groups: if group_name not in group_slices: continue start, end = group_slices[group_name] masked[:, start:end] = X[:, start:end] return masked def _save_confusion_matrix(cm: list[list[int]], labels: list[str], out_path: Path) -> None: """Save confusion matrix as PNG.""" matrix = np.asarray(cm, dtype=np.int64) fig, ax = plt.subplots(figsize=(6, 5)) im = ax.imshow(matrix, cmap='Blues') fig.colorbar(im, ax=ax) ax.set_xticks(range(len(labels))) ax.set_yticks(range(len(labels))) ax.set_xticklabels(labels, rotation=45, ha='right') ax.set_yticklabels(labels) ax.set_xlabel('Predicted') ax.set_ylabel('True') ax.set_title('Confusion Matrix') threshold = matrix.max() / 2.0 if matrix.size else 0 for i in range(matrix.shape[0]): for j in range(matrix.shape[1]): ax.text(j, i, str(matrix[i, j]), ha='center', va='center', color='white' if matrix[i, j] > threshold else 'black') fig.tight_layout() fig.savefig(out_path, dpi=160) plt.close(fig) def _markdown_table(df: pd.DataFrame) -> str: """Render DataFrame as markdown table.""" headers = list(df.columns) rows = [headers] for _, row in df.iterrows(): rows.append([str(row[col]) for col in headers]) widths = [max(len(cell) for cell in column) for column in zip(*rows)] lines = [] lines.append('| ' + ' | '.join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) + ' |') lines.append('| ' + ' | '.join('-' * widths[idx] for idx in range(len(headers))) + ' |') for row in rows[1:]: lines.append('| ' + ' | '.join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)) + ' |') return '\n'.join(lines) def main() -> None: """Run corrected ablation study.""" print('Building corrected feature pipeline (DDInter-only)...', flush=True) pairs_df, artifacts = build_feature_pipeline(save_artifacts=True, sample_size=3000, seed=2026) group_slices = artifacts.group_slices X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32) base_groups = ['pair_encoding', 'semantic_embeddings'] arms = { 'pair_encoding_only': ['pair_encoding'], 'pair_encoding_semantic': ['pair_encoding', 'semantic_embeddings'], 'pair_encoding_support': ['pair_encoding', 'pair_support'], 'full': ['pair_encoding', 'semantic_embeddings', 'pair_support'], } results: list[dict[str, object]] = [] summary_by_arm: dict[str, dict[str, object]] = {} for arm_name, enabled_groups in arms.items(): arm_X = _mask_groups(X, group_slices, enabled_groups) arm_df = pairs_df[['drug_a', 'drug_b', 'label', 'pair_id']].copy() arm_df['_X'] = list(arm_X.tolist()) train_df, temp_df = train_test_split(arm_df, test_size=0.2, stratify=arm_df['label'], random_state=2026) val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=2026) config = TrainConfig( seed=2026, embedding_dim=32, hidden_dim=64, dropout=0.15, lr=1e-3, batch_size=128, weight_decay=1e-5, epochs=2, loss_type='focal', sampler='weighted', class_weights=[], ) print(f'Running arm={arm_name} with groups={enabled_groups} and samples={len(arm_df)}', flush=True) report = train_and_evaluate(config, train_df, val_df, test_df, vocab={}) summary = { 'arm': arm_name, 'accuracy': report['accuracy'], 'macro_f1': report['macro_f1'], 'severe_recall': report['severe_recall'], 'num_test_examples': report['num_test_examples'], 'enabled_groups': enabled_groups, } results.append(summary) summary_by_arm[arm_name] = report cm_path = REPORT_DIR / f'ablation_confusion_matrix_{arm_name}.png' _save_confusion_matrix(report['confusion_matrix'], report['label_names'], cm_path) del arm_X, arm_df, train_df, val_df, test_df, report gc.collect() summary_df = pd.DataFrame(results).sort_values(by=['severe_recall', 'macro_f1'], ascending=False) summary_csv = REPORT_DIR / 'ablation_summary_corrected.csv' summary_df.to_csv(summary_csv, index=False) summary_json = REPORT_DIR / 'ablation_summary_corrected.json' summary_json.write_text(json.dumps(results, indent=2), encoding='utf-8') chart_path = REPORT_DIR / 'ablation_metrics_corrected.png' fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharex=True) for ax, metric in zip(axes, ['accuracy', 'macro_f1', 'severe_recall']): ax.bar(summary_df['arm'], summary_df[metric], color=['#4C78A8', '#72B7B2', '#F58518', '#54A24B']) ax.set_title(metric.replace('_', ' ').title()) ax.set_ylim(0, 1) ax.tick_params(axis='x', rotation=20) fig.tight_layout() fig.savefig(chart_path, dpi=160) plt.close(fig) report_md = REPORT_DIR / 'ablation_report_corrected.md' lines = [ '# Corrected Ablation Study Report', '', '## Summary', '', _markdown_table(summary_df), '', '## Interpretation', '', '- **pair_encoding_only**: Baseline using only hashed pair names.', '- **pair_encoding_semantic**: Adds drug name n-gram embeddings.', '- **pair_encoding_support**: Adds frequency of pair occurrence.', '- **full**: All three feature groups combined.', '', 'If ablation shows meaningful differences now, the features are working correctly.', '', '## Artifacts', '', f'- CSV: {summary_csv}', f'- JSON: {summary_json}', f'- Chart: {chart_path}', ] report_md.write_text('\n'.join(lines), encoding='utf-8') print('Corrected ablation complete.') print(f'Summary CSV: {summary_csv}') print(f'Summary JSON: {summary_json}') print(f'Report: {report_md}') if __name__ == '__main__': main()