PIVOT / src /experiments /run_ablations.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
9.63 kB
"""ablation studies (tables 8, 12-18, 20).
every ablation is a real training/eval run on real data. epoch budget (EPOCHS) is
held fixed across rows of a table for fair comparison.
"""
from __future__ import annotations
import os
import time
import numpy as np
import torch
from src.data.perturb_data import PerturbData, load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.evaluation.baselines import build_baseline
from src.experiments.predictors import PivotPredictor, BaselinePredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.experiments.run_tables import (RESULTS_DIR, DEVICE_INDEX, candidates_singles,
save_table)
from src.models.encoders import REP_MODES
from src.utils.common import save_json
EPOCHS = 45
def _train(data, split, **kw):
cfg = TrainConfig(dataset=data.meta["name"], split=split, epochs=EPOCHS,
device_index=DEVICE_INDEX, **kw)
model, info = train(cfg, data=data, verbose=False)
return model, info, cfg
def _fwd_inv(data, model, split, max_perts=50, max_targets=30, reward="centroid"):
"""compact forward + inverse metrics for one model (for ablation rows)."""
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
test_perts = list(sp["test_perts"]) if split != "cell" else cands
targets = [p for p in test_perts if len(data.parse(p)) == 1 and p in cands][:max_targets]
control_pool = data.control_idx
gc = data.functional_clusters(seed=0)
pred = PivotPredictor(model, data, device)
fwd = evaluate_forward(pred, data, [p for p in test_perts][:max_perts], control_pool, max_perts=max_perts)
inv = evaluate_nomination(pred, data, targets, cands, control_pool, reward_kind=reward,
method="ranking", gene_cluster=gc, device=device)
return {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]},
"inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}}
# table 8: component ablation
def ablation_components(data, split="perturbation"):
variants = {
"flow-map-only": ["map"],
"no-tangent": ["map", "semi"],
"no-semigroup": ["map", "tan"],
"PIVOT-full": ["map", "tan", "semi"],
}
out = {"split": split, "rows": {}}
for name, comps in variants.items():
model, info, cfg = _train(data, split, components=comps)
out["rows"][name] = _fwd_inv(data, model, split)
# endpointmlp baseline (no flow map)
sp = load_split(data.dir, split)
bl = build_baseline("EndpointMLP").fit(data, sp["train_perts"], sp["train_idx"])
device = torch.device(f"cuda:{DEVICE_INDEX}")
cands = candidates_singles(data)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
gc = data.functional_clusters(seed=0)
fwd = evaluate_forward(BaselinePredictor(bl), data, list(sp["test_perts"])[:50], data.control_idx, max_perts=50)
inv = evaluate_nomination(BaselinePredictor(bl), data, targets, cands, data.control_idx,
method="ranking", gene_cluster=gc, device=device)
out["rows"]["EndpointMLP"] = {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]},
"inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}}
return out
# table 12: perturbation representation
def ablation_representation(data, split="perturbation"):
out = {"split": split, "rows": {}}
for rep in REP_MODES:
model, info, cfg = _train(data, split, rep_mode=rep)
out["rows"][rep] = _fwd_inv(data, model, split)
return out
# table 13: inverse-search strategy (single model)
def ablation_inverse_search(data, split="perturbation"):
model, info, cfg = _train(data, split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
gc = data.functional_clusters(seed=0)
pred = PivotPredictor(model, data, device)
common = dict(reward_kind="centroid", gene_cluster=gc, model=model, device=device)
strategies = {
"ranking-only": dict(method="ranking"),
"random-init-opt": dict(method="guidance", guidance_init="random", rerank=False),
"mean-top-init": dict(method="guidance", guidance_init="mean_top", rerank=False),
"guidance-no-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=False, rerank=False),
"guidance-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=False),
"guidance+rerank": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=True),
}
out = {"split": split, "rows": {}}
for name, kw in strategies.items():
r = evaluate_nomination(pred, data, targets, cands, data.control_idx, **common, **kw)
out["rows"][name] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
return out
# table 14: guidance steps
def ablation_guidance_steps(data, split="perturbation"):
model, info, cfg = _train(data, split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
gc = data.functional_clusters(seed=0)
pred = PivotPredictor(model, data, device)
out = {"split": split, "rows": {}}
for steps in [0, 5, 10, 25, 50, 100]:
if steps == 0:
r = evaluate_nomination(pred, data, targets, cands, data.control_idx,
method="ranking", gene_cluster=gc, device=device)
else:
r = evaluate_nomination(pred, data, targets, cands, data.control_idx, method="guidance",
gene_cluster=gc, model=model, device=device,
guidance_steps=steps, rerank=True)
out["rows"][str(steps)] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
return out
# table 15: reward definitions
def ablation_reward(data, split="perturbation"):
model, info, cfg = _train(data, split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:25]
gc = data.functional_clusters(seed=0)
pred = PivotPredictor(model, data, device)
out = {"split": split, "rows": {}}
for rk in ["centroid", "cosine", "nn_target", "mmd", "wasserstein"]:
r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=rk,
method="ranking", gene_cluster=gc, device=device)
out["rows"][rk] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
return out
# table 17: data scaling
def ablation_datascale(data, split="perturbation"):
out = {"split": split, "rows": {}}
for frac in [0.1, 0.25, 0.5, 0.75, 1.0]:
model, info, cfg = _train(data, split, train_frac=frac)
out["rows"][str(frac)] = _fwd_inv(data, model, split)
return out
# table 18: control matching
def ablation_matching(data, split="perturbation"):
from src.data.perturb_data import MATCH_STRATEGIES
out = {"split": split, "rows": {}}
for ms in MATCH_STRATEGIES:
model, info, cfg = _train(data, split, match=ms)
out["rows"][ms] = _fwd_inv(data, model, split)
return out
# table 20: compute cost
def ablation_cost(data, split="perturbation"):
model, info, cfg = _train(data, split)
device = next(model.parameters()).device
n_params = sum(p.numel() for p in model.parameters())
cands = candidates_singles(data)
pred = PivotPredictor(model, data, device)
c0 = data.emb[data.control_idx[:128]]
# pivot endpoint ranking: time per query (one target, rank all candidates)
import time as _t
t0 = _t.time()
_ = [pred.population(c, c0) for c in cands]
rank_time = _t.time() - t0
out = {"split": split,
"PIVOT": {"params": int(n_params), "train_time_s": info["duration_s"],
"rank_time_per_query_s": rank_time, "candidate_evals": len(cands)}}
return out
ABLATIONS = {
"components": ablation_components, # table 8
"representation": ablation_representation, # table 12
"inverse_search": ablation_inverse_search, # table 13
"guidance_steps": ablation_guidance_steps, # table 14
"reward": ablation_reward, # table 15
"datascale": ablation_datascale, # table 17
"matching": ablation_matching, # table 18
"cost": ablation_cost, # table 20
}
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", default="norman")
ap.add_argument("--ablations", nargs="+", default=list(ABLATIONS.keys()))
args = ap.parse_args()
data = load_dataset(args.dataset)
t0 = time.time()
for name in args.ablations:
print(f"=== ablation: {name} ===", flush=True)
res = ABLATIONS[name](data)
save_table(f"{args.dataset}_ablation_{name}", res)
print(f" done {name} ({time.time()-t0:.0f}s elapsed)", flush=True)
print(f"all ablations done in {time.time()-t0:.0f}s")