"""master experiment runner, trains pivot, runs baselines, assembles every results table. each table function returns a json-serializable dict and is saved to experiments/results/.json. nothing is hard-coded, every number comes from these calls. """ from __future__ import annotations import json import os import time import numpy as np import torch from src.data.perturb_data import load_dataset from src.data.splits import load_split from src.training.train import TrainConfig, train from src.evaluation.baselines import BASELINES, build_baseline from src.evaluation.rewards import TargetStateClassifier from src.experiments.predictors import PivotPredictor, BaselinePredictor from src.experiments.forward_eval import evaluate_forward from src.experiments.nomination_eval import evaluate_nomination from src.experiments.combinatorial_eval import evaluate_combinatorial from src.utils.common import save_json RESULTS_DIR = "experiments/results" DEVICE_INDEX = int(os.environ.get("PIVOT_GPU", "3")) class ModelCache: """train-once cache keyed by config tuple.""" def __init__(self, dataset, embedding="pca"): self.data = load_dataset(dataset, embedding=embedding) self.dataset = dataset self.embedding = embedding self._cache = {} def get(self, split, **kw): key = (split, self.embedding, tuple(sorted(kw.items()))) if key in self._cache: return self._cache[key] cfg = TrainConfig(dataset=self.dataset, embedding=self.embedding, split=split, device_index=DEVICE_INDEX, **kw) model, info = train(cfg, data=self.data, verbose=False) self._cache[key] = (model, info, cfg) return model, info, cfg def candidates_singles(data): return [p for p in data.perturbations if len(data.parse(p)) == 1] def get_classifier(data, target_perts, control_pool, device, seed=0): """train one target-state classifier: positives = union of target cells, negatives = control + random other-perturbation cells (for r_clf / target-clf).""" rng = np.random.default_rng(seed) pos = np.concatenate([data.pert_to_idx[p] for p in target_perts]) pos = pos[rng.choice(len(pos), min(4000, len(pos)), replace=False)] neg_pool = np.concatenate([control_pool, np.concatenate([data.pert_to_idx[p] for p in rng.choice(list(set(data.perturbations) - set(target_perts)), size=min(20, len(data.perturbations)), replace=False)])]) neg = neg_pool[rng.choice(len(neg_pool), min(4000, len(neg_pool)), replace=False)] clf = TargetStateClassifier(data.d).to(device) clf.fit(data.emb[pos], data.emb[neg], device) return clf def save_table(name, obj): os.makedirs(RESULTS_DIR, exist_ok=True) save_json(obj, os.path.join(RESULTS_DIR, f"{name}.json")) print(f" saved {RESULTS_DIR}/{name}.json") # table 2 / 3 / 10 / 11: forward prediction def table_forward(cache: ModelCache, split: str, max_perts=80, baseline_names=None): data = cache.data model, info, cfg = cache.get(split) device = next(model.parameters()).device sp = load_split(data.dir, split) test_perts = list(sp["test_perts"]) if split != "cell" else candidates_singles(data) control_pool = sp["test_idx"][data.is_control[sp["test_idx"]]] if len(control_pool) < 50: control_pool = data.control_idx out = {"split": split, "models": {}} pred = PivotPredictor(model, data, device) out["models"]["PIVOT"] = evaluate_forward(pred, data, test_perts, control_pool, max_perts=max_perts) for bname in (baseline_names or ["MeanControl", "AvgPerturbationEffect", "Additive", "LinearResponse", "NearestPerturbationCentroid", "EndpointMLP", "Random"]): bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"]) out["models"][bname] = evaluate_forward(BaselinePredictor(bl), data, test_perts, control_pool, max_perts=max_perts) return out # table 4 / 6: single-gene desired-state nomination & recovery def table_nomination(cache: ModelCache, split: str, max_targets=40, reward_kind="centroid", with_guidance=True, baseline_names=None): data = cache.data model, info, cfg = cache.get(split) device = next(model.parameters()).device sp = load_split(data.dir, split) cands = candidates_singles(data) targets = [p for p in (list(sp["test_perts"]) if split != "cell" else cands) if len(data.parse(p)) == 1 and p in cands][:max_targets] control_pool = data.control_idx gene_cluster = data.functional_clusters(seed=cfg.seed) out = {"split": split, "reward": reward_kind, "n_candidates": len(cands), "n_targets": len(targets), "methods": {}} pred = PivotPredictor(model, data, device) out["methods"]["PIVOT-ranking"] = evaluate_nomination( pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="ranking", gene_cluster=gene_cluster, device=device) if with_guidance: out["methods"]["PIVOT-guidance"] = evaluate_nomination( pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="guidance", gene_cluster=gene_cluster, model=model, device=device) for bname in (baseline_names or ["Additive", "LinearResponse", "NearestPerturbationCentroid", "EndpointMLP", "Random"]): bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"]) out["methods"][f"{bname}+ranking"] = evaluate_nomination( BaselinePredictor(bl), data, targets, cands, control_pool, reward_kind=reward_kind, method="ranking", gene_cluster=gene_cluster, device=device) return out # table 7: combinatorial nomination def table_combinatorial(cache: ModelCache, split="combination", max_targets=26): data = cache.data model, info, cfg = cache.get(split) device = next(model.parameters()).device sp = load_split(data.dir, split) targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 2][:max_targets] combo_cands = data.combos gene_pool = [data.parse(p)[0] for p in candidates_singles(data)] control_pool = data.control_idx additive = build_baseline("Additive").fit(data, sp["train_perts"], sp["train_idx"]) pred = PivotPredictor(model, data, device) out = {"split": split, "n_targets": len(targets), "n_combo_candidates": len(combo_cands)} out["result"] = evaluate_combinatorial(pred, data, targets, combo_cands, gene_pool, control_pool, model, device, additive=additive) return out if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("--dataset", default="norman") ap.add_argument("--tables", nargs="+", default=["forward_cell"]) args = ap.parse_args() cache = ModelCache(args.dataset) t0 = time.time() for tname in args.tables: print(f"=== {tname} ===") if tname.startswith("forward_"): save_table(f"{args.dataset}_{tname}", table_forward(cache, tname.split("_", 1)[1])) elif tname.startswith("nom_"): save_table(f"{args.dataset}_{tname}", table_nomination(cache, tname.split("_", 1)[1])) elif tname == "combinatorial": save_table(f"{args.dataset}_combinatorial", table_combinatorial(cache)) print(f"done in {time.time()-t0:.1f}s")