File size: 4,559 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""publication-quality figures. saves each one as both png and pdf to figures/.
reads only computed artifacts (result jsons, trained model).
"""
from __future__ import annotations

import json
import os

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams.update({
    "font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14,
    "xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 10,
    "figure.figsize": (8, 6), "figure.dpi": 150, "savefig.dpi": 300,
    "savefig.bbox": "tight", "axes.spines.top": False, "axes.spines.right": False,
})
FIG = "figures"
RES = "experiments/results"


def _save(fig, name):
    os.makedirs(FIG, exist_ok=True)
    fig.savefig(os.path.join(FIG, f"{name}.png"))
    fig.savefig(os.path.join(FIG, f"{name}.pdf"))
    plt.close(fig)
    print(f"  figure -> {FIG}/{name}.{{png,pdf}}")


def _load(name):
    with open(os.path.join(RES, f"{name}.json")) as fh:
        return json.load(fh)


def fig_embedding_panels(model, data, targets, device, name="fig1_embedding_panels"):
    """figure 1: control / target / pivot-predicted-endpoint / observed cells in 2d."""
    import torch
    from src.evaluation import inference as inf

    # 2d projection: first 2 pcs of the cell-state embedding
    proj = data.emb[:, :2]
    fig, axes = plt.subplots(1, len(targets), figsize=(5 * len(targets), 4.5))
    if len(targets) == 1:
        axes = [axes]
    rng = np.random.default_rng(0)
    ctrl = data.control_idx
    c0e = torch.as_tensor(data.emb[rng.choice(ctrl, 300, replace=False)], dtype=torch.float32, device=device)
    for ax, p in zip(axes, targets):
        ci = rng.choice(ctrl, 400, replace=False)
        ti = data.pert_to_idx[p]
        ax.scatter(proj[ci, 0], proj[ci, 1], s=6, c="#bbbbbb", alpha=0.5, label="control")
        ax.scatter(proj[ti, 0], proj[ti, 1], s=8, c="#1f77b4", alpha=0.6, label=f"observed {p}")
        # pivot predicted endpoints for the true perturbation
        e = inf.encode_label(model, data, p, device)
        chat = inf.forward_predict(model, c0e, e).cpu().numpy()
        ax.scatter(chat[:, 0], chat[:, 1], s=10, c="#d62728", alpha=0.7, marker="x", label="PIVOT predicted")
        ax.set_title(p)
        ax.set_xlabel("PC1"); ax.set_ylabel("PC2")
        ax.legend(loc="best", markerscale=1.5)
    fig.suptitle("Control, observed perturbed, and PIVOT-predicted endpoints (PC space)")
    _save(fig, name)


def fig_guidance_steps(dataset="norman", name="fig_guidance_steps"):
    obj = _load(f"{dataset}_ablation_guidance_steps")
    steps = [int(k) for k in obj["rows"]]
    top5 = [obj["rows"][str(s)]["top5"] for s in steps]
    ndcg = [obj["rows"][str(s)]["ndcg"] for s in steps]
    fig, ax = plt.subplots()
    ax.plot(steps, top5, "o-", label="Top-5")
    ax.plot(steps, ndcg, "s--", label="nDCG@10")
    ax.set_xlabel("Reward-guidance steps L"); ax.set_ylabel("Recovery metric")
    ax.set_title("Effect of reward-guidance steps")
    ax.legend()
    _save(fig, name)


def fig_datascale(dataset="norman", name="fig_datascale"):
    obj = _load(f"{dataset}_ablation_datascale")
    fr = sorted(float(k) for k in obj["rows"])
    de = [obj["rows"][str(f)]["forward"]["de_corr"] for f in fr]
    t5 = [obj["rows"][str(f)]["inverse"]["top5"] for f in fr]
    fig, ax = plt.subplots()
    ax.plot([f * 100 for f in fr], de, "o-", label="Forward DE-corr")
    ax.plot([f * 100 for f in fr], t5, "s--", label="Inverse Top-5")
    ax.set_xlabel("Training data (%)"); ax.set_ylabel("Metric")
    ax.set_title("Data scaling")
    ax.legend()
    _save(fig, name)


def fig_components(dataset="norman", name="fig_components"):
    obj = _load(f"{dataset}_ablation_components")
    rows = list(obj["rows"])
    t5 = [obj["rows"][r]["inverse"]["top5"] for r in rows]
    de = [obj["rows"][r]["forward"]["de_corr"] for r in rows]
    x = np.arange(len(rows)); w = 0.38
    fig, ax = plt.subplots(figsize=(9, 5))
    ax.bar(x - w / 2, de, w, label="Forward DE-corr")
    ax.bar(x + w / 2, t5, w, label="Inverse Top-5")
    ax.set_xticks(x); ax.set_xticklabels(rows, rotation=20, ha="right")
    ax.set_ylabel("Metric"); ax.set_title("Component ablation")
    ax.legend()
    _save(fig, name)


def fig_loss_curves(history, name="fig_loss_curves"):
    fig, ax = plt.subplots()
    for k in ["total", "map", "tan", "semi"]:
        ax.plot([h[k] for h in history], label=k)
    ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.set_yscale("log")
    ax.set_title("PIVOT training losses"); ax.legend()
    _save(fig, name)