File size: 9,729 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""figures for the paper, helvetica-ish font (nimbus sans).
figure 1: effect-axis embedding panels (control / observed / pivot-predicted).
figure 2: quantitative results (forward bars, gears head-to-head, dist-loss, reward)."""
import sys, os, glob, json
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import numpy as np
import matplotlib
matplotlib.use("Agg")
from matplotlib import font_manager as fm
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

# helvetica-family font (nimbus sans = urw helvetica clone)
for patt in ["/usr/share/fonts/opentype/urw-base35/NimbusSans-*.otf",
             "/usr/share/fonts/truetype/liberation2/LiberationSans-*.ttf"]:
    for f in glob.glob(patt):
        try: fm.fontManager.addfont(f)
        except Exception: pass
plt.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": ["Nimbus Sans", "Helvetica", "Arial", "Liberation Sans", "DejaVu Sans"],
    "mathtext.fontset": "dejavusans",
    "font.size": 11, "axes.labelsize": 12, "axes.titlesize": 12.5,
    "xtick.labelsize": 10, "ytick.labelsize": 10, "legend.fontsize": 10,
    "axes.linewidth": 0.9, "axes.edgecolor": "#444444",
    "xtick.color": "#444444", "ytick.color": "#444444",
    "axes.labelcolor": "#222222", "text.color": "#222222",
    "figure.dpi": 150, "savefig.dpi": 320, "savefig.bbox": "tight",
    "axes.spines.top": False, "axes.spines.right": False,
    "axes.grid": True, "grid.color": "#E6E8EB", "grid.linewidth": 0.8,
    "axes.axisbelow": True, "legend.frameon": False,
})
# palette
C_CTRL, C_OBS, C_PRED, C_PIVOT, C_BASE, C_ACC = "#B9C2CC", "#2D6FB3", "#E4572E", "#1B9E77", "#9AA6B2", "#6A4C93"
FIG = "figures"
RES = "experiments/results"
os.makedirs(FIG, exist_ok=True)
def L(n): return json.load(open(f"{RES}/{n}.json"))
def save(fig, name):
    fig.savefig(f"{FIG}/{name}.png"); fig.savefig(f"{FIG}/{name}.pdf"); plt.close(fig)
    print("wrote", name)


def kde_contour(ax, X, color, levels=4):
    try:
        from scipy.stats import gaussian_kde
        if len(X) < 10: return
        k = gaussian_kde(X.T)
        xmin, ymin = X.min(0); xmax, ymax = X.max(0)
        xs, ys = np.mgrid[xmin:xmax:80j, ymin:ymax:80j]
        z = k(np.vstack([xs.ravel(), ys.ravel()])).reshape(xs.shape)
        ax.contour(xs, ys, z, levels=levels, colors=color, linewidths=1.0, alpha=0.55)
    except Exception:
        pass


def figure1(model, data, targets, device):
    """effect-axis panels: x = projection on (target-control) direction, y = orthogonal pc."""
    from src.evaluation import inference as inf
    import torch
    rng = np.random.default_rng(0)
    ctrl_idx = data.control_idx
    cmean = data.emb[ctrl_idx].mean(0)
    c0e = torch.as_tensor(data.emb[rng.choice(ctrl_idx, 400, replace=False)], dtype=torch.float32, device=device)
    n = len(targets)
    fig, axes = plt.subplots(1, n, figsize=(4.3 * n, 4.1))
    if n == 1: axes = [axes]
    for ax, p in zip(axes, targets):
        ti = data.pert_to_idx[p]
        tmean = data.emb[ti].mean(0)
        d = tmean - cmean; d = d / (np.linalg.norm(d) + 1e-9)          # effect direction
        # orthogonal axis = top pc of perturbed cells with effect-dir removed
        Y = data.emb[ti] - data.emb[ti].mean(0)
        Y = Y - np.outer(Y @ d, d)
        u, s, vt = np.linalg.svd(Y, full_matrices=False); o = vt[0]
        def proj(M): return np.c_[(M - cmean) @ d, (M - cmean) @ o]
        ci = rng.choice(ctrl_idx, 500, replace=False)
        Pc, Po = proj(data.emb[ci]), proj(data.emb[ti])
        e = inf.encode_label(model, data, p, device)
        Ppred = proj(inf.forward_predict(model, c0e, e).cpu().numpy())
        ax.scatter(Pc[:, 0], Pc[:, 1], s=9, c=C_CTRL, alpha=0.7, linewidths=0, label="control", rasterized=True)
        ax.scatter(Po[:, 0], Po[:, 1], s=11, c=C_OBS, alpha=0.55, linewidths=0, label="observed perturbed", rasterized=True)
        ax.scatter(Ppred[:, 0], Ppred[:, 1], s=22, marker="X", c=C_PRED, alpha=0.9,
                   edgecolors="white", linewidths=0.4, label="PIVOT predicted")
        kde_contour(ax, Po, C_OBS); kde_contour(ax, Ppred, C_PRED)
        # arrow control-centroid -> observed-centroid (the transport)
        cc, oc = proj(cmean[None])[0], proj(tmean[None])[0]
        ax.annotate("", xy=(oc[0], oc[1]), xytext=(cc[0], cc[1]),
                    arrowprops=dict(arrowstyle="-|>", color="#333333", lw=1.6, alpha=0.8))
        ax.set_title(p.replace("_", "+"), fontweight="bold")
        ax.set_xlabel("effect axis"); ax.set_ylabel("orthogonal axis")
        ax.tick_params(length=0)
    handles = [Line2D([], [], marker='o', ls='', mfc=C_CTRL, mec='none', ms=7, label='control'),
               Line2D([], [], marker='o', ls='', mfc=C_OBS, mec='none', ms=7, label='observed perturbed'),
               Line2D([], [], marker='X', ls='', mfc=C_PRED, mec='white', ms=8, label='PIVOT predicted')]
    fig.legend(handles=handles, loc="lower center", ncol=3, bbox_to_anchor=(0.5, -0.04))
    fig.suptitle("Control cells transported toward the perturbed population", y=1.02,
                 fontsize=13, fontweight="bold")
    fig.tight_layout()
    save(fig, "fig1_embedding_panels")


def figure2_results():
    fig, ax = plt.subplots(2, 2, figsize=(11, 7.4))
    # (a) forward de-corr across methods (held-out perturbation), from benchmark
    bf = L("norman_benchmark")["forward"]
    order = ["PIVOT", "LinearResponse", "kNN-latent", "Additive", "NearestPerturbationCentroid",
             "ConditionalMLP", "EndpointMLP", "AvgPerturbationEffect", "MeanControl", "Random"]
    pretty = {"PIVOT": "PIVOT", "LinearResponse": "Linear", "kNN-latent": "kNN-latent",
              "Additive": "Additive", "NearestPerturbationCentroid": "Nearest centroid",
              "ConditionalMLP": "Conditional MLP", "EndpointMLP": "Endpoint MLP",
              "AvgPerturbationEffect": "Avg. effect", "MeanControl": "Mean control", "Random": "Random"}
    vals = [(pretty[m], bf[m]["de_corr"]) for m in order if m in bf]
    vals.sort(key=lambda kv: kv[1])
    names = [v[0] for v in vals]; de = [v[1] for v in vals]
    cols = [C_PIVOT if n == "PIVOT" else C_BASE for n in names]
    a = ax[0, 0]; a.barh(names, de, color=cols, edgecolor="white", height=0.74)
    for i, v in enumerate(de): a.text(v + 0.01, i, f"{v:.2f}", va="center", fontsize=9, color="#333")
    a.set_xlim(0, 1.0); a.set_xlabel("DE correlation $\\uparrow$"); a.grid(axis="y", visible=False)
    a.set_title("a  Forward direction, held-out perturbations", loc="left", fontweight="bold")

    # (b) gears head-to-head
    pg = L("pivot_vs_gears")
    b = ax[0, 1]
    bars = b.bar(["PIVOT", "GEARS"], [pg["pivot_pearson_de_expr"], pg["gears_pearson_de_expr"]],
                 color=[C_PIVOT, C_ACC], edgecolor="white", width=0.55)
    for r, v in zip(bars, [pg["pivot_pearson_de_expr"], pg["gears_pearson_de_expr"]]):
        b.text(r.get_x() + r.get_width()/2, v + 0.012, f"{v:.3f}", ha="center", fontsize=11, fontweight="bold")
    b.set_ylim(0, 1.08); b.set_ylabel("Top-20 DE-gene Pearson $\\uparrow$"); b.grid(axis="x", visible=False)
    b.set_title("b  Head-to-head vs GEARS (matched perts)", loc="left", fontweight="bold")

    # (c) distributional loss: mmd down, de-corr preserved
    dl = L("norman_distloss")["rows"]
    lam = sorted(float(k) for k in dl); mmd = [dl[str(l) if str(l) in dl else f"{l:.1f}"]["mmd"] for l in lam]
    de2 = [dl[str(l) if str(l) in dl else f"{l:.1f}"]["de_corr"] for l in lam]
    c = ax[1, 0]
    c.plot(lam, mmd, "o-", color=C_PRED, lw=2.2, ms=7, label="MMD $\\downarrow$")
    c.set_xlabel("distributional-loss weight $\\lambda_{\\mathrm{dist}}$"); c.set_ylabel("population MMD $\\downarrow$", color=C_PRED)
    c.tick_params(axis="y", colors=C_PRED); c.set_ylim(0, max(mmd)*1.15)
    c2 = c.twinx(); c2.plot(lam, de2, "s--", color=C_OBS, lw=2.0, ms=6, label="DE-corr $\\uparrow$")
    c2.set_ylabel("DE correlation $\\uparrow$", color=C_OBS); c2.tick_params(axis="y", colors=C_OBS)
    c2.set_ylim(0.5, 0.95); c2.grid(False); c2.spines["top"].set_visible(False)
    c.set_title("c  Distributional flow loss: MMD $6\\times$ lower, direction kept", loc="left", fontweight="bold")

    # (d) reward ablation top-5
    rw = L("norman_ablation_reward")["rows"]
    rmap = {"centroid": "Centroid", "nn_target": "NN-target", "mmd": "MMD", "wasserstein": "Wasserstein", "cosine": "Cosine"}
    rn = [rmap.get(k, k) for k in rw]; t5 = [rw[k]["top5"] for k in rw]
    cols2 = [C_PIVOT if rmap.get(k) == "Cosine" else C_BASE for k in rw]
    d = ax[1, 1]; bars = d.bar(rn, t5, color=cols2, edgecolor="white", width=0.66)
    for r, v in zip(bars, t5): d.text(r.get_x()+r.get_width()/2, v+0.006, f"{v:.2f}", ha="center", fontsize=9.5)
    d.set_ylabel("nomination Top-5 $\\uparrow$"); d.set_ylim(0, max(t5)*1.25); d.grid(axis="x", visible=False)
    d.tick_params(axis="x", rotation=18)
    d.set_title("d  Direction-aware reward wins at nomination", loc="left", fontweight="bold")
    fig.tight_layout()
    save(fig, "fig2_results")


if __name__ == "__main__":
    from src.data.perturb_data import load_dataset
    from src.training.train import TrainConfig, train
    data = load_dataset("norman")
    gpu = int(os.environ.get("PIVOT_GPU", "3"))
    cfg = TrainConfig(dataset="norman", split="cell", epochs=60, device_index=gpu)
    model, info = train(cfg, data=data, verbose=False)
    dev = next(model.parameters()).device
    singles = [p for p in data.perturbations if len(data.parse(p)) == 1]
    combos = [p for p in data.perturbations if len(data.parse(p)) == 2]
    figure1(model, data, [singles[0], singles[7], combos[0]], dev)
    figure2_results()
    print("FIGURES_V2_DONE")