| """second round of table fills, all from pivot's own pipeline (no fabrication): |
| a. pivot reward-guidance for combinatorial nomination (Table 7). |
| b. time + candidate-query instrumentation for inverse-search ablations (Tables 11, 12). |
| c. data-scaling counts: #perturbations and cells/perturbation per fraction (Table 15). |
| d. held-out gene mse per perturbation representation (Table 10), trained on Replogle K562 |
| gene split. |
| writes experiments/results/norman_timing_scaling.json.""" |
| import sys, os, json, time |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
| import numpy as np, torch |
| from src.data.perturb_data import load_dataset |
| from src.data.splits import load_split |
| from src.training.train import TrainConfig, train |
| from src.experiments.predictors import PivotPredictor |
| from src.experiments.forward_eval import evaluate_forward |
| from src.experiments.nomination_eval import evaluate_nomination |
| from src.utils.common import save_json |
|
|
| gpu = int(os.environ.get("PIVOT_GPU", "3")) |
| data = load_dataset("norman") |
| gc = data.functional_clusters(seed=0) |
| out = {} |
|
|
| |
| spc = load_split(data.dir, "combination") |
| combo_cands = data.combos |
| ctgt = [p for p in spc["test_perts"] if len(data.parse(p)) == 2][:26] |
| mc, _ = train(TrainConfig(dataset="norman", split="combination", epochs=60, device_index=gpu), |
| data=data, verbose=False) |
| devc = next(mc.parameters()).device |
| predc = PivotPredictor(mc, data, devc) |
| g = evaluate_nomination(predc, data, ctgt, combo_cands, data.control_idx, reward_kind="centroid", |
| method="guidance", guidance_init="warm", rerank=True, gene_cluster=gc, |
| model=mc, device=devc) |
| out["combo_guidance"] = {"top1": g["top1"], "top5": g["top5"], "ndcg": g["ndcg"], |
| "endpoint_dist": g["endpoint_dist"]} |
| print("combo_guidance", out["combo_guidance"], flush=True) |
|
|
| |
| sp = load_split(data.dir, "perturbation") |
| cands = [p for p in data.perturbations if len(data.parse(p)) == 1] |
| targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30] |
| mf, _ = train(TrainConfig(dataset="norman", split="perturbation", epochs=60, device_index=gpu), |
| data=data, verbose=False) |
| dev = next(mf.parameters()).device |
| pred = PivotPredictor(mf, data, dev) |
| NC = len(cands) |
| KN = 10 |
|
|
|
|
| def timed(**kw): |
| t0 = time.perf_counter() |
| r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind="centroid", |
| gene_cluster=gc, model=mf, device=dev, **kw) |
| return r, (time.perf_counter() - t0) / max(1, len(targets)) |
|
|
|
|
| |
| strat = { |
| "ranking_only": (dict(method="ranking"), NC), |
| "random_opt": (dict(method="guidance", guidance_init="random", rerank=False), 25), |
| "mean_top_init": (dict(method="guidance", guidance_init="mean_top", rerank=False), 25 + KN), |
| "guidance_no_norm": (dict(method="guidance", guidance_normalize=False, rerank=False), 25), |
| "guidance_norm": (dict(method="guidance", guidance_normalize=True, rerank=False), 25), |
| "guidance_rerank": (dict(method="guidance", rerank=True), 25 + KN), |
| } |
| out["search_timing"] = {} |
| for name, (kw, q) in strat.items(): |
| r, dt = timed(**kw) |
| out["search_timing"][name] = {"sec_per_target": round(dt, 3), "queries": q, |
| "top5": r["top5"], "ndcg": r["ndcg"], "endpoint_dist": r["endpoint_dist"]} |
| print("timing", name, out["search_timing"][name], flush=True) |
|
|
| |
| out["guidance_step_time"] = {} |
| for s in [0, 5, 10, 25, 50, 100]: |
| if s == 0: |
| r, dt = timed(method="ranking") |
| else: |
| r, dt = timed(method="guidance", guidance_steps=s, rerank=False, guidance_normalize=True) |
| out["guidance_step_time"][str(s)] = round(dt, 3) |
| print("step_time", s, round(dt, 3), flush=True) |
|
|
| |
| |
| pert_train = [str(p) for p in sp["train_perts"]] |
| n_train = len(pert_train) |
| cells_per = [len(data.pert_to_idx[p]) for p in pert_train if p in data.pert_to_idx] |
| med_cells = float(np.median(cells_per)) |
| out["data_scaling_counts"] = {str(f): {"n_perts": max(1, int(f * n_train)), |
| "cells_per_pert": int(round(med_cells))} |
| for f in [0.1, 0.25, 0.5, 0.75, 1.0]} |
| print("data_scaling_counts", out["data_scaling_counts"], flush=True) |
| save_json(out, "experiments/results/norman_timing_scaling.json") |
|
|
| |
| rep_data = load_dataset("replogle_k562") |
| spg = load_split(rep_data.dir, "gene") |
| gene_targets = list(spg["test_perts"])[:60] |
| out["heldout_gene_mse"] = {} |
| for rep in ["op_only", "gene_only", "random_id", "gene_op", "gene_pathway_op"]: |
| try: |
| m, _ = train(TrainConfig(dataset="replogle_k562", split="gene", epochs=45, |
| rep_mode=rep, device_index=gpu), data=rep_data, verbose=False) |
| d = next(m.parameters()).device |
| p = PivotPredictor(m, rep_data, d) |
| f = evaluate_forward(p, rep_data, gene_targets, rep_data.control_idx, max_perts=60) |
| out["heldout_gene_mse"][rep] = round(float(f["mse"]), 4) |
| except Exception as e: |
| out["heldout_gene_mse"][rep] = None |
| print("rep failed", rep, repr(e), flush=True) |
| print("heldout_gene_mse", rep, out["heldout_gene_mse"][rep], flush=True) |
| save_json(out, "experiments/results/norman_timing_scaling.json") |
|
|
| print("FILL3_DONE", flush=True) |
|
|