import torch import numpy as np import pandas as pd from umap import UMAP import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap # base_path = '../run-cls/rn18-diff-16-mamba-pcs-768-ce-32-0.001-50' # img_path = 'umap_before.png' base_path = '../run-cls/rn18-diff-16-mamba-pcs-768-ce-32-0.001-50/uda_r2' img_path = 'umap_after.png' # —— 0. 你的两个函数 —— # def get_label_dict(): """ 返回一个字典 { 描述(str): 标签(int, 0 或 1) } 比如: return { "image_0001": 1, "image_0234": 0, # ... } """ df = pd.read_excel('../dataset/r2_case.xlsx') df = df[~pd.isna(df['TAG'])] tags = {} for _, row in df.iterrows(): desc = row['SEQUENCE - D-type amino acid substitution'] tag = row['TAG'] if tag.lower() == 'improved': tags[desc] = 1 elif tag.lower() == 'not improved': tags[desc] = 0 return tags def get_score_dict(): """ 返回 { 描述: Float in [0,1] }, 用于背景点的颜色映射。 """ df = pd.read_csv(f'{base_path}/feature_preds.csv') desc = df['seq'].values scores = df['pred'].values minimum, maximum = np.min(scores), np.max(scores) scores = (scores - minimum) / (maximum - minimum) tags = { desc: score for desc, score in zip(desc, scores) } return tags # —— 1. 加载特征 —— # data_dict = torch.load(f'{base_path}/features.pth', map_location='cpu') descs = list(data_dict.keys()) features = np.vstack([ data_dict[d].cpu().numpy() if isinstance(data_dict[d], torch.Tensor) else np.array(data_dict[d]) for d in descs ]) print(f"共 {features.shape[0]} 个样本,特征维度 {features.shape[1]}") # —— 2. 拿到标签与分数字典 —— # label_dict = get_label_dict() score_dict = get_score_dict() # —— 3. 用 UMAP 降到 2 维 —— # umap = UMAP(n_components=2, metric='euclidean') points_2d = umap.fit_transform(features) # —— 4. 分别收集索引 —— # idx1 = [i for i,d in enumerate(descs) if d in label_dict and label_dict[d]==1] idx0 = [i for i,d in enumerate(descs) if d in label_dict and label_dict[d]==0] idx_rest = [i for i,d in enumerate(descs) if d not in label_dict] # —— 5. 构造渐变色 —— # # skyblue ('#87CEEB') → peachpuff ('#FFDAB9') cmap = LinearSegmentedColormap.from_list('bg_cmap', ["#6EB1EC", "#E69C98"]) # 背景点用它们对应的 score 决定颜色,缺失则给 0.5 bg_scores = np.array([ score_dict.get(descs[i], 0.5) for i in idx_rest ]) bg_colors = cmap(bg_scores) # —— 6. 绘图 —— # plt.figure(figsize=(3,3)) # 背景点(渐变色) plt.scatter( points_2d[idx_rest,0], points_2d[idx_rest,1], s=20, c=bg_colors, alpha=0.8, label='Background' ) # 标签 0 plt.scatter( points_2d[idx0,0], points_2d[idx0,1], s=60, c="#218BE7", alpha=1.0, label='Not Improved' ) # 标签 1 plt.scatter( points_2d[idx1,0], points_2d[idx1,1], s=60, c="#E76760", alpha=1.0, label='Improved' ) # # 标注被高亮的点 # for i in idx0 + idx1: # plt.annotate( # descs[i], # (points_2d[i,0], points_2d[i,1]), # textcoords='offset points', # xytext=(2,2), # fontsize=8, # color='black' # ) # plt.title('UMAP Visualization', fontsize=18) plt.axis('off') # plt.legend(loc='upper left', frameon=False, fontsize=14) plt.tight_layout() # —— 可选:保存图像 —— # plt.savefig(img_path, bbox_inches='tight', pad_inches=0, dpi=300) # plt.show()