File size: 6,905 Bytes
acbef3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import sys 
sys.path.append("..") 
from dataset import PeptidePairDataset
from torch.utils.data import Subset
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import numpy as np

def extract_label(dataset: PeptidePairDataset):
    labels = []
    for _, label in dataset:
        labels.append(label)
    return labels

def extract_labels_cls(dataset: PeptidePairDataset):
    labels = []
    for _, label in dataset:
        if label >= 0.5:
            labels.append(2)
        elif label <= -0.5:
            labels.append(1)
        else:
            labels.append(0)
    return labels

all_set = PeptidePairDataset(mode='train', pad_length=30, task='reg')
test_set = PeptidePairDataset(mode='r2_case', pad_length=30, task='reg', one_way=True)

# kf = KFold(n_splits=5, shuffle=True, random_state=42)
# train_idx, val_idx = next(kf.split(all_set))
# train_set = Subset(all_set, train_idx)
# valid_set = Subset(all_set, val_idx)

bins = np.arange(-6.5, 6.6, 0.5)

def plot_gradient_bar_histogram(ax, data, bins, cmap_name='Blues'):
    """使用渐变配色方案绘制直方图,每个柱子颜色不同但连续,并在柱子上添加数值标注"""
    # 计算直方图数据
    counts, bin_edges = np.histogram(data, bins=bins)
    n_bins = len(counts)
    # 获取 colormap
    cmap = plt.get_cmap(cmap_name)
    width = np.diff(bin_edges)[0]  # 本例中每个柱子的宽度均为固定值
    # 绘制每个柱子,并添加数值标注
    for i in range(n_bins):
        # 归一化的颜色比例
        if n_bins > 1:
            ratio = (i / (n_bins - 1)) * 0.5 + 0.3
        else:
            ratio = 0.5
        color = cmap(ratio)
        # 绘制柱子
        ax.bar(bin_edges[i], counts[i], width=width, align='edge',
               color=color, zorder=2)
        # 为非空柱子添加数值标注
        if counts[i] != 0:
            x_pos = bin_edges[i] + width / 2
            y_pos = counts[i]
            ax.text(
                x_pos, y_pos, f'{counts[i]}', 
                ha='center', va='bottom', fontweight='bold', fontsize=9
                )
    return counts, bin_edges

def plot_gradient_bar_histogram_cls(ax, data, cmap_name='Blues'):
    """将类别数据统计后绘制直方图,并添加数值标注,类别固定为 [0, 1, 2]"""
    # 固定类别顺序
    classes = np.array([0, 1, 2])
    counts = np.array([data.count(cls) for cls in classes])
    num_classes = len(classes)
    width = 0.6
    cmap = plt.get_cmap(cmap_name)
    for i, cls in enumerate(classes):
        # 渐变配色比例
        ratio = (i / (num_classes - 1)) * 0.5 + 0.3 if num_classes > 1 else 0.5
        color = cmap(ratio)
        ax.bar(cls, counts[i], width=width, color=color, zorder=2)
        if counts[i] != 0:
            ax.text(cls, counts[i], f'{counts[i]}', ha='center', va='bottom', fontweight='bold')
    ax.set_xticks(classes)
    ax.set_xticklabels(['Unchanged', 'Lower', 'Higher'])
    return counts

# # 设置全局字体大小
# plt.rcParams.update({
#     'font.size': 10,
#     'axes.titlesize': 14,
#     'axes.labelsize': 14,
#     'xtick.labelsize': 13,
#     'ytick.labelsize': 13,
#     'legend.fontsize': 14,
#     'figure.titlesize': 16
# })
plt.style.use('seaborn-v0_8-whitegrid')

# 创建 3 行 2 列子图,第一列为回归数据的可视化,第二列为分类数据的可视化
# fig, axes = plt.subplots(3, 2, figsize=(12, 9), width_ratios=[2.5, 1])
fig, axes = plt.subplots(1, 2, figsize=(11.4, 3))

# 数据提取:回归数据
reg_train = extract_label(all_set)
# reg_val = extract_label(valid_set)
reg_test = extract_label(test_set)

# 数据提取:分类数据
# cls_train = extract_labels_cls(all_set)
# cls_val = extract_labels_cls(valid_set)
# cls_test = extract_labels_cls(test_set)

