File size: 6,152 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""extra experiments to fill in the runnable table cells. external methods we
can't reproduce stay unreported; only real computed values get emitted."""
import sys, os, json, time
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import numpy as np, torch
from src.data.perturb_data import load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.experiments.predictors import PivotPredictor, BaselinePredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.evaluation.baselines import build_baseline
from src.utils.common import save_json

gpu = int(os.environ.get("PIVOT_GPU", "3"))
data = load_dataset("norman")
out = {}

# ===== core-ablation extras (centroid reward, held-out perturbation) =====
sp = load_split(data.dir, "perturbation")
cands = [p for p in data.perturbations if len(data.parse(p)) == 1]
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
gc = data.functional_clusters(seed=0)

def fwd_inv(model, reward="centroid"):
    dev = next(model.parameters()).device
    pred = PivotPredictor(model, data, dev)
    f = evaluate_forward(pred, data, list(sp["test_perts"])[:50], data.control_idx, max_perts=50)
    r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=reward,
                            method="ranking", gene_cluster=gc, device=dev)
    g = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=reward,
                            method="guidance", guidance_init="warm", rerank=False,
                            gene_cluster=gc, model=model, device=dev)
    return f, r, g

# velocity-only: train with only the tangent (velocity) loss
mv, _ = train(TrainConfig(dataset="norman", split="perturbation", epochs=60, device_index=gpu,
                          components=["tan"]), data=data, verbose=False)
fv, rv, _ = fwd_inv(mv)
out["velocity_only"] = {"mse": fv["mse"], "de_corr": fv["de_corr"], "mmd": fv["mmd"],
                        "endpoint_dist": rv["endpoint_dist"], "top5": rv["top5"], "ndcg": rv["ndcg"]}
print("velocity_only", out["velocity_only"], flush=True)

# full model: ranking-only and guidance-without-reranking inverse rows
mf, info = train(TrainConfig(dataset="norman", split="perturbation", epochs=60, device_index=gpu),
                 data=data, verbose=False)
dev = next(mf.parameters()).device
ff, rf, gf = fwd_inv(mf)
out["ranking_only"] = {"mse": ff["mse"], "de_corr": ff["de_corr"], "mmd": ff["mmd"],
                       "endpoint_dist": rf["endpoint_dist"], "top5": rf["top5"], "ndcg": rf["ndcg"]}
out["guidance_no_rerank"] = {"mse": ff["mse"], "de_corr": ff["de_corr"], "mmd": ff["mmd"],
                             "endpoint_dist": gf["endpoint_dist"], "top5": gf["top5"], "ndcg": gf["ndcg"]}
print("ranking_only", out["ranking_only"], flush=True)
print("guidance_no_rerank", out["guidance_no_rerank"], flush=True)

# ===== inverse-table baseline: average perturbation effect + ranking (cosine) =====
bl = BaselinePredictor(build_baseline("AvgPerturbationEffect").fit(data, sp["train_perts"], sp["train_idx"]))
ra = evaluate_nomination(bl, data, targets, cands, data.control_idx, reward_kind="cosine",
                         method="ranking", gene_cluster=gc, device=dev)
out["avg_effect_ranking"] = {k: ra[k] for k in ["top1", "top5", "ndcg", "func_top5"]}
out["avg_effect_ranking"]["med_rank"] = float(np.median(ra["_per"]["rank"]))
print("avg_effect_ranking", out["avg_effect_ranking"], flush=True)

# ===== gpu memory for compute table =====
torch.cuda.reset_peak_memory_stats(dev)
c0 = torch.as_tensor(data.emb[data.control_idx[:256]], dtype=torch.float32, device=dev)
from src.evaluation import inference as inf
_ = inf.endpoint_ranking(mf, data, cands, c0, __import__("src.evaluation.rewards", fromlist=["Reward"]).Reward(
    "centroid", target_c=data.emb[data.pert_to_idx[targets[0]]].mean(0), device=dev), device=dev)
out["gpu_mem_mb"] = round(torch.cuda.max_memory_allocated(dev) / 1e6, 1)
print("gpu_mem_mb", out["gpu_mem_mb"], flush=True)

# ===== combination table: additive + random + pivot guidance (combination split) =====
spc = load_split(data.dir, "combination")
combo_cands = data.combos
ctgt = [p for p in spc["test_perts"] if len(data.parse(p)) == 2][:26]
mc, _ = train(TrainConfig(dataset="norman", split="combination", epochs=60, device_index=gpu),
              data=data, verbose=False)
devc = next(mc.parameters()).device
predc = PivotPredictor(mc, data, devc)
from src.experiments.nomination_eval import rank_candidates
import src.evaluation.metrics as M
rng = np.random.default_rng(0)
# additive baseline ranking over observed combos
addb = build_baseline("Additive").fit(data, spc["train_perts"], spc["train_idx"])
addp = BaselinePredictor(addb)
def combo_eval(predictor):
    e1=e5=ov=0.0
    for p in ctgt:
        cstar = data.emb[data.pert_to_idx[p]].mean(0)
        c0n = data.emb[rng.choice(data.control_idx, 128, replace=False)]
        sk = dict(kind="centroid", c_star=cstar, target_sample=data.emb[data.pert_to_idx[p]],
                  device=devc, control_ref=data.emb[data.control_idx].mean(0))
        ranked,_ = rank_candidates(predictor, combo_cands, c0n, sk)
        e1 += M.top_k_accuracy(ranked,p,1); e5 += M.top_k_accuracy(ranked,p,5)
        ov += M.partial_overlap(data.parse(ranked[0]), set(data.parse(p)))
    n=len(ctgt); return e1/n, e5/n, ov/n
ae1, ae5, aov = combo_eval(addp)
out["combo_additive"] = {"exact1": ae1, "exact5": ae5, "overlap": aov}
# random combos
r_e1=r_e5=r_ov=0.0
for p in ctgt:
    pick = list(rng.choice(combo_cands, 5, replace=False))
    r_e1 += float(p==pick[0]); r_e5 += float(p in pick)
    r_ov += M.partial_overlap(data.parse(pick[0]), set(data.parse(p)))
n=len(ctgt); out["combo_random"]={"exact1":r_e1/n,"exact5":r_e5/n,"overlap":r_ov/n}
print("combo_additive", out["combo_additive"], "combo_random", out["combo_random"], flush=True)

save_json(out, "experiments/results/norman_extra_ablations.json")
print("FILL_DONE", flush=True)