| import sys |
| sys.path.append("..") |
| from dataset import PeptidePairDataset, SimplePairClsDataset |
| import matplotlib.pyplot as plt |
| from matplotlib.patches import Patch |
| import matplotlib.gridspec as gridspec |
| import numpy as np |
|
|
| def extract_label(dataset: PeptidePairDataset): |
| labels = [] |
| for _, label in dataset: |
| labels.append(label) |
| return labels |
|
|
| train_set = SimplePairClsDataset() |
| train_set_llm = SimplePairClsDataset(llm=True) |
| test_set = PeptidePairDataset(mode='test', task='cls') |
| stablility_set_1 = PeptidePairDataset(mode='125fbs', task='cls') |
| stablility_set_2 = PeptidePairDataset(mode='25fbs', task='cls') |
|
|
| train_labels = extract_label(train_set) |
| train_labels_llm = extract_label(train_set_llm) |
| test_labels = extract_label(test_set) |
| stability_labels_1 = extract_label(stablility_set_1) |
| stability_labels_2 = extract_label(stablility_set_2) |
|
|
| |
| datasets = [ |
| ("Original", train_labels), |
| ("LLM Augmented", train_labels_llm), |
| ("MHB", test_labels), |
| ("12.5% FBS", stability_labels_1), |
| ("25% FBS", stability_labels_2), |
| ] |
|
|
| |
| colors = ["#BDD7EE", "#EBB4B1"] |
| legend_labels = ["Not Improved", "Improved"] |
|
|
| plt.rcParams.update({ |
| 'font.size': 12, |
| 'axes.titlesize': 14, |
| 'axes.labelsize': 14, |
| 'xtick.labelsize': 13, |
| 'ytick.labelsize': 13, |
| 'legend.fontsize': 14, |
| 'figure.titlesize': 18 |
| }) |
|
|
| def make_autopct(total): |
| def my_autopct(pct): |
| absolute = int(round(pct/100. * total)) |
| return '{:.1f}%\n({:d})'.format(pct, absolute) |
| return my_autopct |
|
|
| |
| fig = plt.figure(figsize=(10, 4)) |
| gs = gridspec.GridSpec(2, 5, height_ratios=[0.2, 1], hspace=0) |
|
|
| |
| ax_train_title = fig.add_subplot(gs[0, 0:2]) |
| ax_train_title.axis('off') |
| ax_train_title.set_title('Train Set', pad=10, fontsize=16) |
|
|
| ax_test_title = fig.add_subplot(gs[0, 2:5]) |
| ax_test_title.axis('off') |
| ax_test_title.set_title('Test Set', pad=10, fontsize=16) |
|
|
| |
| axes = [fig.add_subplot(gs[1, i]) for i in range(5)] |
| for ax, (title, data) in zip(axes, datasets): |
| count_0 = data.count(0) |
| count_1 = data.count(1) |
| totals = [count_0, count_1] |
| total_num = sum(totals) |
| ax.pie(totals, colors=colors, |
| autopct=make_autopct(total_num), |
| startangle=90) |
| ax.set_title(title) |
| ax.axis('equal') |
|
|
| |
| patches = [Patch(color=colors[i], label=legend_labels[i]) for i in range(2)] |
| fig.legend(handles=patches, |
| loc='lower center', |
| ncol=2, |
| bbox_to_anchor=(0.5, 0.01), |
| frameon=False) |
|
|
| plt.tight_layout(rect=[0.015, 0.10, 0.985, 0.95]) |
| plt.show() |