DAminoMuta / vis /umap_plot.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import torch
import numpy as np
import pandas as pd
from umap import UMAP
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
# base_path = '../run-cls/rn18-diff-16-mamba-pcs-768-ce-32-0.001-50'
# img_path = 'umap_before.png'
base_path = '../run-cls/rn18-diff-16-mamba-pcs-768-ce-32-0.001-50/uda_r2'
img_path = 'umap_after.png'
# —— 0. 你的两个函数 —— #
def get_label_dict():
"""
返回一个字典 { 描述(str): 标签(int, 0 或 1) }
比如:
return {
"image_0001": 1,
"image_0234": 0,
# ...
}
"""
df = pd.read_excel('../dataset/r2_case.xlsx')
df = df[~pd.isna(df['TAG'])]
tags = {}
for _, row in df.iterrows():
desc = row['SEQUENCE - D-type amino acid substitution']
tag = row['TAG']
if tag.lower() == 'improved':
tags[desc] = 1
elif tag.lower() == 'not improved':
tags[desc] = 0
return tags
def get_score_dict():
"""
返回 { 描述: Float in [0,1] },
用于背景点的颜色映射。
"""
df = pd.read_csv(f'{base_path}/feature_preds.csv')
desc = df['seq'].values
scores = df['pred'].values
minimum, maximum = np.min(scores), np.max(scores)
scores = (scores - minimum) / (maximum - minimum)
tags = {
desc: score
for desc, score in zip(desc, scores)
}
return tags
# —— 1. 加载特征 —— #
data_dict = torch.load(f'{base_path}/features.pth', map_location='cpu')
descs = list(data_dict.keys())
features = np.vstack([
data_dict[d].cpu().numpy() if isinstance(data_dict[d], torch.Tensor)
else np.array(data_dict[d])
for d in descs
])
print(f"共 {features.shape[0]} 个样本,特征维度 {features.shape[1]}")
# —— 2. 拿到标签与分数字典 —— #
label_dict = get_label_dict()
score_dict = get_score_dict()
# —— 3. 用 UMAP 降到 2 维 —— #
umap = UMAP(n_components=2, metric='euclidean')
points_2d = umap.fit_transform(features)
# —— 4. 分别收集索引 —— #
idx1 = [i for i,d in enumerate(descs) if d in label_dict and label_dict[d]==1]
idx0 = [i for i,d in enumerate(descs) if d in label_dict and label_dict[d]==0]
idx_rest = [i for i,d in enumerate(descs) if d not in label_dict]
# —— 5. 构造渐变色 —— #
# skyblue ('#87CEEB') → peachpuff ('#FFDAB9')
cmap = LinearSegmentedColormap.from_list('bg_cmap', ["#6EB1EC", "#E69C98"])
# 背景点用它们对应的 score 决定颜色,缺失则给 0.5
bg_scores = np.array([ score_dict.get(descs[i], 0.5) for i in idx_rest ])
bg_colors = cmap(bg_scores)
# —— 6. 绘图 —— #
plt.figure(figsize=(3,3))
# 背景点(渐变色)
plt.scatter(
points_2d[idx_rest,0],
points_2d[idx_rest,1],
s=20,
c=bg_colors,
alpha=0.8,
label='Background'
)
# 标签 0
plt.scatter(
points_2d[idx0,0],
points_2d[idx0,1],
s=60,
c="#218BE7",
alpha=1.0,
label='Not Improved'
)
# 标签 1
plt.scatter(
points_2d[idx1,0],
points_2d[idx1,1],
s=60,
c="#E76760",
alpha=1.0,
label='Improved'
)
# # 标注被高亮的点
# for i in idx0 + idx1:
# plt.annotate(
# descs[i],
# (points_2d[i,0], points_2d[i,1]),
# textcoords='offset points',
# xytext=(2,2),
# fontsize=8,
# color='black'
# )
# plt.title('UMAP Visualization', fontsize=18)
plt.axis('off')
# plt.legend(loc='upper left', frameon=False, fontsize=14)
plt.tight_layout()
# —— 可选:保存图像 —— #
plt.savefig(img_path, bbox_inches='tight', pad_inches=0, dpi=300)
# plt.show()