import sys sys.path.append("..") from dataset import PeptidePairDataset from torch.utils.data import Subset from sklearn.model_selection import KFold import matplotlib.pyplot as plt import numpy as np def extract_label(dataset: PeptidePairDataset): labels = [] for _, label in dataset: labels.append(label) return labels def extract_labels_cls(dataset: PeptidePairDataset): labels = [] for _, label in dataset: if label >= 0.5: labels.append(2) elif label <= -0.5: labels.append(1) else: labels.append(0) return labels all_set = PeptidePairDataset(mode='train', pad_length=30, task='reg') test_set = PeptidePairDataset(mode='r2_case', pad_length=30, task='reg', one_way=True) # kf = KFold(n_splits=5, shuffle=True, random_state=42) # train_idx, val_idx = next(kf.split(all_set)) # train_set = Subset(all_set, train_idx) # valid_set = Subset(all_set, val_idx) bins = np.arange(-6.5, 6.6, 0.5) def plot_gradient_bar_histogram(ax, data, bins, cmap_name='Blues'): """使用渐变配色方案绘制直方图,每个柱子颜色不同但连续,并在柱子上添加数值标注""" # 计算直方图数据 counts, bin_edges = np.histogram(data, bins=bins) n_bins = len(counts) # 获取 colormap cmap = plt.get_cmap(cmap_name) width = np.diff(bin_edges)[0] # 本例中每个柱子的宽度均为固定值 # 绘制每个柱子,并添加数值标注 for i in range(n_bins): # 归一化的颜色比例 if n_bins > 1: ratio = (i / (n_bins - 1)) * 0.5 + 0.3 else: ratio = 0.5 color = cmap(ratio) # 绘制柱子 ax.bar(bin_edges[i], counts[i], width=width, align='edge', color=color, zorder=2) # 为非空柱子添加数值标注 if counts[i] != 0: x_pos = bin_edges[i] + width / 2 y_pos = counts[i] ax.text( x_pos, y_pos, f'{counts[i]}', ha='center', va='bottom', fontweight='bold', fontsize=9 ) return counts, bin_edges def plot_gradient_bar_histogram_cls(ax, data, cmap_name='Blues'): """将类别数据统计后绘制直方图,并添加数值标注,类别固定为 [0, 1, 2]""" # 固定类别顺序 classes = np.array([0, 1, 2]) counts = np.array([data.count(cls) for cls in classes]) num_classes = len(classes) width = 0.6 cmap = plt.get_cmap(cmap_name) for i, cls in enumerate(classes): # 渐变配色比例 ratio = (i / (num_classes - 1)) * 0.5 + 0.3 if num_classes > 1 else 0.5 color = cmap(ratio) ax.bar(cls, counts[i], width=width, color=color, zorder=2) if counts[i] != 0: ax.text(cls, counts[i], f'{counts[i]}', ha='center', va='bottom', fontweight='bold') ax.set_xticks(classes) ax.set_xticklabels(['Unchanged', 'Lower', 'Higher']) return counts # # 设置全局字体大小 # plt.rcParams.update({ # 'font.size': 10, # 'axes.titlesize': 14, # 'axes.labelsize': 14, # 'xtick.labelsize': 13, # 'ytick.labelsize': 13, # 'legend.fontsize': 14, # 'figure.titlesize': 16 # }) plt.style.use('seaborn-v0_8-whitegrid') # 创建 3 行 2 列子图,第一列为回归数据的可视化,第二列为分类数据的可视化 # fig, axes = plt.subplots(3, 2, figsize=(12, 9), width_ratios=[2.5, 1]) fig, axes = plt.subplots(1, 2, figsize=(11.4, 3)) # 数据提取:回归数据 reg_train = extract_label(all_set) # reg_val = extract_label(valid_set) reg_test = extract_label(test_set) # 数据提取:分类数据 # cls_train = extract_labels_cls(all_set) # cls_val = extract_labels_cls(valid_set) # cls_test = extract_labels_cls(test_set) # 第一行:训练集 counts, _ = plot_gradient_bar_histogram(axes[0], reg_train, bins, cmap_name='Blues') axes[0].set_xlabel('log2 (MIC ratio)', size=14) axes[0].set_xlim(-6.5, 6.5) axes[0].set_ylabel('Counts', size=14) axes[0].set_ylim(0, max(counts)*1.15) axes[0].set_title(f'Train Set ({len(reg_train)})', weight='bold', size=16) # counts_cls = plot_gradient_bar_histogram_cls(axes[0, 1], cls_train, cmap_name='Blues') # axes[0, 1].set_xlabel('Class') # axes[0, 1].set_xlim(-0.5, 2.5) # axes[0, 1].set_ylabel('Counts') # axes[0, 1].set_ylim(0, max(counts_cls)*1.15) # axes[0, 1].set_title(f'Train Set (Cls.) ({len(cls_train)})', weight='bold') # 第二行:验证集 counts, _ = plot_gradient_bar_histogram(axes[1], reg_test, bins, cmap_name='Blues') axes[1].set_xlabel('log2 (MIC ratio)', size=14) axes[1].set_xlim(-6.5, 6.5) axes[1].set_ylabel('Counts', size=14) axes[1].set_ylim(0, max(counts)*1.15) axes[1].set_title(f'R2 Set ({len(reg_test)})', weight='bold', size=16) # counts_cls = plot_gradient_bar_histogram_cls(axes[1, 1], cls_val, cmap_name='Blues') # axes[1, 1].set_xlabel('Class') # axes[1, 1].set_xlim(-0.5, 2.5) # axes[1, 1].set_ylabel('Counts') # axes[1, 1].set_ylim(0, max(counts_cls)*1.15) # axes[1, 1].set_title(f'Validation Set (Cls.) ({len(cls_val)})', weight='bold') # # 第三行:测试集 # counts, _ = plot_gradient_bar_histogram(axes[2, 0], reg_test, bins, cmap_name='Blues') # axes[2, 0].set_xlabel('log2 (MIC ratio)') # axes[2, 0].set_xlim(-6.5, 6.5) # axes[2, 0].set_ylabel('Counts') # axes[2, 0].set_ylim(0, max(counts)*1.15) # axes[2, 0].set_title(f'Test Set (Reg.) ({len(reg_test)})', weight='bold') # counts_cls = plot_gradient_bar_histogram_cls(axes[2, 1], cls_test, cmap_name='Blues') # axes[2, 1].set_xlabel('Class') # axes[2, 1].set_xlim(-0.5, 2.5) # axes[2, 1].set_ylabel('Counts') # axes[2, 1].set_ylim(0, max(counts_cls)*1.15) # axes[2, 1].set_title(f'Test Set (Cls.) ({len(cls_test)})', weight='bold') plt.tight_layout(w_pad=3.2) plt.savefig('dataset.svg') plt.show() # 输出数据集统计信息(回归数据) print("\n训练集回归统计信息:") print(f"样本总数: {len(reg_train)}") print(f"范围: {min(reg_train)} - {max(reg_train)}") # print("\n验证集回归统计信息:") # print(f"样本总数: {len(reg_val)}") # print(f"范围: {min(reg_val)} - {max(reg_val)}") print("\n测试集回归统计信息:") print(f"样本总数: {len(reg_test)}") print(f"范围: {min(reg_test)} - {max(reg_test)}") # 输出数据集统计信息(分类数据) # print("\n训练集分类统计信息:") # print(f"样本总数: {len(cls_train)}") # print(f"各类别统计: 0类: {cls_train.count(0)}, 1类: {cls_train.count(1)}, 2类: {cls_train.count(2)}") # print("\n验证集分类统计信息:") # print(f"样本总数: {len(cls_val)}") # print(f"各类别统计: 0类: {cls_val.count(0)}, 1类: {cls_val.count(1)}, 2类: {cls_val.count(2)}") # print("\n测试集分类统计信息:") # print(f"样本总数: {len(cls_test)}") # print(f"各类别统计: 0类: {cls_test.count(0)}, 1类: {cls_test.count(1)}, 2类: {cls_test.count(2)}")