"""master experiment runner, trains pivot, runs baselines, assembles every results
table. each table function returns a json-serializable dict and is saved to
experiments/results/
.json. nothing is hard-coded, every number comes from
these calls.
"""
from __future__ import annotations
import json
import os
import time
import numpy as np
import torch
from src.data.perturb_data import load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.evaluation.baselines import BASELINES, build_baseline
from src.evaluation.rewards import TargetStateClassifier
from src.experiments.predictors import PivotPredictor, BaselinePredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.experiments.combinatorial_eval import evaluate_combinatorial
from src.utils.common import save_json
RESULTS_DIR = "experiments/results"
DEVICE_INDEX = int(os.environ.get("PIVOT_GPU", "3"))
class ModelCache:
"""train-once cache keyed by config tuple."""
def __init__(self, dataset, embedding="pca"):
self.data = load_dataset(dataset, embedding=embedding)
self.dataset = dataset
self.embedding = embedding
self._cache = {}
def get(self, split, **kw):
key = (split, self.embedding, tuple(sorted(kw.items())))
if key in self._cache:
return self._cache[key]
cfg = TrainConfig(dataset=self.dataset, embedding=self.embedding, split=split,
device_index=DEVICE_INDEX, **kw)
model, info = train(cfg, data=self.data, verbose=False)
self._cache[key] = (model, info, cfg)
return model, info, cfg
def candidates_singles(data):
return [p for p in data.perturbations if len(data.parse(p)) == 1]
def get_classifier(data, target_perts, control_pool, device, seed=0):
"""train one target-state classifier: positives = union of target cells,
negatives = control + random other-perturbation cells (for r_clf / target-clf)."""
rng = np.random.default_rng(seed)
pos = np.concatenate([data.pert_to_idx[p] for p in target_perts])
pos = pos[rng.choice(len(pos), min(4000, len(pos)), replace=False)]
neg_pool = np.concatenate([control_pool,
np.concatenate([data.pert_to_idx[p] for p in
rng.choice(list(set(data.perturbations) - set(target_perts)),
size=min(20, len(data.perturbations)), replace=False)])])
neg = neg_pool[rng.choice(len(neg_pool), min(4000, len(neg_pool)), replace=False)]
clf = TargetStateClassifier(data.d).to(device)
clf.fit(data.emb[pos], data.emb[neg], device)
return clf
def save_table(name, obj):
os.makedirs(RESULTS_DIR, exist_ok=True)
save_json(obj, os.path.join(RESULTS_DIR, f"{name}.json"))
print(f" saved {RESULTS_DIR}/{name}.json")
# table 2 / 3 / 10 / 11: forward prediction
def table_forward(cache: ModelCache, split: str, max_perts=80, baseline_names=None):
data = cache.data
model, info, cfg = cache.get(split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
test_perts = list(sp["test_perts"]) if split != "cell" else candidates_singles(data)
control_pool = sp["test_idx"][data.is_control[sp["test_idx"]]]
if len(control_pool) < 50:
control_pool = data.control_idx
out = {"split": split, "models": {}}
pred = PivotPredictor(model, data, device)
out["models"]["PIVOT"] = evaluate_forward(pred, data, test_perts, control_pool, max_perts=max_perts)
for bname in (baseline_names or ["MeanControl", "AvgPerturbationEffect", "Additive",
"LinearResponse", "NearestPerturbationCentroid", "EndpointMLP", "Random"]):
bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"])
out["models"][bname] = evaluate_forward(BaselinePredictor(bl), data, test_perts,
control_pool, max_perts=max_perts)
return out
# table 4 / 6: single-gene desired-state nomination & recovery
def table_nomination(cache: ModelCache, split: str, max_targets=40, reward_kind="centroid",
with_guidance=True, baseline_names=None):
data = cache.data
model, info, cfg = cache.get(split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
targets = [p for p in (list(sp["test_perts"]) if split != "cell" else cands)
if len(data.parse(p)) == 1 and p in cands][:max_targets]
control_pool = data.control_idx
gene_cluster = data.functional_clusters(seed=cfg.seed)
out = {"split": split, "reward": reward_kind, "n_candidates": len(cands),
"n_targets": len(targets), "methods": {}}
pred = PivotPredictor(model, data, device)
out["methods"]["PIVOT-ranking"] = evaluate_nomination(
pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="ranking",
gene_cluster=gene_cluster, device=device)
if with_guidance:
out["methods"]["PIVOT-guidance"] = evaluate_nomination(
pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="guidance",
gene_cluster=gene_cluster, model=model, device=device)
for bname in (baseline_names or ["Additive", "LinearResponse", "NearestPerturbationCentroid",
"EndpointMLP", "Random"]):
bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"])
out["methods"][f"{bname}+ranking"] = evaluate_nomination(
BaselinePredictor(bl), data, targets, cands, control_pool, reward_kind=reward_kind,
method="ranking", gene_cluster=gene_cluster, device=device)
return out
# table 7: combinatorial nomination
def table_combinatorial(cache: ModelCache, split="combination", max_targets=26):
data = cache.data
model, info, cfg = cache.get(split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 2][:max_targets]
combo_cands = data.combos
gene_pool = [data.parse(p)[0] for p in candidates_singles(data)]
control_pool = data.control_idx
additive = build_baseline("Additive").fit(data, sp["train_perts"], sp["train_idx"])
pred = PivotPredictor(model, data, device)
out = {"split": split, "n_targets": len(targets), "n_combo_candidates": len(combo_cands)}
out["result"] = evaluate_combinatorial(pred, data, targets, combo_cands, gene_pool,
control_pool, model, device, additive=additive)
return out
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", default="norman")
ap.add_argument("--tables", nargs="+", default=["forward_cell"])
args = ap.parse_args()
cache = ModelCache(args.dataset)
t0 = time.time()
for tname in args.tables:
print(f"=== {tname} ===")
if tname.startswith("forward_"):
save_table(f"{args.dataset}_{tname}", table_forward(cache, tname.split("_", 1)[1]))
elif tname.startswith("nom_"):
save_table(f"{args.dataset}_{tname}", table_nomination(cache, tname.split("_", 1)[1]))
elif tname == "combinatorial":
save_table(f"{args.dataset}_combinatorial", table_combinatorial(cache))
print(f"done in {time.time()-t0:.1f}s")