ddi / src /training /ablation_study_corrected.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""Corrected ablation study using fixed feature pipeline."""
from __future__ import annotations
import json
import sys
import gc
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT))
from src.training.feature_pipeline_corrected import build_feature_pipeline
from src.training.retrain_production_model import TrainConfig, train_and_evaluate
REPORT_DIR = ROOT / 'models' / 'reports'
REPORT_DIR.mkdir(parents=True, exist_ok=True)
def _mask_groups(X: np.ndarray, group_slices: dict[str, tuple[int, int]], enabled_groups: list[str]) -> np.ndarray:
"""Mask features to enable/disable groups."""
masked = np.zeros_like(X)
for group_name in enabled_groups:
if group_name not in group_slices:
continue
start, end = group_slices[group_name]
masked[:, start:end] = X[:, start:end]
return masked
def _save_confusion_matrix(cm: list[list[int]], labels: list[str], out_path: Path) -> None:
"""Save confusion matrix as PNG."""
matrix = np.asarray(cm, dtype=np.int64)
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(matrix, cmap='Blues')
fig.colorbar(im, ax=ax)
ax.set_xticks(range(len(labels)))
ax.set_yticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha='right')
ax.set_yticklabels(labels)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title('Confusion Matrix')
threshold = matrix.max() / 2.0 if matrix.size else 0
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
ax.text(j, i, str(matrix[i, j]), ha='center', va='center',
color='white' if matrix[i, j] > threshold else 'black')
fig.tight_layout()
fig.savefig(out_path, dpi=160)
plt.close(fig)
def _markdown_table(df: pd.DataFrame) -> str:
"""Render DataFrame as markdown table."""
headers = list(df.columns)
rows = [headers]
for _, row in df.iterrows():
rows.append([str(row[col]) for col in headers])
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
lines = []
lines.append('| ' + ' | '.join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) + ' |')
lines.append('| ' + ' | '.join('-' * widths[idx] for idx in range(len(headers))) + ' |')
for row in rows[1:]:
lines.append('| ' + ' | '.join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)) + ' |')
return '\n'.join(lines)
def main() -> None:
"""Run corrected ablation study."""
print('Building corrected feature pipeline (DDInter-only)...', flush=True)
pairs_df, artifacts = build_feature_pipeline(save_artifacts=True, sample_size=3000, seed=2026)
group_slices = artifacts.group_slices
X = np.asarray(list(pairs_df['_X'].values), dtype=np.float32)
base_groups = ['pair_encoding', 'semantic_embeddings']
arms = {
'pair_encoding_only': ['pair_encoding'],
'pair_encoding_semantic': ['pair_encoding', 'semantic_embeddings'],
'pair_encoding_support': ['pair_encoding', 'pair_support'],
'full': ['pair_encoding', 'semantic_embeddings', 'pair_support'],
}
results: list[dict[str, object]] = []
summary_by_arm: dict[str, dict[str, object]] = {}
for arm_name, enabled_groups in arms.items():
arm_X = _mask_groups(X, group_slices, enabled_groups)
arm_df = pairs_df[['drug_a', 'drug_b', 'label', 'pair_id']].copy()
arm_df['_X'] = list(arm_X.tolist())
train_df, temp_df = train_test_split(arm_df, test_size=0.2, stratify=arm_df['label'], random_state=2026)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=2026)
config = TrainConfig(
seed=2026,
embedding_dim=32,
hidden_dim=64,
dropout=0.15,
lr=1e-3,
batch_size=128,
weight_decay=1e-5,
epochs=2,
loss_type='focal',
sampler='weighted',
class_weights=[],
)
print(f'Running arm={arm_name} with groups={enabled_groups} and samples={len(arm_df)}', flush=True)
report = train_and_evaluate(config, train_df, val_df, test_df, vocab={})
summary = {
'arm': arm_name,
'accuracy': report['accuracy'],
'macro_f1': report['macro_f1'],
'severe_recall': report['severe_recall'],
'num_test_examples': report['num_test_examples'],
'enabled_groups': enabled_groups,
}
results.append(summary)
summary_by_arm[arm_name] = report
cm_path = REPORT_DIR / f'ablation_confusion_matrix_{arm_name}.png'
_save_confusion_matrix(report['confusion_matrix'], report['label_names'], cm_path)
del arm_X, arm_df, train_df, val_df, test_df, report
gc.collect()
summary_df = pd.DataFrame(results).sort_values(by=['severe_recall', 'macro_f1'], ascending=False)
summary_csv = REPORT_DIR / 'ablation_summary_corrected.csv'
summary_df.to_csv(summary_csv, index=False)
summary_json = REPORT_DIR / 'ablation_summary_corrected.json'
summary_json.write_text(json.dumps(results, indent=2), encoding='utf-8')
chart_path = REPORT_DIR / 'ablation_metrics_corrected.png'
fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharex=True)
for ax, metric in zip(axes, ['accuracy', 'macro_f1', 'severe_recall']):
ax.bar(summary_df['arm'], summary_df[metric], color=['#4C78A8', '#72B7B2', '#F58518', '#54A24B'])
ax.set_title(metric.replace('_', ' ').title())
ax.set_ylim(0, 1)
ax.tick_params(axis='x', rotation=20)
fig.tight_layout()
fig.savefig(chart_path, dpi=160)
plt.close(fig)
report_md = REPORT_DIR / 'ablation_report_corrected.md'
lines = [
'# Corrected Ablation Study Report',
'',
'## Summary',
'',
_markdown_table(summary_df),
'',
'## Interpretation',
'',
'- **pair_encoding_only**: Baseline using only hashed pair names.',
'- **pair_encoding_semantic**: Adds drug name n-gram embeddings.',
'- **pair_encoding_support**: Adds frequency of pair occurrence.',
'- **full**: All three feature groups combined.',
'',
'If ablation shows meaningful differences now, the features are working correctly.',
'',
'## Artifacts',
'',
f'- CSV: {summary_csv}',
f'- JSON: {summary_json}',
f'- Chart: {chart_path}',
]
report_md.write_text('\n'.join(lines), encoding='utf-8')
print('Corrected ablation complete.')
print(f'Summary CSV: {summary_csv}')
print(f'Summary JSON: {summary_json}')
print(f'Report: {report_md}')
if __name__ == '__main__':
main()