| """publication-quality figures. saves each one as both png and pdf to figures/. |
| reads only computed artifacts (result jsons, trained model). |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| plt.rcParams.update({ |
| "font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14, |
| "xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 10, |
| "figure.figsize": (8, 6), "figure.dpi": 150, "savefig.dpi": 300, |
| "savefig.bbox": "tight", "axes.spines.top": False, "axes.spines.right": False, |
| }) |
| FIG = "figures" |
| RES = "experiments/results" |
|
|
|
|
| def _save(fig, name): |
| os.makedirs(FIG, exist_ok=True) |
| fig.savefig(os.path.join(FIG, f"{name}.png")) |
| fig.savefig(os.path.join(FIG, f"{name}.pdf")) |
| plt.close(fig) |
| print(f" figure -> {FIG}/{name}.{{png,pdf}}") |
|
|
|
|
| def _load(name): |
| with open(os.path.join(RES, f"{name}.json")) as fh: |
| return json.load(fh) |
|
|
|
|
| def fig_embedding_panels(model, data, targets, device, name="fig1_embedding_panels"): |
| """figure 1: control / target / pivot-predicted-endpoint / observed cells in 2d.""" |
| import torch |
| from src.evaluation import inference as inf |
|
|
| |
| proj = data.emb[:, :2] |
| fig, axes = plt.subplots(1, len(targets), figsize=(5 * len(targets), 4.5)) |
| if len(targets) == 1: |
| axes = [axes] |
| rng = np.random.default_rng(0) |
| ctrl = data.control_idx |
| c0e = torch.as_tensor(data.emb[rng.choice(ctrl, 300, replace=False)], dtype=torch.float32, device=device) |
| for ax, p in zip(axes, targets): |
| ci = rng.choice(ctrl, 400, replace=False) |
| ti = data.pert_to_idx[p] |
| ax.scatter(proj[ci, 0], proj[ci, 1], s=6, c="#bbbbbb", alpha=0.5, label="control") |
| ax.scatter(proj[ti, 0], proj[ti, 1], s=8, c="#1f77b4", alpha=0.6, label=f"observed {p}") |
| |
| e = inf.encode_label(model, data, p, device) |
| chat = inf.forward_predict(model, c0e, e).cpu().numpy() |
| ax.scatter(chat[:, 0], chat[:, 1], s=10, c="#d62728", alpha=0.7, marker="x", label="PIVOT predicted") |
| ax.set_title(p) |
| ax.set_xlabel("PC1"); ax.set_ylabel("PC2") |
| ax.legend(loc="best", markerscale=1.5) |
| fig.suptitle("Control, observed perturbed, and PIVOT-predicted endpoints (PC space)") |
| _save(fig, name) |
|
|
|
|
| def fig_guidance_steps(dataset="norman", name="fig_guidance_steps"): |
| obj = _load(f"{dataset}_ablation_guidance_steps") |
| steps = [int(k) for k in obj["rows"]] |
| top5 = [obj["rows"][str(s)]["top5"] for s in steps] |
| ndcg = [obj["rows"][str(s)]["ndcg"] for s in steps] |
| fig, ax = plt.subplots() |
| ax.plot(steps, top5, "o-", label="Top-5") |
| ax.plot(steps, ndcg, "s--", label="nDCG@10") |
| ax.set_xlabel("Reward-guidance steps L"); ax.set_ylabel("Recovery metric") |
| ax.set_title("Effect of reward-guidance steps") |
| ax.legend() |
| _save(fig, name) |
|
|
|
|
| def fig_datascale(dataset="norman", name="fig_datascale"): |
| obj = _load(f"{dataset}_ablation_datascale") |
| fr = sorted(float(k) for k in obj["rows"]) |
| de = [obj["rows"][str(f)]["forward"]["de_corr"] for f in fr] |
| t5 = [obj["rows"][str(f)]["inverse"]["top5"] for f in fr] |
| fig, ax = plt.subplots() |
| ax.plot([f * 100 for f in fr], de, "o-", label="Forward DE-corr") |
| ax.plot([f * 100 for f in fr], t5, "s--", label="Inverse Top-5") |
| ax.set_xlabel("Training data (%)"); ax.set_ylabel("Metric") |
| ax.set_title("Data scaling") |
| ax.legend() |
| _save(fig, name) |
|
|
|
|
| def fig_components(dataset="norman", name="fig_components"): |
| obj = _load(f"{dataset}_ablation_components") |
| rows = list(obj["rows"]) |
| t5 = [obj["rows"][r]["inverse"]["top5"] for r in rows] |
| de = [obj["rows"][r]["forward"]["de_corr"] for r in rows] |
| x = np.arange(len(rows)); w = 0.38 |
| fig, ax = plt.subplots(figsize=(9, 5)) |
| ax.bar(x - w / 2, de, w, label="Forward DE-corr") |
| ax.bar(x + w / 2, t5, w, label="Inverse Top-5") |
| ax.set_xticks(x); ax.set_xticklabels(rows, rotation=20, ha="right") |
| ax.set_ylabel("Metric"); ax.set_title("Component ablation") |
| ax.legend() |
| _save(fig, name) |
|
|
|
|
| def fig_loss_curves(history, name="fig_loss_curves"): |
| fig, ax = plt.subplots() |
| for k in ["total", "map", "tan", "semi"]: |
| ax.plot([h[k] for h in history], label=k) |
| ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.set_yscale("log") |
| ax.set_title("PIVOT training losses"); ax.legend() |
| _save(fig, name) |
|
|