File size: 9,633 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | """ablation studies (tables 8, 12-18, 20).
every ablation is a real training/eval run on real data. epoch budget (EPOCHS) is
held fixed across rows of a table for fair comparison.
"""
from __future__ import annotations
import os
import time
import numpy as np
import torch
from src.data.perturb_data import PerturbData, load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.evaluation.baselines import build_baseline
from src.experiments.predictors import PivotPredictor, BaselinePredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.experiments.run_tables import (RESULTS_DIR, DEVICE_INDEX, candidates_singles,
save_table)
from src.models.encoders import REP_MODES
from src.utils.common import save_json
EPOCHS = 45
def _train(data, split, **kw):
cfg = TrainConfig(dataset=data.meta["name"], split=split, epochs=EPOCHS,
device_index=DEVICE_INDEX, **kw)
model, info = train(cfg, data=data, verbose=False)
return model, info, cfg
def _fwd_inv(data, model, split, max_perts=50, max_targets=30, reward="centroid"):
"""compact forward + inverse metrics for one model (for ablation rows)."""
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
test_perts = list(sp["test_perts"]) if split != "cell" else cands
targets = [p for p in test_perts if len(data.parse(p)) == 1 and p in cands][:max_targets]
control_pool = data.control_idx
gc = data.functional_clusters(seed=0)
pred = PivotPredictor(model, data, device)
fwd = evaluate_forward(pred, data, [p for p in test_perts][:max_perts], control_pool, max_perts=max_perts)
inv = evaluate_nomination(pred, data, targets, cands, control_pool, reward_kind=reward,
method="ranking", gene_cluster=gc, device=device)
return {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]},
"inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}}
# table 8: component ablation
def ablation_components(data, split="perturbation"):
variants = {
"flow-map-only": ["map"],
"no-tangent": ["map", "semi"],
"no-semigroup": ["map", "tan"],
"PIVOT-full": ["map", "tan", "semi"],
}
out = {"split": split, "rows": {}}
for name, comps in variants.items():
model, info, cfg = _train(data, split, components=comps)
out["rows"][name] = _fwd_inv(data, model, split)
# endpointmlp baseline (no flow map)
sp = load_split(data.dir, split)
bl = build_baseline("EndpointMLP").fit(data, sp["train_perts"], sp["train_idx"])
device = torch.device(f"cuda:{DEVICE_INDEX}")
cands = candidates_singles(data)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
gc = data.functional_clusters(seed=0)
fwd = evaluate_forward(BaselinePredictor(bl), data, list(sp["test_perts"])[:50], data.control_idx, max_perts=50)
inv = evaluate_nomination(BaselinePredictor(bl), data, targets, cands, data.control_idx,
method="ranking", gene_cluster=gc, device=device)
out["rows"]["EndpointMLP"] = {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]},
"inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}}
return out
# table 12: perturbation representation
def ablation_representation(data, split="perturbation"):
out = {"split": split, "rows": {}}
for rep in REP_MODES:
model, info, cfg = _train(data, split, rep_mode=rep)
out["rows"][rep] = _fwd_inv(data, model, split)
return out
# table 13: inverse-search strategy (single model)
def ablation_inverse_search(data, split="perturbation"):
model, info, cfg = _train(data, split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
gc = data.functional_clusters(seed=0)
pred = PivotPredictor(model, data, device)
common = dict(reward_kind="centroid", gene_cluster=gc, model=model, device=device)
strategies = {
"ranking-only": dict(method="ranking"),
"random-init-opt": dict(method="guidance", guidance_init="random", rerank=False),
"mean-top-init": dict(method="guidance", guidance_init="mean_top", rerank=False),
"guidance-no-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=False, rerank=False),
"guidance-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=False),
"guidance+rerank": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=True),
}
out = {"split": split, "rows": {}}
for name, kw in strategies.items():
r = evaluate_nomination(pred, data, targets, cands, data.control_idx, **common, **kw)
out["rows"][name] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
return out
# table 14: guidance steps
def ablation_guidance_steps(data, split="perturbation"):
model, info, cfg = _train(data, split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
gc = data.functional_clusters(seed=0)
pred = PivotPredictor(model, data, device)
out = {"split": split, "rows": {}}
for steps in [0, 5, 10, 25, 50, 100]:
if steps == 0:
r = evaluate_nomination(pred, data, targets, cands, data.control_idx,
method="ranking", gene_cluster=gc, device=device)
else:
r = evaluate_nomination(pred, data, targets, cands, data.control_idx, method="guidance",
gene_cluster=gc, model=model, device=device,
guidance_steps=steps, rerank=True)
out["rows"][str(steps)] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
return out
# table 15: reward definitions
def ablation_reward(data, split="perturbation"):
model, info, cfg = _train(data, split)
device = next(model.parameters()).device
sp = load_split(data.dir, split)
cands = candidates_singles(data)
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:25]
gc = data.functional_clusters(seed=0)
pred = PivotPredictor(model, data, device)
out = {"split": split, "rows": {}}
for rk in ["centroid", "cosine", "nn_target", "mmd", "wasserstein"]:
r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=rk,
method="ranking", gene_cluster=gc, device=device)
out["rows"][rk] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
return out
# table 17: data scaling
def ablation_datascale(data, split="perturbation"):
out = {"split": split, "rows": {}}
for frac in [0.1, 0.25, 0.5, 0.75, 1.0]:
model, info, cfg = _train(data, split, train_frac=frac)
out["rows"][str(frac)] = _fwd_inv(data, model, split)
return out
# table 18: control matching
def ablation_matching(data, split="perturbation"):
from src.data.perturb_data import MATCH_STRATEGIES
out = {"split": split, "rows": {}}
for ms in MATCH_STRATEGIES:
model, info, cfg = _train(data, split, match=ms)
out["rows"][ms] = _fwd_inv(data, model, split)
return out
# table 20: compute cost
def ablation_cost(data, split="perturbation"):
model, info, cfg = _train(data, split)
device = next(model.parameters()).device
n_params = sum(p.numel() for p in model.parameters())
cands = candidates_singles(data)
pred = PivotPredictor(model, data, device)
c0 = data.emb[data.control_idx[:128]]
# pivot endpoint ranking: time per query (one target, rank all candidates)
import time as _t
t0 = _t.time()
_ = [pred.population(c, c0) for c in cands]
rank_time = _t.time() - t0
out = {"split": split,
"PIVOT": {"params": int(n_params), "train_time_s": info["duration_s"],
"rank_time_per_query_s": rank_time, "candidate_evals": len(cands)}}
return out
ABLATIONS = {
"components": ablation_components, # table 8
"representation": ablation_representation, # table 12
"inverse_search": ablation_inverse_search, # table 13
"guidance_steps": ablation_guidance_steps, # table 14
"reward": ablation_reward, # table 15
"datascale": ablation_datascale, # table 17
"matching": ablation_matching, # table 18
"cost": ablation_cost, # table 20
}
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", default="norman")
ap.add_argument("--ablations", nargs="+", default=list(ABLATIONS.keys()))
args = ap.parse_args()
data = load_dataset(args.dataset)
t0 = time.time()
for name in args.ablations:
print(f"=== ablation: {name} ===", flush=True)
res = ABLATIONS[name](data)
save_table(f"{args.dataset}_ablation_{name}", res)
print(f" done {name} ({time.time()-t0:.0f}s elapsed)", flush=True)
print(f"all ablations done in {time.time()-t0:.0f}s")
|