PIVOT / src /experiments /make_figures.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
4.56 kB
"""publication-quality figures. saves each one as both png and pdf to figures/.
reads only computed artifacts (result jsons, trained model).
"""
from __future__ import annotations
import json
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams.update({
"font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14,
"xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 10,
"figure.figsize": (8, 6), "figure.dpi": 150, "savefig.dpi": 300,
"savefig.bbox": "tight", "axes.spines.top": False, "axes.spines.right": False,
})
FIG = "figures"
RES = "experiments/results"
def _save(fig, name):
os.makedirs(FIG, exist_ok=True)
fig.savefig(os.path.join(FIG, f"{name}.png"))
fig.savefig(os.path.join(FIG, f"{name}.pdf"))
plt.close(fig)
print(f" figure -> {FIG}/{name}.{{png,pdf}}")
def _load(name):
with open(os.path.join(RES, f"{name}.json")) as fh:
return json.load(fh)
def fig_embedding_panels(model, data, targets, device, name="fig1_embedding_panels"):
"""figure 1: control / target / pivot-predicted-endpoint / observed cells in 2d."""
import torch
from src.evaluation import inference as inf
# 2d projection: first 2 pcs of the cell-state embedding
proj = data.emb[:, :2]
fig, axes = plt.subplots(1, len(targets), figsize=(5 * len(targets), 4.5))
if len(targets) == 1:
axes = [axes]
rng = np.random.default_rng(0)
ctrl = data.control_idx
c0e = torch.as_tensor(data.emb[rng.choice(ctrl, 300, replace=False)], dtype=torch.float32, device=device)
for ax, p in zip(axes, targets):
ci = rng.choice(ctrl, 400, replace=False)
ti = data.pert_to_idx[p]
ax.scatter(proj[ci, 0], proj[ci, 1], s=6, c="#bbbbbb", alpha=0.5, label="control")
ax.scatter(proj[ti, 0], proj[ti, 1], s=8, c="#1f77b4", alpha=0.6, label=f"observed {p}")
# pivot predicted endpoints for the true perturbation
e = inf.encode_label(model, data, p, device)
chat = inf.forward_predict(model, c0e, e).cpu().numpy()
ax.scatter(chat[:, 0], chat[:, 1], s=10, c="#d62728", alpha=0.7, marker="x", label="PIVOT predicted")
ax.set_title(p)
ax.set_xlabel("PC1"); ax.set_ylabel("PC2")
ax.legend(loc="best", markerscale=1.5)
fig.suptitle("Control, observed perturbed, and PIVOT-predicted endpoints (PC space)")
_save(fig, name)
def fig_guidance_steps(dataset="norman", name="fig_guidance_steps"):
obj = _load(f"{dataset}_ablation_guidance_steps")
steps = [int(k) for k in obj["rows"]]
top5 = [obj["rows"][str(s)]["top5"] for s in steps]
ndcg = [obj["rows"][str(s)]["ndcg"] for s in steps]
fig, ax = plt.subplots()
ax.plot(steps, top5, "o-", label="Top-5")
ax.plot(steps, ndcg, "s--", label="nDCG@10")
ax.set_xlabel("Reward-guidance steps L"); ax.set_ylabel("Recovery metric")
ax.set_title("Effect of reward-guidance steps")
ax.legend()
_save(fig, name)
def fig_datascale(dataset="norman", name="fig_datascale"):
obj = _load(f"{dataset}_ablation_datascale")
fr = sorted(float(k) for k in obj["rows"])
de = [obj["rows"][str(f)]["forward"]["de_corr"] for f in fr]
t5 = [obj["rows"][str(f)]["inverse"]["top5"] for f in fr]
fig, ax = plt.subplots()
ax.plot([f * 100 for f in fr], de, "o-", label="Forward DE-corr")
ax.plot([f * 100 for f in fr], t5, "s--", label="Inverse Top-5")
ax.set_xlabel("Training data (%)"); ax.set_ylabel("Metric")
ax.set_title("Data scaling")
ax.legend()
_save(fig, name)
def fig_components(dataset="norman", name="fig_components"):
obj = _load(f"{dataset}_ablation_components")
rows = list(obj["rows"])
t5 = [obj["rows"][r]["inverse"]["top5"] for r in rows]
de = [obj["rows"][r]["forward"]["de_corr"] for r in rows]
x = np.arange(len(rows)); w = 0.38
fig, ax = plt.subplots(figsize=(9, 5))
ax.bar(x - w / 2, de, w, label="Forward DE-corr")
ax.bar(x + w / 2, t5, w, label="Inverse Top-5")
ax.set_xticks(x); ax.set_xticklabels(rows, rotation=20, ha="right")
ax.set_ylabel("Metric"); ax.set_title("Component ablation")
ax.legend()
_save(fig, name)
def fig_loss_curves(history, name="fig_loss_curves"):
fig, ax = plt.subplots()
for k in ["total", "map", "tan", "semi"]:
ax.plot([h[k] for h in history], label=k)
ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.set_yscale("log")
ax.set_title("PIVOT training losses"); ax.legend()
_save(fig, name)