"""publication-quality figures. saves each one as both png and pdf to figures/. reads only computed artifacts (result jsons, trained model). """ from __future__ import annotations import json import os import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np plt.rcParams.update({ "font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14, "xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 10, "figure.figsize": (8, 6), "figure.dpi": 150, "savefig.dpi": 300, "savefig.bbox": "tight", "axes.spines.top": False, "axes.spines.right": False, }) FIG = "figures" RES = "experiments/results" def _save(fig, name): os.makedirs(FIG, exist_ok=True) fig.savefig(os.path.join(FIG, f"{name}.png")) fig.savefig(os.path.join(FIG, f"{name}.pdf")) plt.close(fig) print(f" figure -> {FIG}/{name}.{{png,pdf}}") def _load(name): with open(os.path.join(RES, f"{name}.json")) as fh: return json.load(fh) def fig_embedding_panels(model, data, targets, device, name="fig1_embedding_panels"): """figure 1: control / target / pivot-predicted-endpoint / observed cells in 2d.""" import torch from src.evaluation import inference as inf # 2d projection: first 2 pcs of the cell-state embedding proj = data.emb[:, :2] fig, axes = plt.subplots(1, len(targets), figsize=(5 * len(targets), 4.5)) if len(targets) == 1: axes = [axes] rng = np.random.default_rng(0) ctrl = data.control_idx c0e = torch.as_tensor(data.emb[rng.choice(ctrl, 300, replace=False)], dtype=torch.float32, device=device) for ax, p in zip(axes, targets): ci = rng.choice(ctrl, 400, replace=False) ti = data.pert_to_idx[p] ax.scatter(proj[ci, 0], proj[ci, 1], s=6, c="#bbbbbb", alpha=0.5, label="control") ax.scatter(proj[ti, 0], proj[ti, 1], s=8, c="#1f77b4", alpha=0.6, label=f"observed {p}") # pivot predicted endpoints for the true perturbation e = inf.encode_label(model, data, p, device) chat = inf.forward_predict(model, c0e, e).cpu().numpy() ax.scatter(chat[:, 0], chat[:, 1], s=10, c="#d62728", alpha=0.7, marker="x", label="PIVOT predicted") ax.set_title(p) ax.set_xlabel("PC1"); ax.set_ylabel("PC2") ax.legend(loc="best", markerscale=1.5) fig.suptitle("Control, observed perturbed, and PIVOT-predicted endpoints (PC space)") _save(fig, name) def fig_guidance_steps(dataset="norman", name="fig_guidance_steps"): obj = _load(f"{dataset}_ablation_guidance_steps") steps = [int(k) for k in obj["rows"]] top5 = [obj["rows"][str(s)]["top5"] for s in steps] ndcg = [obj["rows"][str(s)]["ndcg"] for s in steps] fig, ax = plt.subplots() ax.plot(steps, top5, "o-", label="Top-5") ax.plot(steps, ndcg, "s--", label="nDCG@10") ax.set_xlabel("Reward-guidance steps L"); ax.set_ylabel("Recovery metric") ax.set_title("Effect of reward-guidance steps") ax.legend() _save(fig, name) def fig_datascale(dataset="norman", name="fig_datascale"): obj = _load(f"{dataset}_ablation_datascale") fr = sorted(float(k) for k in obj["rows"]) de = [obj["rows"][str(f)]["forward"]["de_corr"] for f in fr] t5 = [obj["rows"][str(f)]["inverse"]["top5"] for f in fr] fig, ax = plt.subplots() ax.plot([f * 100 for f in fr], de, "o-", label="Forward DE-corr") ax.plot([f * 100 for f in fr], t5, "s--", label="Inverse Top-5") ax.set_xlabel("Training data (%)"); ax.set_ylabel("Metric") ax.set_title("Data scaling") ax.legend() _save(fig, name) def fig_components(dataset="norman", name="fig_components"): obj = _load(f"{dataset}_ablation_components") rows = list(obj["rows"]) t5 = [obj["rows"][r]["inverse"]["top5"] for r in rows] de = [obj["rows"][r]["forward"]["de_corr"] for r in rows] x = np.arange(len(rows)); w = 0.38 fig, ax = plt.subplots(figsize=(9, 5)) ax.bar(x - w / 2, de, w, label="Forward DE-corr") ax.bar(x + w / 2, t5, w, label="Inverse Top-5") ax.set_xticks(x); ax.set_xticklabels(rows, rotation=20, ha="right") ax.set_ylabel("Metric"); ax.set_title("Component ablation") ax.legend() _save(fig, name) def fig_loss_curves(history, name="fig_loss_curves"): fig, ax = plt.subplots() for k in ["total", "map", "tan", "semi"]: ax.plot([h[k] for h in history], label=k) ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.set_yscale("log") ax.set_title("PIVOT training losses"); ax.legend() _save(fig, name)