# 第一行:训练集
counts, _ = plot_gradient_bar_histogram(axes[0], reg_train, bins, cmap_name='Blues')
axes[0].set_xlabel('log2 (MIC ratio)', size=14)
axes[0].set_xlim(-6.5, 6.5)
axes[0].set_ylabel('Counts', size=14)
axes[0].set_ylim(0, max(counts)*1.15)
axes[0].set_title(f'Train Set ({len(reg_train)})', weight='bold', size=16)

# counts_cls = plot_gradient_bar_histogram_cls(axes[0, 1], cls_train, cmap_name='Blues')
# axes[0, 1].set_xlabel('Class')
# axes[0, 1].set_xlim(-0.5, 2.5)
# axes[0, 1].set_ylabel('Counts')
# axes[0, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[0, 1].set_title(f'Train Set (Cls.) ({len(cls_train)})', weight='bold')

# 第二行:验证集
counts, _ = plot_gradient_bar_histogram(axes[1], reg_test, bins, cmap_name='Blues')
axes[1].set_xlabel('log2 (MIC ratio)', size=14)
axes[1].set_xlim(-6.5, 6.5)
axes[1].set_ylabel('Counts', size=14)
axes[1].set_ylim(0, max(counts)*1.15)
axes[1].set_title(f'R2 Set ({len(reg_test)})', weight='bold', size=16)

# counts_cls = plot_gradient_bar_histogram_cls(axes[1, 1], cls_val, cmap_name='Blues')
# axes[1, 1].set_xlabel('Class')
# axes[1, 1].set_xlim(-0.5, 2.5)
# axes[1, 1].set_ylabel('Counts')
# axes[1, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[1, 1].set_title(f'Validation Set (Cls.) ({len(cls_val)})', weight='bold')

# # 第三行:测试集
# counts, _ = plot_gradient_bar_histogram(axes[2, 0], reg_test, bins, cmap_name='Blues')
# axes[2, 0].set_xlabel('log2 (MIC ratio)')
# axes[2, 0].set_xlim(-6.5, 6.5)
# axes[2, 0].set_ylabel('Counts')
# axes[2, 0].set_ylim(0, max(counts)*1.15)
# axes[2, 0].set_title(f'Test Set (Reg.) ({len(reg_test)})', weight='bold')

# counts_cls = plot_gradient_bar_histogram_cls(axes[2, 1], cls_test, cmap_name='Blues')
# axes[2, 1].set_xlabel('Class')
# axes[2, 1].set_xlim(-0.5, 2.5)
# axes[2, 1].set_ylabel('Counts')
# axes[2, 1].set_ylim(0, max(counts_cls)*1.15)
# axes[2, 1].set_title(f'Test Set (Cls.) ({len(cls_test)})', weight='bold')

plt.tight_layout(w_pad=3.2)
plt.savefig('dataset.svg')
plt.show()

# 输出数据集统计信息(回归数据)
print("\n训练集回归统计信息:")
print(f"样本总数: {len(reg_train)}")
print(f"范围: {min(reg_train)} - {max(reg_train)}")

# print("\n验证集回归统计信息:")
# print(f"样本总数: {len(reg_val)}")
# print(f"范围: {min(reg_val)} - {max(reg_val)}")

print("\n测试集回归统计信息:")
print(f"样本总数: {len(reg_test)}")
print(f"范围: {min(reg_test)} - {max(reg_test)}")

# 输出数据集统计信息(分类数据)
# print("\n训练集分类统计信息:")
# print(f"样本总数: {len(cls_train)}")
# print(f"各类别统计: 0类: {cls_train.count(0)}, 1类: {cls_train.count(1)}, 2类: {cls_train.count(2)}")

# print("\n验证集分类统计信息:")
# print(f"样本总数: {len(cls_val)}")
# print(f"各类别统计: 0类: {cls_val.count(0)}, 1类: {cls_val.count(1)}, 2类: {cls_val.count(2)}")

# print("\n测试集分类统计信息:")
# print(f"样本总数: {len(cls_test)}")
# print(f"各类别统计: 0类: {cls_test.count(0)}, 1类: {cls_test.count(1)}, 2类: {cls_test.count(2)}")