File size: 6,905 Bytes
acbef3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import sys
sys.path.append("..")
from dataset import PeptidePairDataset
from torch.utils.data import Subset
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import numpy as np
def extract_label(dataset: PeptidePairDataset):
labels = []
for _, label in dataset:
labels.append(label)
return labels
def extract_labels_cls(dataset: PeptidePairDataset):
labels = []
for _, label in dataset:
if label >= 0.5:
labels.append(2)
elif label <= -0.5:
labels.append(1)
else:
labels.append(0)
return labels
all_set = PeptidePairDataset(mode='train', pad_length=30, task='reg')
test_set = PeptidePairDataset(mode='r2_case', pad_length=30, task='reg', one_way=True)
# kf = KFold(n_splits=5, shuffle=True, random_state=42)
# train_idx, val_idx = next(kf.split(all_set))
# train_set = Subset(all_set, train_idx)
# valid_set = Subset(all_set, val_idx)
bins = np.arange(-6.5, 6.6, 0.5)
def plot_gradient_bar_histogram(ax, data, bins, cmap_name='Blues'):
"""使用渐变配色方案绘制直方图,每个柱子颜色不同但连续,并在柱子上添加数值标注"""
# 计算直方图数据
counts, bin_edges = np.histogram(data, bins=bins)
n_bins = len(counts)
# 获取 colormap
cmap = plt.get_cmap(cmap_name)
width = np.diff(bin_edges)[0] # 本例中每个柱子的宽度均为固定值
# 绘制每个柱子,并添加数值标注
for i in range(n_bins):
# 归一化的颜色比例
if n_bins > 1:
ratio = (i / (n_bins - 1)) * 0.5 + 0.3
else:
ratio = 0.5
color = cmap(ratio)
# 绘制柱子
ax.bar(bin_edges[i], counts[i], width=width, align='edge',
color=color, zorder=2)
# 为非空柱子添加数值标注
if counts[i] != 0:
x_pos = bin_edges[i] + width / 2
y_pos = counts[i]
ax.text(
x_pos, y_pos, f'{counts[i]}',
ha='center', va='bottom', fontweight='bold', fontsize=9
)
return counts, bin_edges
def plot_gradient_bar_histogram_cls(ax, data, cmap_name='Blues'):
"""将类别数据统计后绘制直方图,并添加数值标注,类别固定为 [0, 1, 2]"""
# 固定类别顺序
classes = np.array([0, 1, 2])
counts = np.array([data.count(cls) for cls in classes])
num_classes = len(classes)
width = 0.6
cmap = plt.get_cmap(cmap_name)
for i, cls in enumerate(classes):
# 渐变配色比例
ratio = (i / (num_classes - 1)) * 0.5 + 0.3 if num_classes > 1 else 0.5
color = cmap(ratio)
ax.bar(cls, counts[i], width=width, color=color, zorder=2)
if counts[i] != 0:
ax.text(cls, counts[i], f'{counts[i]}', ha='center', va='bottom', fontweight='bold')
ax.set_xticks(classes)
ax.set_xticklabels(['Unchanged', 'Lower', 'Higher'])
return counts
# # 设置全局字体大小
# plt.rcParams.update({
# 'font.size': 10,
# 'axes.titlesize': 14,
# 'axes.labelsize': 14,
# 'xtick.labelsize': 13,
# 'ytick.labelsize': 13,
# 'legend.fontsize': 14,
# 'figure.titlesize': 16
# })
plt.style.use('seaborn-v0_8-whitegrid')
# 创建 3 行 2 列子图,第一列为回归数据的可视化,第二列为分类数据的可视化
# fig, axes = plt.subplots(3, 2, figsize=(12, 9), width_ratios=[2.5, 1])
fig, axes = plt.subplots(1, 2, figsize=(11.4, 3))
# 数据提取:回归数据
reg_train = extract_label(all_set)
# reg_val = extract_label(valid_set)
reg_test = extract_label(test_set)
# 数据提取:分类数据
# cls_train = extract_labels_cls(all_set)
# cls_val = extract_labels_cls(valid_set)
# cls_test = extract_labels_cls(test_set)
# 第一行:训练集
counts, _ = plot_gradient_bar_histogram(axes[0], reg_train, bins, cmap_name='Blues')
axes[0].set_xlabel('log2 (MIC ratio)', size=14)
axes[0].set_xlim(-6.5, 6.5)
axes[0].set_ylabel('Counts', size=14)
axes[0].set_ylim(0, max(counts)*1.15)
axes[0].set_title(f'Train Set ({len(reg_train)})', weight='bold', size=16)
# counts_cls = plot_gradient_bar_histogram_cls(axes[0, 1], cls_train, cmap_name='Blues')
# axes[0, 1].set_xlabel('Class')
# axes[0, 1].set_xlim(-0.5, 2.5)
# axes[0, 1].set_ylabel('Counts')
# axes[0, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[0, 1].set_title(f'Train Set (Cls.) ({len(cls_train)})', weight='bold')
# 第二行:验证集
counts, _ = plot_gradient_bar_histogram(axes[1], reg_test, bins, cmap_name='Blues')
axes[1].set_xlabel('log2 (MIC ratio)', size=14)
axes[1].set_xlim(-6.5, 6.5)
axes[1].set_ylabel('Counts', size=14)
axes[1].set_ylim(0, max(counts)*1.15)
axes[1].set_title(f'R2 Set ({len(reg_test)})', weight='bold', size=16)
# counts_cls = plot_gradient_bar_histogram_cls(axes[1, 1], cls_val, cmap_name='Blues')
# axes[1, 1].set_xlabel('Class')
# axes[1, 1].set_xlim(-0.5, 2.5)
# axes[1, 1].set_ylabel('Counts')
# axes[1, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[1, 1].set_title(f'Validation Set (Cls.) ({len(cls_val)})', weight='bold')
# # 第三行:测试集
# counts, _ = plot_gradient_bar_histogram(axes[2, 0], reg_test, bins, cmap_name='Blues')
# axes[2, 0].set_xlabel('log2 (MIC ratio)')
# axes[2, 0].set_xlim(-6.5, 6.5)
# axes[2, 0].set_ylabel('Counts')
# axes[2, 0].set_ylim(0, max(counts)*1.15)
# axes[2, 0].set_title(f'Test Set (Reg.) ({len(reg_test)})', weight='bold')
# counts_cls = plot_gradient_bar_histogram_cls(axes[2, 1], cls_test, cmap_name='Blues')
# axes[2, 1].set_xlabel('Class')
# axes[2, 1].set_xlim(-0.5, 2.5)
# axes[2, 1].set_ylabel('Counts')
# axes[2, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[2, 1].set_title(f'Test Set (Cls.) ({len(cls_test)})', weight='bold')
plt.tight_layout(w_pad=3.2)
plt.savefig('dataset.svg')
plt.show()
# 输出数据集统计信息(回归数据)
print("\n训练集回归统计信息:")
print(f"样本总数: {len(reg_train)}")
print(f"范围: {min(reg_train)} - {max(reg_train)}")
# print("\n验证集回归统计信息:")
# print(f"样本总数: {len(reg_val)}")
# print(f"范围: {min(reg_val)} - {max(reg_val)}")
print("\n测试集回归统计信息:")
print(f"样本总数: {len(reg_test)}")
print(f"范围: {min(reg_test)} - {max(reg_test)}")
# 输出数据集统计信息(分类数据)
# print("\n训练集分类统计信息:")
# print(f"样本总数: {len(cls_train)}")
# print(f"各类别统计: 0类: {cls_train.count(0)}, 1类: {cls_train.count(1)}, 2类: {cls_train.count(2)}")
# print("\n验证集分类统计信息:")
# print(f"样本总数: {len(cls_val)}")
# print(f"各类别统计: 0类: {cls_val.count(0)}, 1类: {cls_val.count(1)}, 2类: {cls_val.count(2)}")
# print("\n测试集分类统计信息:")
# print(f"样本总数: {len(cls_test)}")
# print(f"各类别统计: 0类: {cls_test.count(0)}, 1类: {cls_test.count(1)}, 2类: {cls_test.count(2)}") |