File size: 7,631 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """master experiment runner, trains pivot, runs baselines, assembles every results
table. each table function returns a json-serializable dict and is saved to
experiments/results/<table>.json. nothing is hard-coded, every number comes from
these calls.
"""
from __future__ import annotations
import json
import os
import time
import numpy as np
import torch
from src.data.perturb_data import load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.evaluation.baselines import BASELINES, build_baseline
from src.evaluation.rewards import TargetStateClassifier
from src.experiments.predictors import PivotPredictor, BaselinePredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.experiments.combinatorial_eval import evaluate_combinatorial
from src.utils.common import save_json
RESULTS_DIR = "experiments/results"
DEVICE_INDEX = int(os.environ.get("PIVOT_GPU", "3"))
class ModelCache:
"""train-once cache keyed by config tuple."""
def __init__(self, dataset, embedding="pca"):
self.data = load_dataset(dataset, embedding=embedding)
self.dataset = dataset
self.embedding = embedding
self._cache = {}
def get(self, split, **kw):
key = (split, self.embedding, tuple(sorted(kw.items())))
if key in self._cache:
return self._cache[key]
cfg = TrainConfig(dataset=self.dataset, embedding=self.embedding, split=split,
device_index=DEVICE_INDEX, **kw)
model, info = train(cfg, data=self.data, verbose=False)
self._cache[key] = (model, info, cfg)
return model, info, cfg
def candidates_singles(data):
return [p for p in data.perturbations if len(data.parse(p)) == 1]
def get_classifier(data, target_perts, control_pool, device, seed=0):
"""train one target-state classifier: positives = union of target cells,
negatives = control + random other-perturbation cells (for r_clf / target-clf)."""
rng = np.random.default_rng(seed)
pos = np.concatenate([data.pert_to_idx[p] for p in target_perts])
pos = pos[rng.choice(len(pos), min(4000, len(pos)), replace=False)]
neg_pool = np.concatenate([control_pool,
np.concatenate([data.pert_to_idx[p] for p in
rng.choice(list(set(data.perturbations) - set(target_perts)),
size=min(20, len(data.perturbations)), replace=False)])])
neg = neg_pool[rng.choice(len(neg_pool), min(4000, len(neg_pool)), replace=False)]
clf = TargetStateClassifier(data.d).to(device)
clf.fit(data.emb[pos], data.emb[neg], device)
return clf
def save_table(name, obj):
os.makedirs(RESULTS_DIR, exist_ok=True)
save_json(obj, os.path.join(RESULTS_DIR, f"{name}.json"))
print(f" saved {RESULTS_DIR}/{name}.json")
# table 2 / 3 / 10 / 11: forward prediction
def table_forward(cache: ModelCache, split: str, max_perts=80, baseline_names=None):
data = cache.data
model, info, cfg = cache.get(split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
test_perts = list(sp["test_perts"]) if split != "cell" else candidates_singles(data)
control_pool = sp["test_idx"][data.is_control[sp["test_idx"]]]
if len(control_pool) < 50:
control_pool = data.control_idx
out = {"split": split, "models": {}}
pred = PivotPredictor(model, data, device)
out["models"]["PIVOT"] = evaluate_forward(pred, data, test_perts, control_pool, max_perts=max_perts)
for bname in (baseline_names or ["MeanControl", "AvgPerturbationEffect", "Additive",
"LinearResponse", "NearestPerturbationCentroid", "EndpointMLP", "Random"]):
bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"])
out["models"][bname] = evaluate_forward(BaselinePredictor(bl), data, test_perts,
control_pool, max_perts=max_perts)
return out
# table 4 / 6: single-gene desired-state nomination & recovery
def table_nomination(cache: ModelCache, split: str, max_targets=40, reward_kind="centroid",
with_guidance=True, baseline_names=None):
data = cache.data
model, info, cfg = cache.get(split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
targets = [p for p in (list(sp["test_perts"]) if split != "cell" else cands)
if len(data.parse(p)) == 1 and p in cands][:max_targets]
control_pool = data.control_idx
gene_cluster = data.functional_clusters(seed=cfg.seed)
out = {"split": split, "reward": reward_kind, "n_candidates": len(cands),
"n_targets": len(targets), "methods": {}}
pred = PivotPredictor(model, data, device)
out["methods"]["PIVOT-ranking"] = evaluate_nomination(
pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="ranking",
gene_cluster=gene_cluster, device=device)
if with_guidance:
out["methods"]["PIVOT-guidance"] = evaluate_nomination(
pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="guidance",
gene_cluster=gene_cluster, model=model, device=device)
for bname in (baseline_names or ["Additive", "LinearResponse", "NearestPerturbationCentroid",
"EndpointMLP", "Random"]):
bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"])
out["methods"][f"{bname}+ranking"] = evaluate_nomination(
BaselinePredictor(bl), data, targets, cands, control_pool, reward_kind=reward_kind,
method="ranking", gene_cluster=gene_cluster, device=device)
return out
# table 7: combinatorial nomination
def table_combinatorial(cache: ModelCache, split="combination", max_targets=26):
data = cache.data
model, info, cfg = cache.get(split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 2][:max_targets]
combo_cands = data.combos
gene_pool = [data.parse(p)[0] for p in candidates_singles(data)]
control_pool = data.control_idx
additive = build_baseline("Additive").fit(data, sp["train_perts"], sp["train_idx"])
pred = PivotPredictor(model, data, device)
out = {"split": split, "n_targets": len(targets), "n_combo_candidates": len(combo_cands)}
out["result"] = evaluate_combinatorial(pred, data, targets, combo_cands, gene_pool,
control_pool, model, device, additive=additive)
return out
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", default="norman")
ap.add_argument("--tables", nargs="+", default=["forward_cell"])
args = ap.parse_args()
cache = ModelCache(args.dataset)
t0 = time.time()
for tname in args.tables:
print(f"=== {tname} ===")
if tname.startswith("forward_"):
save_table(f"{args.dataset}_{tname}", table_forward(cache, tname.split("_", 1)[1]))
elif tname.startswith("nom_"):
save_table(f"{args.dataset}_{tname}", table_nomination(cache, tname.split("_", 1)[1]))
elif tname == "combinatorial":
save_table(f"{args.dataset}_combinatorial", table_combinatorial(cache))
print(f"done in {time.time()-t0:.1f}s")
|