File size: 2,441 Bytes
acbef3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import matplotlib.pyplot as plt
import umap

# ---------- 配置参数 ----------
file1 = "test_features.npz"   # 第一个 npz 文件路径
file2 = "case_features.npz"   # 第二个 npz 文件路径

titles = ["Before Finetune", "After Finetune"]  # 子图标题,可按需修改

# UMAP 参数
umap_kwargs = dict(n_components=2, random_state=465, metric="euclidean")

# 颜色与图例标签映射
color_map = {0: "#81B9E9", 1: "#EBB4B1"}
label_map = {0: "Not Improved", 1: "Improved"}

# ---------- 加载数据 ----------
data1 = np.load(file1)
feats1 = data1["features"]    # shape (n1, d)
gts1   = data1["gts"].ravel() # shape (n1,)

data2 = np.load(file2)
feats2 = data2["features"]    # shape (n2, d)
gts2   = data2["gts"].ravel() # shape (n2,)

# # 合并用于统一 fit
# feats_all = np.vstack([feats1, feats2])
# n1 = feats1.shape[0]

# # ---------- 同一空间 UMAP 降维 ----------
# reducer = umap.UMAP(**umap_kwargs)
# emb_all = reducer.fit_transform(feats_all)  # shape (n1+n2, 2)
# emb1 = emb_all[:n1]
# emb2 = emb_all[n1:]

reducer = umap.UMAP(**umap_kwargs)
emb1 = reducer.fit_transform(feats1)
emb2 = reducer.fit_transform(feats2)

# # 确定统一的坐标范围
# x_min, x_max = emb_all[:, 0].min(), emb_all[:, 0].max()
# y_min, y_max = emb_all[:, 1].min(), emb_all[:, 1].max()

# ---------- 绘图 ----------
fig, axes = plt.subplots(1, 2, figsize=(8, 5), sharex=False, sharey=False)
fig.suptitle(" R2 UMAP Visualization", fontsize=26)

for ax, emb, gts, title in zip(axes, (emb1, emb2), (gts1, gts2), titles):
    for cls in (0, 1):
        mask = (gts == cls)
        ax.scatter(
            emb[mask, 0], emb[mask, 1],
            c=color_map[cls],
            label=label_map[cls],
            s=50, alpha=0.8
        )
    ax.set_title(title, fontsize=20)
    # ax.set_xlim(x_min, x_max)
    # ax.set_ylim(y_min, y_max)
    ax.set_xticks([])
    ax.set_yticks([])

# ---------- 公用图例放在最下方 ----------
handles = []
labels = []
for cls in (0, 1):
    handles.append(plt.Line2D([], [], marker="o", color=color_map[cls],
                              linestyle="", markersize=6))
    labels.append(label_map[cls])

fig.legend(
    handles, labels,
    loc="lower center",
    ncol=2,
    frameon=False,
    bbox_to_anchor=(0.5, -0.02),
    fontsize=20,
    markerscale=2.0,
)

plt.tight_layout()
plt.subplots_adjust(bottom=0.12, wspace=0.08)
plt.savefig('umap.svg')
# plt.show()