PIVOT / scripts /figures.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
9.73 kB
"""figures for the paper, helvetica-ish font (nimbus sans).
figure 1: effect-axis embedding panels (control / observed / pivot-predicted).
figure 2: quantitative results (forward bars, gears head-to-head, dist-loss, reward)."""
import sys, os, glob, json
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import numpy as np
import matplotlib
matplotlib.use("Agg")
from matplotlib import font_manager as fm
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
# helvetica-family font (nimbus sans = urw helvetica clone)
for patt in ["/usr/share/fonts/opentype/urw-base35/NimbusSans-*.otf",
"/usr/share/fonts/truetype/liberation2/LiberationSans-*.ttf"]:
for f in glob.glob(patt):
try: fm.fontManager.addfont(f)
except Exception: pass
plt.rcParams.update({
"font.family": "sans-serif",
"font.sans-serif": ["Nimbus Sans", "Helvetica", "Arial", "Liberation Sans", "DejaVu Sans"],
"mathtext.fontset": "dejavusans",
"font.size": 11, "axes.labelsize": 12, "axes.titlesize": 12.5,
"xtick.labelsize": 10, "ytick.labelsize": 10, "legend.fontsize": 10,
"axes.linewidth": 0.9, "axes.edgecolor": "#444444",
"xtick.color": "#444444", "ytick.color": "#444444",
"axes.labelcolor": "#222222", "text.color": "#222222",
"figure.dpi": 150, "savefig.dpi": 320, "savefig.bbox": "tight",
"axes.spines.top": False, "axes.spines.right": False,
"axes.grid": True, "grid.color": "#E6E8EB", "grid.linewidth": 0.8,
"axes.axisbelow": True, "legend.frameon": False,
})
# palette
C_CTRL, C_OBS, C_PRED, C_PIVOT, C_BASE, C_ACC = "#B9C2CC", "#2D6FB3", "#E4572E", "#1B9E77", "#9AA6B2", "#6A4C93"
FIG = "figures"
RES = "experiments/results"
os.makedirs(FIG, exist_ok=True)
def L(n): return json.load(open(f"{RES}/{n}.json"))
def save(fig, name):
fig.savefig(f"{FIG}/{name}.png"); fig.savefig(f"{FIG}/{name}.pdf"); plt.close(fig)
print("wrote", name)
def kde_contour(ax, X, color, levels=4):
try:
from scipy.stats import gaussian_kde
if len(X) < 10: return
k = gaussian_kde(X.T)
xmin, ymin = X.min(0); xmax, ymax = X.max(0)
xs, ys = np.mgrid[xmin:xmax:80j, ymin:ymax:80j]
z = k(np.vstack([xs.ravel(), ys.ravel()])).reshape(xs.shape)
ax.contour(xs, ys, z, levels=levels, colors=color, linewidths=1.0, alpha=0.55)
except Exception:
pass
def figure1(model, data, targets, device):
"""effect-axis panels: x = projection on (target-control) direction, y = orthogonal pc."""
from src.evaluation import inference as inf
import torch
rng = np.random.default_rng(0)
ctrl_idx = data.control_idx
cmean = data.emb[ctrl_idx].mean(0)
c0e = torch.as_tensor(data.emb[rng.choice(ctrl_idx, 400, replace=False)], dtype=torch.float32, device=device)
n = len(targets)
fig, axes = plt.subplots(1, n, figsize=(4.3 * n, 4.1))
if n == 1: axes = [axes]
for ax, p in zip(axes, targets):
ti = data.pert_to_idx[p]
tmean = data.emb[ti].mean(0)
d = tmean - cmean; d = d / (np.linalg.norm(d) + 1e-9) # effect direction
# orthogonal axis = top pc of perturbed cells with effect-dir removed
Y = data.emb[ti] - data.emb[ti].mean(0)
Y = Y - np.outer(Y @ d, d)
u, s, vt = np.linalg.svd(Y, full_matrices=False); o = vt[0]
def proj(M): return np.c_[(M - cmean) @ d, (M - cmean) @ o]
ci = rng.choice(ctrl_idx, 500, replace=False)
Pc, Po = proj(data.emb[ci]), proj(data.emb[ti])
e = inf.encode_label(model, data, p, device)
Ppred = proj(inf.forward_predict(model, c0e, e).cpu().numpy())
ax.scatter(Pc[:, 0], Pc[:, 1], s=9, c=C_CTRL, alpha=0.7, linewidths=0, label="control", rasterized=True)
ax.scatter(Po[:, 0], Po[:, 1], s=11, c=C_OBS, alpha=0.55, linewidths=0, label="observed perturbed", rasterized=True)
ax.scatter(Ppred[:, 0], Ppred[:, 1], s=22, marker="X", c=C_PRED, alpha=0.9,
edgecolors="white", linewidths=0.4, label="PIVOT predicted")
kde_contour(ax, Po, C_OBS); kde_contour(ax, Ppred, C_PRED)
# arrow control-centroid -> observed-centroid (the transport)
cc, oc = proj(cmean[None])[0], proj(tmean[None])[0]
ax.annotate("", xy=(oc[0], oc[1]), xytext=(cc[0], cc[1]),
arrowprops=dict(arrowstyle="-|>", color="#333333", lw=1.6, alpha=0.8))
ax.set_title(p.replace("_", "+"), fontweight="bold")
ax.set_xlabel("effect axis"); ax.set_ylabel("orthogonal axis")
ax.tick_params(length=0)
handles = [Line2D([], [], marker='o', ls='', mfc=C_CTRL, mec='none', ms=7, label='control'),
Line2D([], [], marker='o', ls='', mfc=C_OBS, mec='none', ms=7, label='observed perturbed'),
Line2D([], [], marker='X', ls='', mfc=C_PRED, mec='white', ms=8, label='PIVOT predicted')]
fig.legend(handles=handles, loc="lower center", ncol=3, bbox_to_anchor=(0.5, -0.04))
fig.suptitle("Control cells transported toward the perturbed population", y=1.02,
fontsize=13, fontweight="bold")
fig.tight_layout()
save(fig, "fig1_embedding_panels")
def figure2_results():
fig, ax = plt.subplots(2, 2, figsize=(11, 7.4))
# (a) forward de-corr across methods (held-out perturbation), from benchmark
bf = L("norman_benchmark")["forward"]
order = ["PIVOT", "LinearResponse", "kNN-latent", "Additive", "NearestPerturbationCentroid",
"ConditionalMLP", "EndpointMLP", "AvgPerturbationEffect", "MeanControl", "Random"]
pretty = {"PIVOT": "PIVOT", "LinearResponse": "Linear", "kNN-latent": "kNN-latent",
"Additive": "Additive", "NearestPerturbationCentroid": "Nearest centroid",
"ConditionalMLP": "Conditional MLP", "EndpointMLP": "Endpoint MLP",
"AvgPerturbationEffect": "Avg. effect", "MeanControl": "Mean control", "Random": "Random"}
vals = [(pretty[m], bf[m]["de_corr"]) for m in order if m in bf]
vals.sort(key=lambda kv: kv[1])
names = [v[0] for v in vals]; de = [v[1] for v in vals]
cols = [C_PIVOT if n == "PIVOT" else C_BASE for n in names]
a = ax[0, 0]; a.barh(names, de, color=cols, edgecolor="white", height=0.74)
for i, v in enumerate(de): a.text(v + 0.01, i, f"{v:.2f}", va="center", fontsize=9, color="#333")
a.set_xlim(0, 1.0); a.set_xlabel("DE correlation $\\uparrow$"); a.grid(axis="y", visible=False)
a.set_title("a Forward direction, held-out perturbations", loc="left", fontweight="bold")
# (b) gears head-to-head
pg = L("pivot_vs_gears")
b = ax[0, 1]
bars = b.bar(["PIVOT", "GEARS"], [pg["pivot_pearson_de_expr"], pg["gears_pearson_de_expr"]],
color=[C_PIVOT, C_ACC], edgecolor="white", width=0.55)
for r, v in zip(bars, [pg["pivot_pearson_de_expr"], pg["gears_pearson_de_expr"]]):
b.text(r.get_x() + r.get_width()/2, v + 0.012, f"{v:.3f}", ha="center", fontsize=11, fontweight="bold")
b.set_ylim(0, 1.08); b.set_ylabel("Top-20 DE-gene Pearson $\\uparrow$"); b.grid(axis="x", visible=False)
b.set_title("b Head-to-head vs GEARS (matched perts)", loc="left", fontweight="bold")
# (c) distributional loss: mmd down, de-corr preserved
dl = L("norman_distloss")["rows"]
lam = sorted(float(k) for k in dl); mmd = [dl[str(l) if str(l) in dl else f"{l:.1f}"]["mmd"] for l in lam]
de2 = [dl[str(l) if str(l) in dl else f"{l:.1f}"]["de_corr"] for l in lam]
c = ax[1, 0]
c.plot(lam, mmd, "o-", color=C_PRED, lw=2.2, ms=7, label="MMD $\\downarrow$")
c.set_xlabel("distributional-loss weight $\\lambda_{\\mathrm{dist}}$"); c.set_ylabel("population MMD $\\downarrow$", color=C_PRED)
c.tick_params(axis="y", colors=C_PRED); c.set_ylim(0, max(mmd)*1.15)
c2 = c.twinx(); c2.plot(lam, de2, "s--", color=C_OBS, lw=2.0, ms=6, label="DE-corr $\\uparrow$")
c2.set_ylabel("DE correlation $\\uparrow$", color=C_OBS); c2.tick_params(axis="y", colors=C_OBS)
c2.set_ylim(0.5, 0.95); c2.grid(False); c2.spines["top"].set_visible(False)
c.set_title("c Distributional flow loss: MMD $6\\times$ lower, direction kept", loc="left", fontweight="bold")
# (d) reward ablation top-5
rw = L("norman_ablation_reward")["rows"]
rmap = {"centroid": "Centroid", "nn_target": "NN-target", "mmd": "MMD", "wasserstein": "Wasserstein", "cosine": "Cosine"}
rn = [rmap.get(k, k) for k in rw]; t5 = [rw[k]["top5"] for k in rw]
cols2 = [C_PIVOT if rmap.get(k) == "Cosine" else C_BASE for k in rw]
d = ax[1, 1]; bars = d.bar(rn, t5, color=cols2, edgecolor="white", width=0.66)
for r, v in zip(bars, t5): d.text(r.get_x()+r.get_width()/2, v+0.006, f"{v:.2f}", ha="center", fontsize=9.5)
d.set_ylabel("nomination Top-5 $\\uparrow$"); d.set_ylim(0, max(t5)*1.25); d.grid(axis="x", visible=False)
d.tick_params(axis="x", rotation=18)
d.set_title("d Direction-aware reward wins at nomination", loc="left", fontweight="bold")
fig.tight_layout()
save(fig, "fig2_results")
if __name__ == "__main__":
from src.data.perturb_data import load_dataset
from src.training.train import TrainConfig, train
data = load_dataset("norman")
gpu = int(os.environ.get("PIVOT_GPU", "3"))
cfg = TrainConfig(dataset="norman", split="cell", epochs=60, device_index=gpu)
model, info = train(cfg, data=data, verbose=False)
dev = next(model.parameters()).device
singles = [p for p in data.perturbations if len(data.parse(p)) == 1]
combos = [p for p in data.perturbations if len(data.parse(p)) == 2]
figure1(model, data, [singles[0], singles[7], combos[0]], dev)
figure2_results()
print("FIGURES_V2_DONE")