File size: 3,638 Bytes
acbef3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import numpy as np
import pandas as pd
from umap import UMAP
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
# base_path = '../run-cls/rn18-diff-16-mamba-pcs-768-ce-32-0.001-50'
# img_path = 'umap_before.png'
base_path = '../run-cls/rn18-diff-16-mamba-pcs-768-ce-32-0.001-50/uda_r2'
img_path = 'umap_after.png'
# —— 0. 你的两个函数 —— #
def get_label_dict():
"""
返回一个字典 { 描述(str): 标签(int, 0 或 1) }
比如:
return {
"image_0001": 1,
"image_0234": 0,
# ...
}
"""
df = pd.read_excel('../dataset/r2_case.xlsx')
df = df[~pd.isna(df['TAG'])]
tags = {}
for _, row in df.iterrows():
desc = row['SEQUENCE - D-type amino acid substitution']
tag = row['TAG']
if tag.lower() == 'improved':
tags[desc] = 1
elif tag.lower() == 'not improved':
tags[desc] = 0
return tags
def get_score_dict():
"""
返回 { 描述: Float in [0,1] },
用于背景点的颜色映射。
"""
df = pd.read_csv(f'{base_path}/feature_preds.csv')
desc = df['seq'].values
scores = df['pred'].values
minimum, maximum = np.min(scores), np.max(scores)
scores = (scores - minimum) / (maximum - minimum)
tags = {
desc: score
for desc, score in zip(desc, scores)
}
return tags
# —— 1. 加载特征 —— #
data_dict = torch.load(f'{base_path}/features.pth', map_location='cpu')
descs = list(data_dict.keys())
features = np.vstack([
data_dict[d].cpu().numpy() if isinstance(data_dict[d], torch.Tensor)
else np.array(data_dict[d])
for d in descs
])
print(f"共 {features.shape[0]} 个样本,特征维度 {features.shape[1]}")
# —— 2. 拿到标签与分数字典 —— #
label_dict = get_label_dict()
score_dict = get_score_dict()
# —— 3. 用 UMAP 降到 2 维 —— #
umap = UMAP(n_components=2, metric='euclidean')
points_2d = umap.fit_transform(features)
# —— 4. 分别收集索引 —— #
idx1 = [i for i,d in enumerate(descs) if d in label_dict and label_dict[d]==1]
idx0 = [i for i,d in enumerate(descs) if d in label_dict and label_dict[d]==0]
idx_rest = [i for i,d in enumerate(descs) if d not in label_dict]
# —— 5. 构造渐变色 —— #
# skyblue ('#87CEEB') → peachpuff ('#FFDAB9')
cmap = LinearSegmentedColormap.from_list('bg_cmap', ["#6EB1EC", "#E69C98"])
# 背景点用它们对应的 score 决定颜色,缺失则给 0.5
bg_scores = np.array([ score_dict.get(descs[i], 0.5) for i in idx_rest ])
bg_colors = cmap(bg_scores)
# —— 6. 绘图 —— #
plt.figure(figsize=(3,3))
# 背景点(渐变色)
plt.scatter(
points_2d[idx_rest,0],
points_2d[idx_rest,1],
s=20,
c=bg_colors,
alpha=0.8,
label='Background'
)
# 标签 0
plt.scatter(
points_2d[idx0,0],
points_2d[idx0,1],
s=60,
c="#218BE7",
alpha=1.0,
label='Not Improved'
)
# 标签 1
plt.scatter(
points_2d[idx1,0],
points_2d[idx1,1],
s=60,
c="#E76760",
alpha=1.0,
label='Improved'
)
# # 标注被高亮的点
# for i in idx0 + idx1:
# plt.annotate(
# descs[i],
# (points_2d[i,0], points_2d[i,1]),
# textcoords='offset points',
# xytext=(2,2),
# fontsize=8,
# color='black'
# )
# plt.title('UMAP Visualization', fontsize=18)
plt.axis('off')
# plt.legend(loc='upper left', frameon=False, fontsize=14)
plt.tight_layout()
# —— 可选:保存图像 —— #
plt.savefig(img_path, bbox_inches='tight', pad_inches=0, dpi=300)
# plt.show()
|