|
|
import sys |
|
|
sys.path.append("..") |
|
|
from dataset import PeptidePairDataset |
|
|
from torch.utils.data import Subset |
|
|
from sklearn.model_selection import KFold |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
def extract_label(dataset: PeptidePairDataset): |
|
|
labels = [] |
|
|
for _, label in dataset: |
|
|
labels.append(label) |
|
|
return labels |
|
|
|
|
|
def extract_labels_cls(dataset: PeptidePairDataset): |
|
|
labels = [] |
|
|
for _, label in dataset: |
|
|
if label >= 0.5: |
|
|
labels.append(2) |
|
|
elif label <= -0.5: |
|
|
labels.append(1) |
|
|
else: |
|
|
labels.append(0) |
|
|
return labels |
|
|
|
|
|
all_set = PeptidePairDataset(mode='train', pad_length=30, task='reg') |
|
|
test_set = PeptidePairDataset(mode='r2_case', pad_length=30, task='reg', one_way=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bins = np.arange(-6.5, 6.6, 0.5) |
|
|
|
|
|
def plot_gradient_bar_histogram(ax, data, bins, cmap_name='Blues'): |
|
|
"""使用渐变配色方案绘制直方图,每个柱子颜色不同但连续,并在柱子上添加数值标注""" |
|
|
|
|
|
counts, bin_edges = np.histogram(data, bins=bins) |
|
|
n_bins = len(counts) |
|
|
|
|
|
cmap = plt.get_cmap(cmap_name) |
|
|
width = np.diff(bin_edges)[0] |
|
|
|
|
|
for i in range(n_bins): |
|
|
|
|
|
if n_bins > 1: |
|
|
ratio = (i / (n_bins - 1)) * 0.5 + 0.3 |
|
|
else: |
|
|
ratio = 0.5 |
|
|
color = cmap(ratio) |
|
|
|
|
|
ax.bar(bin_edges[i], counts[i], width=width, align='edge', |
|
|
color=color, zorder=2) |
|
|
|
|
|
if counts[i] != 0: |
|
|
x_pos = bin_edges[i] + width / 2 |
|
|
y_pos = counts[i] |
|
|
ax.text( |
|
|
x_pos, y_pos, f'{counts[i]}', |
|
|
ha='center', va='bottom', fontweight='bold', fontsize=9 |
|
|
) |
|
|
return counts, bin_edges |
|
|
|
|
|
def plot_gradient_bar_histogram_cls(ax, data, cmap_name='Blues'): |
|
|
"""将类别数据统计后绘制直方图,并添加数值标注,类别固定为 [0, 1, 2]""" |
|
|
|
|
|
classes = np.array([0, 1, 2]) |
|
|
counts = np.array([data.count(cls) for cls in classes]) |
|
|
num_classes = len(classes) |
|
|
width = 0.6 |
|
|
cmap = plt.get_cmap(cmap_name) |
|
|
for i, cls in enumerate(classes): |
|
|
|
|
|
ratio = (i / (num_classes - 1)) * 0.5 + 0.3 if num_classes > 1 else 0.5 |
|
|
color = cmap(ratio) |
|
|
ax.bar(cls, counts[i], width=width, color=color, zorder=2) |
|
|
if counts[i] != 0: |
|
|
ax.text(cls, counts[i], f'{counts[i]}', ha='center', va='bottom', fontweight='bold') |
|
|
ax.set_xticks(classes) |
|
|
ax.set_xticklabels(['Unchanged', 'Lower', 'Higher']) |
|
|
return counts |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.style.use('seaborn-v0_8-whitegrid') |
|
|
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(11.4, 3)) |
|
|
|
|
|
|
|
|
reg_train = extract_label(all_set) |
|
|
|
|
|
reg_test = extract_label(test_set) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
counts, _ = plot_gradient_bar_histogram(axes[0], reg_train, bins, cmap_name='Blues') |
|
|
axes[0].set_xlabel('log2 (MIC ratio)', size=14) |
|
|
axes[0].set_xlim(-6.5, 6.5) |
|
|
axes[0].set_ylabel('Counts', size=14) |
|
|
axes[0].set_ylim(0, max(counts)*1.15) |
|
|
axes[0].set_title(f'Train Set ({len(reg_train)})', weight='bold', size=16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
counts, _ = plot_gradient_bar_histogram(axes[1], reg_test, bins, cmap_name='Blues') |
|
|
axes[1].set_xlabel('log2 (MIC ratio)', size=14) |
|
|
axes[1].set_xlim(-6.5, 6.5) |
|
|
axes[1].set_ylabel('Counts', size=14) |
|
|
axes[1].set_ylim(0, max(counts)*1.15) |
|
|
axes[1].set_title(f'R2 Set ({len(reg_test)})', weight='bold', size=16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.tight_layout(w_pad=3.2) |
|
|
plt.savefig('dataset.svg') |
|
|
plt.show() |
|
|
|
|
|
|
|
|
print("\n训练集回归统计信息:") |
|
|
print(f"样本总数: {len(reg_train)}") |
|
|
print(f"范围: {min(reg_train)} - {max(reg_train)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n测试集回归统计信息:") |
|
|
print(f"样本总数: {len(reg_test)}") |
|
|
print(f"范围: {min(reg_test)} - {max(reg_test)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|