| """master experiment runner, trains pivot, runs baselines, assembles every results |
| table. each table function returns a json-serializable dict and is saved to |
| experiments/results/<table>.json. nothing is hard-coded, every number comes from |
| these calls. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import time |
|
|
| import numpy as np |
| import torch |
|
|
| from src.data.perturb_data import load_dataset |
| from src.data.splits import load_split |
| from src.training.train import TrainConfig, train |
| from src.evaluation.baselines import BASELINES, build_baseline |
| from src.evaluation.rewards import TargetStateClassifier |
| from src.experiments.predictors import PivotPredictor, BaselinePredictor |
| from src.experiments.forward_eval import evaluate_forward |
| from src.experiments.nomination_eval import evaluate_nomination |
| from src.experiments.combinatorial_eval import evaluate_combinatorial |
| from src.utils.common import save_json |
|
|
| RESULTS_DIR = "experiments/results" |
| DEVICE_INDEX = int(os.environ.get("PIVOT_GPU", "3")) |
|
|
|
|
| class ModelCache: |
| """train-once cache keyed by config tuple.""" |
|
|
| def __init__(self, dataset, embedding="pca"): |
| self.data = load_dataset(dataset, embedding=embedding) |
| self.dataset = dataset |
| self.embedding = embedding |
| self._cache = {} |
|
|
| def get(self, split, **kw): |
| key = (split, self.embedding, tuple(sorted(kw.items()))) |
| if key in self._cache: |
| return self._cache[key] |
| cfg = TrainConfig(dataset=self.dataset, embedding=self.embedding, split=split, |
| device_index=DEVICE_INDEX, **kw) |
| model, info = train(cfg, data=self.data, verbose=False) |
| self._cache[key] = (model, info, cfg) |
| return model, info, cfg |
|
|
|
|
| def candidates_singles(data): |
| return [p for p in data.perturbations if len(data.parse(p)) == 1] |
|
|
|
|
| def get_classifier(data, target_perts, control_pool, device, seed=0): |
| """train one target-state classifier: positives = union of target cells, |
| negatives = control + random other-perturbation cells (for r_clf / target-clf).""" |
| rng = np.random.default_rng(seed) |
| pos = np.concatenate([data.pert_to_idx[p] for p in target_perts]) |
| pos = pos[rng.choice(len(pos), min(4000, len(pos)), replace=False)] |
| neg_pool = np.concatenate([control_pool, |
| np.concatenate([data.pert_to_idx[p] for p in |
| rng.choice(list(set(data.perturbations) - set(target_perts)), |
| size=min(20, len(data.perturbations)), replace=False)])]) |
| neg = neg_pool[rng.choice(len(neg_pool), min(4000, len(neg_pool)), replace=False)] |
| clf = TargetStateClassifier(data.d).to(device) |
| clf.fit(data.emb[pos], data.emb[neg], device) |
| return clf |
|
|
|
|
| def save_table(name, obj): |
| os.makedirs(RESULTS_DIR, exist_ok=True) |
| save_json(obj, os.path.join(RESULTS_DIR, f"{name}.json")) |
| print(f" saved {RESULTS_DIR}/{name}.json") |
|
|
|
|
| |
| def table_forward(cache: ModelCache, split: str, max_perts=80, baseline_names=None): |
| data = cache.data |
| model, info, cfg = cache.get(split) |
| device = next(model.parameters()).device |
| sp = load_split(data.dir, split) |
| test_perts = list(sp["test_perts"]) if split != "cell" else candidates_singles(data) |
| control_pool = sp["test_idx"][data.is_control[sp["test_idx"]]] |
| if len(control_pool) < 50: |
| control_pool = data.control_idx |
|
|
| out = {"split": split, "models": {}} |
| pred = PivotPredictor(model, data, device) |
| out["models"]["PIVOT"] = evaluate_forward(pred, data, test_perts, control_pool, max_perts=max_perts) |
| for bname in (baseline_names or ["MeanControl", "AvgPerturbationEffect", "Additive", |
| "LinearResponse", "NearestPerturbationCentroid", "EndpointMLP", "Random"]): |
| bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"]) |
| out["models"][bname] = evaluate_forward(BaselinePredictor(bl), data, test_perts, |
| control_pool, max_perts=max_perts) |
| return out |
|
|
|
|
| |
| def table_nomination(cache: ModelCache, split: str, max_targets=40, reward_kind="centroid", |
| with_guidance=True, baseline_names=None): |
| data = cache.data |
| model, info, cfg = cache.get(split) |
| device = next(model.parameters()).device |
| sp = load_split(data.dir, split) |
| cands = candidates_singles(data) |
| targets = [p for p in (list(sp["test_perts"]) if split != "cell" else cands) |
| if len(data.parse(p)) == 1 and p in cands][:max_targets] |
| control_pool = data.control_idx |
| gene_cluster = data.functional_clusters(seed=cfg.seed) |
|
|
| out = {"split": split, "reward": reward_kind, "n_candidates": len(cands), |
| "n_targets": len(targets), "methods": {}} |
| pred = PivotPredictor(model, data, device) |
| out["methods"]["PIVOT-ranking"] = evaluate_nomination( |
| pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="ranking", |
| gene_cluster=gene_cluster, device=device) |
| if with_guidance: |
| out["methods"]["PIVOT-guidance"] = evaluate_nomination( |
| pred, data, targets, cands, control_pool, reward_kind=reward_kind, method="guidance", |
| gene_cluster=gene_cluster, model=model, device=device) |
| for bname in (baseline_names or ["Additive", "LinearResponse", "NearestPerturbationCentroid", |
| "EndpointMLP", "Random"]): |
| bl = build_baseline(bname).fit(data, sp["train_perts"], sp["train_idx"]) |
| out["methods"][f"{bname}+ranking"] = evaluate_nomination( |
| BaselinePredictor(bl), data, targets, cands, control_pool, reward_kind=reward_kind, |
| method="ranking", gene_cluster=gene_cluster, device=device) |
| return out |
|
|
|
|
| |
| def table_combinatorial(cache: ModelCache, split="combination", max_targets=26): |
| data = cache.data |
| model, info, cfg = cache.get(split) |
| device = next(model.parameters()).device |
| sp = load_split(data.dir, split) |
| targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 2][:max_targets] |
| combo_cands = data.combos |
| gene_pool = [data.parse(p)[0] for p in candidates_singles(data)] |
| control_pool = data.control_idx |
| additive = build_baseline("Additive").fit(data, sp["train_perts"], sp["train_idx"]) |
| pred = PivotPredictor(model, data, device) |
| out = {"split": split, "n_targets": len(targets), "n_combo_candidates": len(combo_cands)} |
| out["result"] = evaluate_combinatorial(pred, data, targets, combo_cands, gene_pool, |
| control_pool, model, device, additive=additive) |
| return out |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--dataset", default="norman") |
| ap.add_argument("--tables", nargs="+", default=["forward_cell"]) |
| args = ap.parse_args() |
| cache = ModelCache(args.dataset) |
| t0 = time.time() |
| for tname in args.tables: |
| print(f"=== {tname} ===") |
| if tname.startswith("forward_"): |
| save_table(f"{args.dataset}_{tname}", table_forward(cache, tname.split("_", 1)[1])) |
| elif tname.startswith("nom_"): |
| save_table(f"{args.dataset}_{tname}", table_nomination(cache, tname.split("_", 1)[1])) |
| elif tname == "combinatorial": |
| save_table(f"{args.dataset}_combinatorial", table_combinatorial(cache)) |
| print(f"done in {time.time()-t0:.1f}s") |
|
|