DAminoMuta / vis /2cls_dataset.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import sys
sys.path.append("..")
from dataset import PeptidePairDataset, SimplePairClsDataset
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import matplotlib.gridspec as gridspec
import numpy as np
def extract_label(dataset: PeptidePairDataset):
labels = []
for _, label in dataset:
labels.append(label)
return labels
train_set = SimplePairClsDataset()
train_set_llm = SimplePairClsDataset(llm=True)
test_set = PeptidePairDataset(mode='test', task='cls')
stablility_set_1 = PeptidePairDataset(mode='125fbs', task='cls')
stablility_set_2 = PeptidePairDataset(mode='25fbs', task='cls')
train_labels = extract_label(train_set)
train_labels_llm = extract_label(train_set_llm)
test_labels = extract_label(test_set)
stability_labels_1 = extract_label(stablility_set_1)
stability_labels_2 = extract_label(stablility_set_2)
# 5个数据集和对应的名称
datasets = [
("Original", train_labels),
("LLM Augmented", train_labels_llm),
("MHB", test_labels),
("12.5% FBS", stability_labels_1),
("25% FBS", stability_labels_2),
]
# 定义扇区颜色,0为天蓝色,1为橙色
colors = ["#BDD7EE", "#EBB4B1"]
legend_labels = ["Not Improved", "Improved"]
plt.rcParams.update({
'font.size': 12,
'axes.titlesize': 14,
'axes.labelsize': 14,
'xtick.labelsize': 13,
'ytick.labelsize': 13,
'legend.fontsize': 14,
'figure.titlesize': 18
})
def make_autopct(total):
def my_autopct(pct):
absolute = int(round(pct/100. * total))
return '{:.1f}%\n({:d})'.format(pct, absolute)
return my_autopct
# 使用 GridSpec: 2 行 5 列,第一行用于组标题,第二行用于饼图
fig = plt.figure(figsize=(10, 4))
gs = gridspec.GridSpec(2, 5, height_ratios=[0.2, 1], hspace=0)
# 第1行:组标题
ax_train_title = fig.add_subplot(gs[0, 0:2])
ax_train_title.axis('off')
ax_train_title.set_title('Train Set', pad=10, fontsize=16)
ax_test_title = fig.add_subplot(gs[0, 2:5])
ax_test_title.axis('off')
ax_test_title.set_title('Test Set', pad=10, fontsize=16)
# 第2行:5 个饼图
axes = [fig.add_subplot(gs[1, i]) for i in range(5)]
for ax, (title, data) in zip(axes, datasets):
count_0 = data.count(0)
count_1 = data.count(1)
totals = [count_0, count_1]
total_num = sum(totals)
ax.pie(totals, colors=colors,
autopct=make_autopct(total_num),
startangle=90)
ax.set_title(title)
ax.axis('equal') # 保证饼图为圆形
# 共用图例
patches = [Patch(color=colors[i], label=legend_labels[i]) for i in range(2)]
fig.legend(handles=patches,
loc='lower center',
ncol=2,
bbox_to_anchor=(0.5, 0.01),
frameon=False)
plt.tight_layout(rect=[0.015, 0.10, 0.985, 0.95])
plt.show()