PIVOT / scripts /timing_and_scaling.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
5.83 kB
"""second round of table fills, all from pivot's own pipeline (no fabrication):
a. pivot reward-guidance for combinatorial nomination (Table 7).
b. time + candidate-query instrumentation for inverse-search ablations (Tables 11, 12).
c. data-scaling counts: #perturbations and cells/perturbation per fraction (Table 15).
d. held-out gene mse per perturbation representation (Table 10), trained on Replogle K562
gene split.
writes experiments/results/norman_timing_scaling.json."""
import sys, os, json, time
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import numpy as np, torch
from src.data.perturb_data import load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.experiments.predictors import PivotPredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.utils.common import save_json
gpu = int(os.environ.get("PIVOT_GPU", "3"))
data = load_dataset("norman")
gc = data.functional_clusters(seed=0)
out = {}
# ===== a. pivot reward-guidance for combinatorial nomination (Table 7) =====
spc = load_split(data.dir, "combination")
combo_cands = data.combos
ctgt = [p for p in spc["test_perts"] if len(data.parse(p)) == 2][:26]
mc, _ = train(TrainConfig(dataset="norman", split="combination", epochs=60, device_index=gpu),
data=data, verbose=False)
devc = next(mc.parameters()).device
predc = PivotPredictor(mc, data, devc)
g = evaluate_nomination(predc, data, ctgt, combo_cands, data.control_idx, reward_kind="centroid",
method="guidance", guidance_init="warm", rerank=True, gene_cluster=gc,
model=mc, device=devc)
out["combo_guidance"] = {"top1": g["top1"], "top5": g["top5"], "ndcg": g["ndcg"],
"endpoint_dist": g["endpoint_dist"]}
print("combo_guidance", out["combo_guidance"], flush=True)
# ===== b. time + query instrumentation for inverse search (Tables 11, 12) =====
sp = load_split(data.dir, "perturbation")
cands = [p for p in data.perturbations if len(data.parse(p)) == 1]
targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
mf, _ = train(TrainConfig(dataset="norman", split="perturbation", epochs=60, device_index=gpu),
data=data, verbose=False)
dev = next(mf.parameters()).device
pred = PivotPredictor(mf, data, dev)
NC = len(cands)
KN = 10
def timed(**kw):
t0 = time.perf_counter()
r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind="centroid",
gene_cluster=gc, model=mf, device=dev, **kw)
return r, (time.perf_counter() - t0) / max(1, len(targets))
# table 11 strategies: (label, kwargs, query_count)
strat = {
"ranking_only": (dict(method="ranking"), NC),
"random_opt": (dict(method="guidance", guidance_init="random", rerank=False), 25),
"mean_top_init": (dict(method="guidance", guidance_init="mean_top", rerank=False), 25 + KN),
"guidance_no_norm": (dict(method="guidance", guidance_normalize=False, rerank=False), 25),
"guidance_norm": (dict(method="guidance", guidance_normalize=True, rerank=False), 25),
"guidance_rerank": (dict(method="guidance", rerank=True), 25 + KN),
}
out["search_timing"] = {}
for name, (kw, q) in strat.items():
r, dt = timed(**kw)
out["search_timing"][name] = {"sec_per_target": round(dt, 3), "queries": q,
"top5": r["top5"], "ndcg": r["ndcg"], "endpoint_dist": r["endpoint_dist"]}
print("timing", name, out["search_timing"][name], flush=True)
# table 12: time vs guidance steps
out["guidance_step_time"] = {}
for s in [0, 5, 10, 25, 50, 100]:
if s == 0:
r, dt = timed(method="ranking")
else:
r, dt = timed(method="guidance", guidance_steps=s, rerank=False, guidance_normalize=True)
out["guidance_step_time"][str(s)] = round(dt, 3)
print("step_time", s, round(dt, 3), flush=True)
# ===== c. data-scaling counts (Table 15) =====
# train_frac selects the first int(frac * n_train_perts) training perturbations.
pert_train = [str(p) for p in sp["train_perts"]]
n_train = len(pert_train)
cells_per = [len(data.pert_to_idx[p]) for p in pert_train if p in data.pert_to_idx]
med_cells = float(np.median(cells_per))
out["data_scaling_counts"] = {str(f): {"n_perts": max(1, int(f * n_train)),
"cells_per_pert": int(round(med_cells))}
for f in [0.1, 0.25, 0.5, 0.75, 1.0]}
print("data_scaling_counts", out["data_scaling_counts"], flush=True)
save_json(out, "experiments/results/norman_timing_scaling.json") # checkpoint before slow part
# ===== d. held-out gene mse per representation (Table 10), Replogle K562 gene split =====
rep_data = load_dataset("replogle_k562")
spg = load_split(rep_data.dir, "gene")
gene_targets = list(spg["test_perts"])[:60]
out["heldout_gene_mse"] = {}
for rep in ["op_only", "gene_only", "random_id", "gene_op", "gene_pathway_op"]:
try:
m, _ = train(TrainConfig(dataset="replogle_k562", split="gene", epochs=45,
rep_mode=rep, device_index=gpu), data=rep_data, verbose=False)
d = next(m.parameters()).device
p = PivotPredictor(m, rep_data, d)
f = evaluate_forward(p, rep_data, gene_targets, rep_data.control_idx, max_perts=60)
out["heldout_gene_mse"][rep] = round(float(f["mse"]), 4)
except Exception as e:
out["heldout_gene_mse"][rep] = None
print("rep failed", rep, repr(e), flush=True)
print("heldout_gene_mse", rep, out["heldout_gene_mse"][rep], flush=True)
save_json(out, "experiments/results/norman_timing_scaling.json")
print("FILL3_DONE", flush=True)