import numpy as np import matplotlib.pyplot as plt import umap # ---------- 配置参数 ---------- file1 = "test_features.npz" # 第一个 npz 文件路径 file2 = "case_features.npz" # 第二个 npz 文件路径 titles = ["Before Finetune", "After Finetune"] # 子图标题,可按需修改 # UMAP 参数 umap_kwargs = dict(n_components=2, random_state=465, metric="euclidean") # 颜色与图例标签映射 color_map = {0: "#81B9E9", 1: "#EBB4B1"} label_map = {0: "Not Improved", 1: "Improved"} # ---------- 加载数据 ---------- data1 = np.load(file1) feats1 = data1["features"] # shape (n1, d) gts1 = data1["gts"].ravel() # shape (n1,) data2 = np.load(file2) feats2 = data2["features"] # shape (n2, d) gts2 = data2["gts"].ravel() # shape (n2,) # # 合并用于统一 fit # feats_all = np.vstack([feats1, feats2]) # n1 = feats1.shape[0] # # ---------- 同一空间 UMAP 降维 ---------- # reducer = umap.UMAP(**umap_kwargs) # emb_all = reducer.fit_transform(feats_all) # shape (n1+n2, 2) # emb1 = emb_all[:n1] # emb2 = emb_all[n1:] reducer = umap.UMAP(**umap_kwargs) emb1 = reducer.fit_transform(feats1) emb2 = reducer.fit_transform(feats2) # # 确定统一的坐标范围 # x_min, x_max = emb_all[:, 0].min(), emb_all[:, 0].max() # y_min, y_max = emb_all[:, 1].min(), emb_all[:, 1].max() # ---------- 绘图 ---------- fig, axes = plt.subplots(1, 2, figsize=(8, 5), sharex=False, sharey=False) fig.suptitle(" R2 UMAP Visualization", fontsize=26) for ax, emb, gts, title in zip(axes, (emb1, emb2), (gts1, gts2), titles): for cls in (0, 1): mask = (gts == cls) ax.scatter( emb[mask, 0], emb[mask, 1], c=color_map[cls], label=label_map[cls], s=50, alpha=0.8 ) ax.set_title(title, fontsize=20) # ax.set_xlim(x_min, x_max) # ax.set_ylim(y_min, y_max) ax.set_xticks([]) ax.set_yticks([]) # ---------- 公用图例放在最下方 ---------- handles = [] labels = [] for cls in (0, 1): handles.append(plt.Line2D([], [], marker="o", color=color_map[cls], linestyle="", markersize=6)) labels.append(label_map[cls]) fig.legend( handles, labels, loc="lower center", ncol=2, frameon=False, bbox_to_anchor=(0.5, -0.02), fontsize=20, markerscale=2.0, ) plt.tight_layout() plt.subplots_adjust(bottom=0.12, wspace=0.08) plt.savefig('umap.svg') # plt.show()