File size: 6,152 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | """extra experiments to fill in the runnable table cells. external methods we
can't reproduce stay unreported; only real computed values get emitted."""
import sys, os, json, time
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import numpy as np, torch
from src.data.perturb_data import load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.experiments.predictors import PivotPredictor, BaselinePredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.evaluation.baselines import build_baseline
from src.utils.common import save_json
gpu = int(os.environ.get("PIVOT_GPU", "3"))
data = load_dataset("norman")
out = {}
# ===== core-ablation extras (centroid reward, held-out perturbation) =====
sp = load_split(data.dir, "perturbation")
cands = [p for p in data.perturbations if len(data.parse(p)) == 1]
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
gc = data.functional_clusters(seed=0)
def fwd_inv(model, reward="centroid"):
dev = next(model.parameters()).device
pred = PivotPredictor(model, data, dev)
f = evaluate_forward(pred, data, list(sp["test_perts"])[:50], data.control_idx, max_perts=50)
r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=reward,
method="ranking", gene_cluster=gc, device=dev)
g = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=reward,
method="guidance", guidance_init="warm", rerank=False,
gene_cluster=gc, model=model, device=dev)
return f, r, g
# velocity-only: train with only the tangent (velocity) loss
mv, _ = train(TrainConfig(dataset="norman", split="perturbation", epochs=60, device_index=gpu,
components=["tan"]), data=data, verbose=False)
fv, rv, _ = fwd_inv(mv)
out["velocity_only"] = {"mse": fv["mse"], "de_corr": fv["de_corr"], "mmd": fv["mmd"],
"endpoint_dist": rv["endpoint_dist"], "top5": rv["top5"], "ndcg": rv["ndcg"]}
print("velocity_only", out["velocity_only"], flush=True)
# full model: ranking-only and guidance-without-reranking inverse rows
mf, info = train(TrainConfig(dataset="norman", split="perturbation", epochs=60, device_index=gpu),
data=data, verbose=False)
dev = next(mf.parameters()).device
ff, rf, gf = fwd_inv(mf)
out["ranking_only"] = {"mse": ff["mse"], "de_corr": ff["de_corr"], "mmd": ff["mmd"],
"endpoint_dist": rf["endpoint_dist"], "top5": rf["top5"], "ndcg": rf["ndcg"]}
out["guidance_no_rerank"] = {"mse": ff["mse"], "de_corr": ff["de_corr"], "mmd": ff["mmd"],
"endpoint_dist": gf["endpoint_dist"], "top5": gf["top5"], "ndcg": gf["ndcg"]}
print("ranking_only", out["ranking_only"], flush=True)
print("guidance_no_rerank", out["guidance_no_rerank"], flush=True)
# ===== inverse-table baseline: average perturbation effect + ranking (cosine) =====
bl = BaselinePredictor(build_baseline("AvgPerturbationEffect").fit(data, sp["train_perts"], sp["train_idx"]))
ra = evaluate_nomination(bl, data, targets, cands, data.control_idx, reward_kind="cosine",
method="ranking", gene_cluster=gc, device=dev)
out["avg_effect_ranking"] = {k: ra[k] for k in ["top1", "top5", "ndcg", "func_top5"]}
out["avg_effect_ranking"]["med_rank"] = float(np.median(ra["_per"]["rank"]))
print("avg_effect_ranking", out["avg_effect_ranking"], flush=True)
# ===== gpu memory for compute table =====
torch.cuda.reset_peak_memory_stats(dev)
c0 = torch.as_tensor(data.emb[data.control_idx[:256]], dtype=torch.float32, device=dev)
from src.evaluation import inference as inf
_ = inf.endpoint_ranking(mf, data, cands, c0, __import__("src.evaluation.rewards", fromlist=["Reward"]).Reward(
"centroid", target_c=data.emb[data.pert_to_idx[targets[0]]].mean(0), device=dev), device=dev)
out["gpu_mem_mb"] = round(torch.cuda.max_memory_allocated(dev) / 1e6, 1)
print("gpu_mem_mb", out["gpu_mem_mb"], flush=True)
# ===== combination table: additive + random + pivot guidance (combination split) =====
spc = load_split(data.dir, "combination")
combo_cands = data.combos
ctgt = [p for p in spc["test_perts"] if len(data.parse(p)) == 2][:26]
mc, _ = train(TrainConfig(dataset="norman", split="combination", epochs=60, device_index=gpu),
data=data, verbose=False)
devc = next(mc.parameters()).device
predc = PivotPredictor(mc, data, devc)
from src.experiments.nomination_eval import rank_candidates
import src.evaluation.metrics as M
rng = np.random.default_rng(0)
# additive baseline ranking over observed combos
addb = build_baseline("Additive").fit(data, spc["train_perts"], spc["train_idx"])
addp = BaselinePredictor(addb)
def combo_eval(predictor):
e1=e5=ov=0.0
for p in ctgt:
cstar = data.emb[data.pert_to_idx[p]].mean(0)
c0n = data.emb[rng.choice(data.control_idx, 128, replace=False)]
sk = dict(kind="centroid", c_star=cstar, target_sample=data.emb[data.pert_to_idx[p]],
device=devc, control_ref=data.emb[data.control_idx].mean(0))
ranked,_ = rank_candidates(predictor, combo_cands, c0n, sk)
e1 += M.top_k_accuracy(ranked,p,1); e5 += M.top_k_accuracy(ranked,p,5)
ov += M.partial_overlap(data.parse(ranked[0]), set(data.parse(p)))
n=len(ctgt); return e1/n, e5/n, ov/n
ae1, ae5, aov = combo_eval(addp)
out["combo_additive"] = {"exact1": ae1, "exact5": ae5, "overlap": aov}
# random combos
r_e1=r_e5=r_ov=0.0
for p in ctgt:
pick = list(rng.choice(combo_cands, 5, replace=False))
r_e1 += float(p==pick[0]); r_e5 += float(p in pick)
r_ov += M.partial_overlap(data.parse(pick[0]), set(data.parse(p)))
n=len(ctgt); out["combo_random"]={"exact1":r_e1/n,"exact5":r_e5/n,"overlap":r_ov/n}
print("combo_additive", out["combo_additive"], "combo_random", out["combo_random"], flush=True)
save_json(out, "experiments/results/norman_extra_ablations.json")
print("FILL_DONE", flush=True)
|