| """Generate blog post graphs for the autoresearch findings.""" |
|
|
| import matplotlib.pyplot as plt |
| import matplotlib |
| import numpy as np |
|
|
| matplotlib.rcParams['font.family'] = 'sans-serif' |
| matplotlib.rcParams['font.size'] = 13 |
|
|
| |
|
|
| fig, ax = plt.subplots(figsize=(7, 5)) |
|
|
| sources = ['bc5cdr_chem', 'ncbi_disease', 'jnlpba', 'bc2gm', 'linnaeus'] |
| source_labels = ['Chemicals\n(BC5CDR)', 'Diseases\n(NCBI)', 'Proteins/DNA\n(JNLPBA)', 'Genes\n(BC2GM)', 'Species\n(Linnaeus)'] |
| targets = ['ncbi_disease', 'bc5cdr_chem'] |
| target_labels = ['Disease NER\n(NCBI Disease)', 'Chemical NER\n(BC5CDR-Chem)'] |
|
|
| |
| data = np.array([ |
| [+0.044, np.nan], |
| [np.nan, -0.001], |
| [+0.013, -0.004], |
| [-0.007, -0.001], |
| [-0.033, +0.001], |
| ]) |
|
|
| |
| mask = np.isnan(data) |
| masked_data = np.ma.array(data, mask=mask) |
|
|
| cmap = plt.cm.RdYlGn |
| cmap.set_bad('lightgray') |
|
|
| im = ax.imshow(masked_data, cmap=cmap, aspect='auto', vmin=-0.04, vmax=0.05) |
|
|
| ax.set_xticks(range(len(target_labels))) |
| ax.set_xticklabels(target_labels, fontsize=12) |
| ax.set_yticks(range(len(source_labels))) |
| ax.set_yticklabels(source_labels, fontsize=11) |
|
|
| ax.set_xlabel('Target Task β', fontsize=13, fontweight='bold', labelpad=10) |
| ax.set_ylabel('β Pretrain Source', fontsize=13, fontweight='bold', labelpad=10) |
|
|
| |
| for i in range(len(sources)): |
| for j in range(len(targets)): |
| if mask[i, j]: |
| ax.text(j, i, 'β', ha='center', va='center', fontsize=14, color='gray') |
| else: |
| val = data[i, j] |
| sign = '+' if val > 0 else '' |
| color = 'white' if abs(val) > 0.025 else 'black' |
| ax.text(j, i, f'{sign}{val:.1%}', ha='center', va='center', |
| fontsize=14, fontweight='bold', color=color) |
|
|
| cbar = plt.colorbar(im, ax=ax, shrink=0.8, pad=0.02) |
| cbar.set_label('ΞF1 vs baseline', fontsize=11) |
| cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:+.0%}')) |
|
|
| ax.set_title('Cross-Dataset Transfer Affinity in Biomedical NER', fontsize=14, fontweight='bold', pad=15) |
|
|
| plt.tight_layout() |
| plt.savefig('graph_transfer_heatmap.png', dpi=150, bbox_inches='tight', facecolor='white') |
| plt.close() |
| print("Saved graph_transfer_heatmap.png") |
|
|
|
|
| |
|
|
| fig, ax = plt.subplots(figsize=(9, 5)) |
|
|
| experiments = [0, 2, 9, 13, 15] |
| f1_scores = [0.8033, 0.8470, 0.8519, 0.8543, 0.8535] |
| labels = [ |
| 'Baseline\n(disease only)', |
| 'Chem pretrain\n(50/50)', |
| '3-stage\n(chemβjnlpbaβdisease)', |
| 'Time split tuning\n(30/20/50)', |
| 'Final optimized\n(25/15/60)', |
| ] |
|
|
| colors = ['#95a5a6', '#3498db', '#2ecc71', '#27ae60', '#1a8a4a'] |
|
|
| bars = ax.bar(range(len(experiments)), f1_scores, color=colors, edgecolor='white', linewidth=1.5, width=0.7) |
|
|
| |
| for bar, score in zip(bars, f1_scores): |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.002, |
| f'{score:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold') |
|
|
| ax.set_xticks(range(len(experiments))) |
| ax.set_xticklabels(labels, fontsize=10) |
| ax.set_ylabel('val_f1 (entity-level, micro)', fontsize=12) |
| ax.set_ylim(0.78, 0.87) |
| ax.set_title('Disease NER: How the Agent Improved Over 93 Experiments', fontsize=14, fontweight='bold') |
|
|
| |
| ax.annotate('', xy=(1, 0.847), xytext=(0, 0.805), |
| arrowprops=dict(arrowstyle='->', color='#e74c3c', lw=2)) |
| ax.text(0.5, 0.823, '+4.4%', ha='center', fontsize=11, color='#e74c3c', fontweight='bold') |
|
|
| ax.axhline(y=0.8033, color='gray', linestyle='--', alpha=0.5, linewidth=1) |
| ax.text(4.4, 0.8043, 'baseline', fontsize=9, color='gray', alpha=0.7) |
|
|
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
|
|
| plt.tight_layout() |
| plt.savefig('graph_improvement_timeline.png', dpi=150, bbox_inches='tight', facecolor='white') |
| plt.close() |
| print("Saved graph_improvement_timeline.png") |
|
|
|
|
| |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(10, 4.5), sharey=False) |
|
|
| |
| sources_disease = ['bc5cdr_chem', 'jnlpba', 'bc2gm', 'linnaeus'] |
| deltas_disease = [+0.044, +0.013, -0.007, -0.033] |
| colors_disease = ['#27ae60' if d > 0 else '#e74c3c' for d in deltas_disease] |
| src_labels = ['Chemicals', 'Proteins/DNA', 'Genes', 'Species'] |
|
|
| axes[0].barh(range(len(sources_disease)), deltas_disease, color=colors_disease, edgecolor='white', height=0.6) |
| axes[0].set_yticks(range(len(sources_disease))) |
| axes[0].set_yticklabels(src_labels, fontsize=11) |
| axes[0].set_xlabel('ΞF1 vs baseline', fontsize=11) |
| axes[0].set_title('Target: Disease NER', fontsize=13, fontweight='bold') |
| axes[0].axvline(x=0, color='gray', linewidth=0.8) |
| axes[0].set_xlim(-0.05, 0.06) |
| axes[0].spines['top'].set_visible(False) |
| axes[0].spines['right'].set_visible(False) |
| for i, d in enumerate(deltas_disease): |
| sign = '+' if d > 0 else '' |
| axes[0].text(d + (0.002 if d >= 0 else -0.002), i, f'{sign}{d:.1%}', |
| ha='left' if d >= 0 else 'right', va='center', fontsize=11, fontweight='bold') |
|
|
| |
| sources_chem = ['ncbi_disease', 'jnlpba', 'bc2gm', 'linnaeus'] |
| deltas_chem = [-0.001, -0.004, -0.001, +0.001] |
| colors_chem = ['#27ae60' if d > 0.002 else '#e74c3c' if d < -0.002 else '#95a5a6' for d in deltas_chem] |
| src_labels_chem = ['Diseases', 'Proteins/DNA', 'Genes', 'Species'] |
|
|
| axes[1].barh(range(len(sources_chem)), deltas_chem, color=colors_chem, edgecolor='white', height=0.6) |
| axes[1].set_yticks(range(len(sources_chem))) |
| axes[1].set_yticklabels(src_labels_chem, fontsize=11) |
| axes[1].set_xlabel('ΞF1 vs baseline', fontsize=11) |
| axes[1].set_title('Target: Chemical NER', fontsize=13, fontweight='bold') |
| axes[1].axvline(x=0, color='gray', linewidth=0.8) |
| axes[1].set_xlim(-0.05, 0.06) |
| axes[1].spines['top'].set_visible(False) |
| axes[1].spines['right'].set_visible(False) |
| for i, d in enumerate(deltas_chem): |
| sign = '+' if d > 0 else '' |
| axes[1].text(d + (0.002 if d >= 0 else -0.002), i, f'{sign}{d:.1%}', |
| ha='left' if d >= 0 else 'right', va='center', fontsize=11, fontweight='bold') |
|
|
| fig.suptitle('The Asymmetry: Chemicals Help Diseases, But Not Vice Versa', |
| fontsize=14, fontweight='bold', y=1.02) |
| plt.tight_layout() |
| plt.savefig('graph_asymmetry.png', dpi=150, bbox_inches='tight', facecolor='white') |
| plt.close() |
| print("Saved graph_asymmetry.png") |
|
|
| print("\nAll graphs generated!") |
|
|