| """figures for the paper, helvetica-ish font (nimbus sans). |
| figure 1: effect-axis embedding panels (control / observed / pivot-predicted). |
| figure 2: quantitative results (forward bars, gears head-to-head, dist-loss, reward).""" |
| import sys, os, glob, json |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
| import numpy as np |
| import matplotlib |
| matplotlib.use("Agg") |
| from matplotlib import font_manager as fm |
| import matplotlib.pyplot as plt |
| from matplotlib.patches import Patch |
| from matplotlib.lines import Line2D |
|
|
| |
| for patt in ["/usr/share/fonts/opentype/urw-base35/NimbusSans-*.otf", |
| "/usr/share/fonts/truetype/liberation2/LiberationSans-*.ttf"]: |
| for f in glob.glob(patt): |
| try: fm.fontManager.addfont(f) |
| except Exception: pass |
| plt.rcParams.update({ |
| "font.family": "sans-serif", |
| "font.sans-serif": ["Nimbus Sans", "Helvetica", "Arial", "Liberation Sans", "DejaVu Sans"], |
| "mathtext.fontset": "dejavusans", |
| "font.size": 11, "axes.labelsize": 12, "axes.titlesize": 12.5, |
| "xtick.labelsize": 10, "ytick.labelsize": 10, "legend.fontsize": 10, |
| "axes.linewidth": 0.9, "axes.edgecolor": "#444444", |
| "xtick.color": "#444444", "ytick.color": "#444444", |
| "axes.labelcolor": "#222222", "text.color": "#222222", |
| "figure.dpi": 150, "savefig.dpi": 320, "savefig.bbox": "tight", |
| "axes.spines.top": False, "axes.spines.right": False, |
| "axes.grid": True, "grid.color": "#E6E8EB", "grid.linewidth": 0.8, |
| "axes.axisbelow": True, "legend.frameon": False, |
| }) |
| |
| C_CTRL, C_OBS, C_PRED, C_PIVOT, C_BASE, C_ACC = "#B9C2CC", "#2D6FB3", "#E4572E", "#1B9E77", "#9AA6B2", "#6A4C93" |
| FIG = "figures" |
| RES = "experiments/results" |
| os.makedirs(FIG, exist_ok=True) |
| def L(n): return json.load(open(f"{RES}/{n}.json")) |
| def save(fig, name): |
| fig.savefig(f"{FIG}/{name}.png"); fig.savefig(f"{FIG}/{name}.pdf"); plt.close(fig) |
| print("wrote", name) |
|
|
|
|
| def kde_contour(ax, X, color, levels=4): |
| try: |
| from scipy.stats import gaussian_kde |
| if len(X) < 10: return |
| k = gaussian_kde(X.T) |
| xmin, ymin = X.min(0); xmax, ymax = X.max(0) |
| xs, ys = np.mgrid[xmin:xmax:80j, ymin:ymax:80j] |
| z = k(np.vstack([xs.ravel(), ys.ravel()])).reshape(xs.shape) |
| ax.contour(xs, ys, z, levels=levels, colors=color, linewidths=1.0, alpha=0.55) |
| except Exception: |
| pass |
|
|
|
|
| def figure1(model, data, targets, device): |
| """effect-axis panels: x = projection on (target-control) direction, y = orthogonal pc.""" |
| from src.evaluation import inference as inf |
| import torch |
| rng = np.random.default_rng(0) |
| ctrl_idx = data.control_idx |
| cmean = data.emb[ctrl_idx].mean(0) |
| c0e = torch.as_tensor(data.emb[rng.choice(ctrl_idx, 400, replace=False)], dtype=torch.float32, device=device) |
| n = len(targets) |
| fig, axes = plt.subplots(1, n, figsize=(4.3 * n, 4.1)) |
| if n == 1: axes = [axes] |
| for ax, p in zip(axes, targets): |
| ti = data.pert_to_idx[p] |
| tmean = data.emb[ti].mean(0) |
| d = tmean - cmean; d = d / (np.linalg.norm(d) + 1e-9) |
| |
| Y = data.emb[ti] - data.emb[ti].mean(0) |
| Y = Y - np.outer(Y @ d, d) |
| u, s, vt = np.linalg.svd(Y, full_matrices=False); o = vt[0] |
| def proj(M): return np.c_[(M - cmean) @ d, (M - cmean) @ o] |
| ci = rng.choice(ctrl_idx, 500, replace=False) |
| Pc, Po = proj(data.emb[ci]), proj(data.emb[ti]) |
| e = inf.encode_label(model, data, p, device) |
| Ppred = proj(inf.forward_predict(model, c0e, e).cpu().numpy()) |
| ax.scatter(Pc[:, 0], Pc[:, 1], s=9, c=C_CTRL, alpha=0.7, linewidths=0, label="control", rasterized=True) |
| ax.scatter(Po[:, 0], Po[:, 1], s=11, c=C_OBS, alpha=0.55, linewidths=0, label="observed perturbed", rasterized=True) |
| ax.scatter(Ppred[:, 0], Ppred[:, 1], s=22, marker="X", c=C_PRED, alpha=0.9, |
| edgecolors="white", linewidths=0.4, label="PIVOT predicted") |
| kde_contour(ax, Po, C_OBS); kde_contour(ax, Ppred, C_PRED) |
| |
| cc, oc = proj(cmean[None])[0], proj(tmean[None])[0] |
| ax.annotate("", xy=(oc[0], oc[1]), xytext=(cc[0], cc[1]), |
| arrowprops=dict(arrowstyle="-|>", color="#333333", lw=1.6, alpha=0.8)) |
| ax.set_title(p.replace("_", "+"), fontweight="bold") |
| ax.set_xlabel("effect axis"); ax.set_ylabel("orthogonal axis") |
| ax.tick_params(length=0) |
| handles = [Line2D([], [], marker='o', ls='', mfc=C_CTRL, mec='none', ms=7, label='control'), |
| Line2D([], [], marker='o', ls='', mfc=C_OBS, mec='none', ms=7, label='observed perturbed'), |
| Line2D([], [], marker='X', ls='', mfc=C_PRED, mec='white', ms=8, label='PIVOT predicted')] |
| fig.legend(handles=handles, loc="lower center", ncol=3, bbox_to_anchor=(0.5, -0.04)) |
| fig.suptitle("Control cells transported toward the perturbed population", y=1.02, |
| fontsize=13, fontweight="bold") |
| fig.tight_layout() |
| save(fig, "fig1_embedding_panels") |
|
|
|
|
| def figure2_results(): |
| fig, ax = plt.subplots(2, 2, figsize=(11, 7.4)) |
| |
| bf = L("norman_benchmark")["forward"] |
| order = ["PIVOT", "LinearResponse", "kNN-latent", "Additive", "NearestPerturbationCentroid", |
| "ConditionalMLP", "EndpointMLP", "AvgPerturbationEffect", "MeanControl", "Random"] |
| pretty = {"PIVOT": "PIVOT", "LinearResponse": "Linear", "kNN-latent": "kNN-latent", |
| "Additive": "Additive", "NearestPerturbationCentroid": "Nearest centroid", |
| "ConditionalMLP": "Conditional MLP", "EndpointMLP": "Endpoint MLP", |
| "AvgPerturbationEffect": "Avg. effect", "MeanControl": "Mean control", "Random": "Random"} |
| vals = [(pretty[m], bf[m]["de_corr"]) for m in order if m in bf] |
| vals.sort(key=lambda kv: kv[1]) |
| names = [v[0] for v in vals]; de = [v[1] for v in vals] |
| cols = [C_PIVOT if n == "PIVOT" else C_BASE for n in names] |
| a = ax[0, 0]; a.barh(names, de, color=cols, edgecolor="white", height=0.74) |
| for i, v in enumerate(de): a.text(v + 0.01, i, f"{v:.2f}", va="center", fontsize=9, color="#333") |
| a.set_xlim(0, 1.0); a.set_xlabel("DE correlation $\\uparrow$"); a.grid(axis="y", visible=False) |
| a.set_title("a Forward direction, held-out perturbations", loc="left", fontweight="bold") |
|
|
| |
| pg = L("pivot_vs_gears") |
| b = ax[0, 1] |
| bars = b.bar(["PIVOT", "GEARS"], [pg["pivot_pearson_de_expr"], pg["gears_pearson_de_expr"]], |
| color=[C_PIVOT, C_ACC], edgecolor="white", width=0.55) |
| for r, v in zip(bars, [pg["pivot_pearson_de_expr"], pg["gears_pearson_de_expr"]]): |
| b.text(r.get_x() + r.get_width()/2, v + 0.012, f"{v:.3f}", ha="center", fontsize=11, fontweight="bold") |
| b.set_ylim(0, 1.08); b.set_ylabel("Top-20 DE-gene Pearson $\\uparrow$"); b.grid(axis="x", visible=False) |
| b.set_title("b Head-to-head vs GEARS (matched perts)", loc="left", fontweight="bold") |
|
|
| |
| dl = L("norman_distloss")["rows"] |
| lam = sorted(float(k) for k in dl); mmd = [dl[str(l) if str(l) in dl else f"{l:.1f}"]["mmd"] for l in lam] |
| de2 = [dl[str(l) if str(l) in dl else f"{l:.1f}"]["de_corr"] for l in lam] |
| c = ax[1, 0] |
| c.plot(lam, mmd, "o-", color=C_PRED, lw=2.2, ms=7, label="MMD $\\downarrow$") |
| c.set_xlabel("distributional-loss weight $\\lambda_{\\mathrm{dist}}$"); c.set_ylabel("population MMD $\\downarrow$", color=C_PRED) |
| c.tick_params(axis="y", colors=C_PRED); c.set_ylim(0, max(mmd)*1.15) |
| c2 = c.twinx(); c2.plot(lam, de2, "s--", color=C_OBS, lw=2.0, ms=6, label="DE-corr $\\uparrow$") |
| c2.set_ylabel("DE correlation $\\uparrow$", color=C_OBS); c2.tick_params(axis="y", colors=C_OBS) |
| c2.set_ylim(0.5, 0.95); c2.grid(False); c2.spines["top"].set_visible(False) |
| c.set_title("c Distributional flow loss: MMD $6\\times$ lower, direction kept", loc="left", fontweight="bold") |
|
|
| |
| rw = L("norman_ablation_reward")["rows"] |
| rmap = {"centroid": "Centroid", "nn_target": "NN-target", "mmd": "MMD", "wasserstein": "Wasserstein", "cosine": "Cosine"} |
| rn = [rmap.get(k, k) for k in rw]; t5 = [rw[k]["top5"] for k in rw] |
| cols2 = [C_PIVOT if rmap.get(k) == "Cosine" else C_BASE for k in rw] |
| d = ax[1, 1]; bars = d.bar(rn, t5, color=cols2, edgecolor="white", width=0.66) |
| for r, v in zip(bars, t5): d.text(r.get_x()+r.get_width()/2, v+0.006, f"{v:.2f}", ha="center", fontsize=9.5) |
| d.set_ylabel("nomination Top-5 $\\uparrow$"); d.set_ylim(0, max(t5)*1.25); d.grid(axis="x", visible=False) |
| d.tick_params(axis="x", rotation=18) |
| d.set_title("d Direction-aware reward wins at nomination", loc="left", fontweight="bold") |
| fig.tight_layout() |
| save(fig, "fig2_results") |
|
|
|
|
| if __name__ == "__main__": |
| from src.data.perturb_data import load_dataset |
| from src.training.train import TrainConfig, train |
| data = load_dataset("norman") |
| gpu = int(os.environ.get("PIVOT_GPU", "3")) |
| cfg = TrainConfig(dataset="norman", split="cell", epochs=60, device_index=gpu) |
| model, info = train(cfg, data=data, verbose=False) |
| dev = next(model.parameters()).device |
| singles = [p for p in data.perturbations if len(data.parse(p)) == 1] |
| combos = [p for p in data.perturbations if len(data.parse(p)) == 2] |
| figure1(model, data, [singles[0], singles[7], combos[0]], dev) |
| figure2_results() |
| print("FIGURES_V2_DONE") |
|
|