DAminoMuta / vis /umap_plot copy.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import numpy as np
import matplotlib.pyplot as plt
import umap
# ---------- 配置参数 ----------
file1 = "test_features.npz" # 第一个 npz 文件路径
file2 = "case_features.npz" # 第二个 npz 文件路径
titles = ["Before Finetune", "After Finetune"] # 子图标题,可按需修改
# UMAP 参数
umap_kwargs = dict(n_components=2, random_state=465, metric="euclidean")
# 颜色与图例标签映射
color_map = {0: "#81B9E9", 1: "#EBB4B1"}
label_map = {0: "Not Improved", 1: "Improved"}
# ---------- 加载数据 ----------
data1 = np.load(file1)
feats1 = data1["features"] # shape (n1, d)
gts1 = data1["gts"].ravel() # shape (n1,)
data2 = np.load(file2)
feats2 = data2["features"] # shape (n2, d)
gts2 = data2["gts"].ravel() # shape (n2,)
# # 合并用于统一 fit
# feats_all = np.vstack([feats1, feats2])
# n1 = feats1.shape[0]
# # ---------- 同一空间 UMAP 降维 ----------
# reducer = umap.UMAP(**umap_kwargs)
# emb_all = reducer.fit_transform(feats_all) # shape (n1+n2, 2)
# emb1 = emb_all[:n1]
# emb2 = emb_all[n1:]
reducer = umap.UMAP(**umap_kwargs)
emb1 = reducer.fit_transform(feats1)
emb2 = reducer.fit_transform(feats2)
# # 确定统一的坐标范围
# x_min, x_max = emb_all[:, 0].min(), emb_all[:, 0].max()
# y_min, y_max = emb_all[:, 1].min(), emb_all[:, 1].max()
# ---------- 绘图 ----------
fig, axes = plt.subplots(1, 2, figsize=(8, 5), sharex=False, sharey=False)
fig.suptitle(" R2 UMAP Visualization", fontsize=26)
for ax, emb, gts, title in zip(axes, (emb1, emb2), (gts1, gts2), titles):
for cls in (0, 1):
mask = (gts == cls)
ax.scatter(
emb[mask, 0], emb[mask, 1],
c=color_map[cls],
label=label_map[cls],
s=50, alpha=0.8
)
ax.set_title(title, fontsize=20)
# ax.set_xlim(x_min, x_max)
# ax.set_ylim(y_min, y_max)
ax.set_xticks([])
ax.set_yticks([])
# ---------- 公用图例放在最下方 ----------
handles = []
labels = []
for cls in (0, 1):
handles.append(plt.Line2D([], [], marker="o", color=color_map[cls],
linestyle="", markersize=6))
labels.append(label_map[cls])
fig.legend(
handles, labels,
loc="lower center",
ncol=2,
frameon=False,
bbox_to_anchor=(0.5, -0.02),
fontsize=20,
markerscale=2.0,
)
plt.tight_layout()
plt.subplots_adjust(bottom=0.12, wspace=0.08)
plt.savefig('umap.svg')
# plt.show()