openmed-autoresearch / make_graphs.py
AutoResearch Agent
add blog post with graphs (heatmap, timeline, asymmetry)
64fd1eb
"""Generate blog post graphs for the autoresearch findings."""
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.size'] = 13
# ── Graph 1: Transfer Affinity Heatmap ──────────────────────────────────────
fig, ax = plt.subplots(figsize=(7, 5))
sources = ['bc5cdr_chem', 'ncbi_disease', 'jnlpba', 'bc2gm', 'linnaeus']
source_labels = ['Chemicals\n(BC5CDR)', 'Diseases\n(NCBI)', 'Proteins/DNA\n(JNLPBA)', 'Genes\n(BC2GM)', 'Species\n(Linnaeus)']
targets = ['ncbi_disease', 'bc5cdr_chem']
target_labels = ['Disease NER\n(NCBI Disease)', 'Chemical NER\n(BC5CDR-Chem)']
# Ξ”F1 values from experiments (50/50 split)
data = np.array([
[+0.044, np.nan], # bc5cdr_chem β†’ disease, bc5cdr_chem β†’ chem (self)
[np.nan, -0.001], # ncbi_disease β†’ disease (self), ncbi_disease β†’ chem
[+0.013, -0.004], # jnlpba
[-0.007, -0.001], # bc2gm
[-0.033, +0.001], # linnaeus
])
# Mask self-transfers
mask = np.isnan(data)
masked_data = np.ma.array(data, mask=mask)
cmap = plt.cm.RdYlGn
cmap.set_bad('lightgray')
im = ax.imshow(masked_data, cmap=cmap, aspect='auto', vmin=-0.04, vmax=0.05)
ax.set_xticks(range(len(target_labels)))
ax.set_xticklabels(target_labels, fontsize=12)
ax.set_yticks(range(len(source_labels)))
ax.set_yticklabels(source_labels, fontsize=11)
ax.set_xlabel('Target Task β†’', fontsize=13, fontweight='bold', labelpad=10)
ax.set_ylabel('← Pretrain Source', fontsize=13, fontweight='bold', labelpad=10)
# Add text annotations
for i in range(len(sources)):
for j in range(len(targets)):
if mask[i, j]:
ax.text(j, i, 'β€”', ha='center', va='center', fontsize=14, color='gray')
else:
val = data[i, j]
sign = '+' if val > 0 else ''
color = 'white' if abs(val) > 0.025 else 'black'
ax.text(j, i, f'{sign}{val:.1%}', ha='center', va='center',
fontsize=14, fontweight='bold', color=color)
cbar = plt.colorbar(im, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label('Ξ”F1 vs baseline', fontsize=11)
cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:+.0%}'))
ax.set_title('Cross-Dataset Transfer Affinity in Biomedical NER', fontsize=14, fontweight='bold', pad=15)
plt.tight_layout()
plt.savefig('graph_transfer_heatmap.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.close()
print("Saved graph_transfer_heatmap.png")
# ── Graph 2: Improvement Timeline (Disease Target) ─────────────────────────
fig, ax = plt.subplots(figsize=(9, 5))
experiments = [0, 2, 9, 13, 15]
f1_scores = [0.8033, 0.8470, 0.8519, 0.8543, 0.8535]
labels = [
'Baseline\n(disease only)',
'Chem pretrain\n(50/50)',
'3-stage\n(chem→jnlpba→disease)',
'Time split tuning\n(30/20/50)',
'Final optimized\n(25/15/60)',
]
colors = ['#95a5a6', '#3498db', '#2ecc71', '#27ae60', '#1a8a4a']
bars = ax.bar(range(len(experiments)), f1_scores, color=colors, edgecolor='white', linewidth=1.5, width=0.7)
# Add value labels on bars
for bar, score in zip(bars, f1_scores):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.002,
f'{score:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
ax.set_xticks(range(len(experiments)))
ax.set_xticklabels(labels, fontsize=10)
ax.set_ylabel('val_f1 (entity-level, micro)', fontsize=12)
ax.set_ylim(0.78, 0.87)
ax.set_title('Disease NER: How the Agent Improved Over 93 Experiments', fontsize=14, fontweight='bold')
# Add improvement arrows
ax.annotate('', xy=(1, 0.847), xytext=(0, 0.805),
arrowprops=dict(arrowstyle='->', color='#e74c3c', lw=2))
ax.text(0.5, 0.823, '+4.4%', ha='center', fontsize=11, color='#e74c3c', fontweight='bold')
ax.axhline(y=0.8033, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax.text(4.4, 0.8043, 'baseline', fontsize=9, color='gray', alpha=0.7)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig('graph_improvement_timeline.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.close()
print("Saved graph_improvement_timeline.png")
# ── Graph 3: Asymmetry Comparison ──────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5), sharey=False)
# Left: Disease target
sources_disease = ['bc5cdr_chem', 'jnlpba', 'bc2gm', 'linnaeus']
deltas_disease = [+0.044, +0.013, -0.007, -0.033]
colors_disease = ['#27ae60' if d > 0 else '#e74c3c' for d in deltas_disease]
src_labels = ['Chemicals', 'Proteins/DNA', 'Genes', 'Species']
axes[0].barh(range(len(sources_disease)), deltas_disease, color=colors_disease, edgecolor='white', height=0.6)
axes[0].set_yticks(range(len(sources_disease)))
axes[0].set_yticklabels(src_labels, fontsize=11)
axes[0].set_xlabel('Ξ”F1 vs baseline', fontsize=11)
axes[0].set_title('Target: Disease NER', fontsize=13, fontweight='bold')
axes[0].axvline(x=0, color='gray', linewidth=0.8)
axes[0].set_xlim(-0.05, 0.06)
axes[0].spines['top'].set_visible(False)
axes[0].spines['right'].set_visible(False)
for i, d in enumerate(deltas_disease):
sign = '+' if d > 0 else ''
axes[0].text(d + (0.002 if d >= 0 else -0.002), i, f'{sign}{d:.1%}',
ha='left' if d >= 0 else 'right', va='center', fontsize=11, fontweight='bold')
# Right: Chemical target
sources_chem = ['ncbi_disease', 'jnlpba', 'bc2gm', 'linnaeus']
deltas_chem = [-0.001, -0.004, -0.001, +0.001]
colors_chem = ['#27ae60' if d > 0.002 else '#e74c3c' if d < -0.002 else '#95a5a6' for d in deltas_chem]
src_labels_chem = ['Diseases', 'Proteins/DNA', 'Genes', 'Species']
axes[1].barh(range(len(sources_chem)), deltas_chem, color=colors_chem, edgecolor='white', height=0.6)
axes[1].set_yticks(range(len(sources_chem)))
axes[1].set_yticklabels(src_labels_chem, fontsize=11)
axes[1].set_xlabel('Ξ”F1 vs baseline', fontsize=11)
axes[1].set_title('Target: Chemical NER', fontsize=13, fontweight='bold')
axes[1].axvline(x=0, color='gray', linewidth=0.8)
axes[1].set_xlim(-0.05, 0.06)
axes[1].spines['top'].set_visible(False)
axes[1].spines['right'].set_visible(False)
for i, d in enumerate(deltas_chem):
sign = '+' if d > 0 else ''
axes[1].text(d + (0.002 if d >= 0 else -0.002), i, f'{sign}{d:.1%}',
ha='left' if d >= 0 else 'right', va='center', fontsize=11, fontweight='bold')
fig.suptitle('The Asymmetry: Chemicals Help Diseases, But Not Vice Versa',
fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('graph_asymmetry.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.close()
print("Saved graph_asymmetry.png")
print("\nAll graphs generated!")