File size: 7,631 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""master experiment runner, trains pivot, runs baselines, assembles every results
table. each table function returns a json-serializable dict and is saved to
experiments/results/<table>.json. nothing is hard-coded, every number comes from
these calls.
"""
from __future__ import annotations

import json
import os
import time

import numpy as np
import torch

from src.data.perturb_data import load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.evaluation.baselines import BASELINES, build_baseline
from src.evaluation.rewards import TargetStateClassifier
from src.experiments.predictors import PivotPredictor, BaselinePredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.experiments.combinatorial_eval import evaluate_combinatorial
from src.utils.common import save_json

RESULTS_DIR = "experiments/results"
DEVICE_INDEX = int(os.environ.get("PIVOT_GPU", "3"))


class ModelCache:
    """train-once cache keyed by config tuple."""

    def __init__(self, dataset, embedding="pca"):
        self.data = load_dataset(dataset, embedding=embedding)
        self.dataset = dataset
        self.embedding = embedding
        self._cache = {}

    def get(self, split, **kw):
        key = (split, self.embedding, tuple(sorted(kw.items())))
        if key in self._cache:
            return self._cache[key]
        cfg = TrainConfig(dataset=self.dataset, embedding=self.embedding, split=split,
                          device_index=DEVICE_INDEX, **kw)
        model, info = train(cfg, data=self.data, verbose=False)
        self._cache[key] = (model, info, cfg)
        return model, info, cfg


def candidates_singles(data):
    return [p for p in data.perturbations if len(data.parse(p)) == 1]


def get_classifier(data, target_perts, control_pool, device, seed=0):
    """train one target-state classifier: positives = union of target cells,
    negatives = control + random other-perturbation cells (for r_clf / target-clf)."""
    rng = np.random.default_rng(seed)
    pos = np.concatenate([data.pert_to_idx[p] for p in target_perts])
    pos = pos[rng.choice(len(pos), min(4000, len(pos)), replace=False)]
    neg_pool = np.concatenate([control_pool,
                               np.concatenate([data.pert_to_idx[p] for p in
                                               rng.choice(list(set(data.perturbations) - set(target_perts)),
                                                          size=min(20, len(data.perturbations)), replace=False)])])
    neg = neg_pool[rng.choice(len(neg_pool), min(4000, len(neg_pool)), replace=False)]
    clf = TargetStateClassifier(data.d).to(device)
    clf.fit(data.emb[pos], data.emb[neg], device)
    return clf


def save_table(name, obj):
    os.makedirs(RESULTS_DIR, exist_ok=True)
    save_json(obj, os.path.join(RESULTS_DIR, f"{name}.json"))
    print(f"  saved {RESULTS_DIR}/{name}.json")


# table 2 / 3 / 10 / 11: forward prediction
def table_forward(cache: ModelCache, split: str, max_perts=80, baseline_names=None):
    data = cache.data
    model, info, cfg = cache.get(split)
    device = next(model.parameters()).device
    sp = load_split(data.dir, split)
    test_perts = list(sp["test_perts"]) if split != "cell" else candidates_singles(data)
    control_pool = sp["test_idx"][data.is_control[sp["test_idx"]]]
    if len(control_pool) < 50:
        control_pool = data.control_idx

    out = {"split": split, "models": {}}
    pred = PivotPredictor(model, data, device)
    out["models"]["PIVOT"] = evaluate_forward(pred, data, test_perts, control_pool, max_perts=max_perts)
    for bname in (baseline_names or ["MeanControl", "AvgPerturbationEffect", "Additive",
                                     "LinearResponse", "NearestPerturbationCentroid", "EndpointMLP", "Random"]):
        bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"])
        out["models"][bname] = evaluate_forward(BaselinePredictor(bl), data, test_perts,
                                                 control_pool, max_perts=max_perts)
    return out


# table 4 / 6: single-gene desired-state nomination & recovery
def table_nomination(cache: ModelCache, split: str, max_targets=40, reward_kind="centroid",
                     with_guidance=True, baseline_names=None):
    data = cache.data
    model, info, cfg = cache.get(split)
    device = next(model.parameters()).device
    sp = load_split(data.dir, split)
    cands = candidates_singles(data)
    targets = [p for p in (list(sp["test_perts"]) if split != "cell" else cands)
               if len(data.parse(p)) == 1 and p in cands][:max_targets]
    control_pool = data.control_idx
    gene_cluster = data.functional_clusters(seed=cfg.seed)

    out = {"split": split, "reward": reward_kind, "n_candidates": len(cands),
           "n_targets": len(targets), "methods": {}}
    pred = PivotPredictor(model, data, device)
    out["methods"]["PIVOT-ranking"] = evaluate_nomination(
        pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="ranking",
        gene_cluster=gene_cluster, device=device)
    if with_guidance:
        out["methods"]["PIVOT-guidance"] = evaluate_nomination(
            pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="guidance",
            gene_cluster=gene_cluster, model=model, device=device)
    for bname in (baseline_names or ["Additive", "LinearResponse", "NearestPerturbationCentroid",
                                     "EndpointMLP", "Random"]):
        bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"])
        out["methods"][f"{bname}+ranking"] = evaluate_nomination(
            BaselinePredictor(bl), data, targets, cands, control_pool, reward_kind=reward_kind,
            method="ranking", gene_cluster=gene_cluster, device=device)
    return out


# table 7: combinatorial nomination
def table_combinatorial(cache: ModelCache, split="combination", max_targets=26):
    data = cache.data
    model, info, cfg = cache.get(split)
    device = next(model.parameters()).device
    sp = load_split(data.dir, split)
    targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 2][:max_targets]
    combo_cands = data.combos
    gene_pool = [data.parse(p)[0] for p in candidates_singles(data)]
    control_pool = data.control_idx
    additive = build_baseline("Additive").fit(data, sp["train_perts"], sp["train_idx"])
    pred = PivotPredictor(model, data, device)
    out = {"split": split, "n_targets": len(targets), "n_combo_candidates": len(combo_cands)}
    out["result"] = evaluate_combinatorial(pred, data, targets, combo_cands, gene_pool,
                                            control_pool, model, device, additive=additive)
    return out


if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", default="norman")
    ap.add_argument("--tables", nargs="+", default=["forward_cell"])
    args = ap.parse_args()
    cache = ModelCache(args.dataset)
    t0 = time.time()
    for tname in args.tables:
        print(f"=== {tname} ===")
        if tname.startswith("forward_"):
            save_table(f"{args.dataset}_{tname}", table_forward(cache, tname.split("_", 1)[1]))
        elif tname.startswith("nom_"):
            save_table(f"{args.dataset}_{tname}", table_nomination(cache, tname.split("_", 1)[1]))
        elif tname == "combinatorial":
            save_table(f"{args.dataset}_combinatorial", table_combinatorial(cache))
    print(f"done in {time.time()-t0:.1f}s")