"""second round of table fills, all from pivot's own pipeline (no fabrication): a. pivot reward-guidance for combinatorial nomination (Table 7). b. time + candidate-query instrumentation for inverse-search ablations (Tables 11, 12). c. data-scaling counts: #perturbations and cells/perturbation per fraction (Table 15). d. held-out gene mse per perturbation representation (Table 10), trained on Replogle K562 gene split. writes experiments/results/norman_timing_scaling.json.""" import sys, os, json, time sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import numpy as np, torch from src.data.perturb_data import load_dataset from src.data.splits import load_split from src.training.train import TrainConfig, train from src.experiments.predictors import PivotPredictor from src.experiments.forward_eval import evaluate_forward from src.experiments.nomination_eval import evaluate_nomination from src.utils.common import save_json gpu = int(os.environ.get("PIVOT_GPU", "3")) data = load_dataset("norman") gc = data.functional_clusters(seed=0) out = {} # ===== a. pivot reward-guidance for combinatorial nomination (Table 7) ===== spc = load_split(data.dir, "combination") combo_cands = data.combos ctgt = [p for p in spc["test_perts"] if len(data.parse(p)) == 2][:26] mc, _ = train(TrainConfig(dataset="norman", split="combination", epochs=60, device_index=gpu), data=data, verbose=False) devc = next(mc.parameters()).device predc = PivotPredictor(mc, data, devc) g = evaluate_nomination(predc, data, ctgt, combo_cands, data.control_idx, reward_kind="centroid", method="guidance", guidance_init="warm", rerank=True, gene_cluster=gc, model=mc, device=devc) out["combo_guidance"] = {"top1": g["top1"], "top5": g["top5"], "ndcg": g["ndcg"], "endpoint_dist": g["endpoint_dist"]} print("combo_guidance", out["combo_guidance"], flush=True) # ===== b. time + query instrumentation for inverse search (Tables 11, 12) ===== sp = load_split(data.dir, "perturbation") cands = [p for p in data.perturbations if len(data.parse(p)) == 1] targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30] mf, _ = train(TrainConfig(dataset="norman", split="perturbation", epochs=60, device_index=gpu), data=data, verbose=False) dev = next(mf.parameters()).device pred = PivotPredictor(mf, data, dev) NC = len(cands) KN = 10 def timed(**kw): t0 = time.perf_counter() r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind="centroid", gene_cluster=gc, model=mf, device=dev, **kw) return r, (time.perf_counter() - t0) / max(1, len(targets)) # table 11 strategies: (label, kwargs, query_count) strat = { "ranking_only": (dict(method="ranking"), NC), "random_opt": (dict(method="guidance", guidance_init="random", rerank=False), 25), "mean_top_init": (dict(method="guidance", guidance_init="mean_top", rerank=False), 25 + KN), "guidance_no_norm": (dict(method="guidance", guidance_normalize=False, rerank=False), 25), "guidance_norm": (dict(method="guidance", guidance_normalize=True, rerank=False), 25), "guidance_rerank": (dict(method="guidance", rerank=True), 25 + KN), } out["search_timing"] = {} for name, (kw, q) in strat.items(): r, dt = timed(**kw) out["search_timing"][name] = {"sec_per_target": round(dt, 3), "queries": q, "top5": r["top5"], "ndcg": r["ndcg"], "endpoint_dist": r["endpoint_dist"]} print("timing", name, out["search_timing"][name], flush=True) # table 12: time vs guidance steps out["guidance_step_time"] = {} for s in [0, 5, 10, 25, 50, 100]: if s == 0: r, dt = timed(method="ranking") else: r, dt = timed(method="guidance", guidance_steps=s, rerank=False, guidance_normalize=True) out["guidance_step_time"][str(s)] = round(dt, 3) print("step_time", s, round(dt, 3), flush=True) # ===== c. data-scaling counts (Table 15) ===== # train_frac selects the first int(frac * n_train_perts) training perturbations. pert_train = [str(p) for p in sp["train_perts"]] n_train = len(pert_train) cells_per = [len(data.pert_to_idx[p]) for p in pert_train if p in data.pert_to_idx] med_cells = float(np.median(cells_per)) out["data_scaling_counts"] = {str(f): {"n_perts": max(1, int(f * n_train)), "cells_per_pert": int(round(med_cells))} for f in [0.1, 0.25, 0.5, 0.75, 1.0]} print("data_scaling_counts", out["data_scaling_counts"], flush=True) save_json(out, "experiments/results/norman_timing_scaling.json") # checkpoint before slow part # ===== d. held-out gene mse per representation (Table 10), Replogle K562 gene split ===== rep_data = load_dataset("replogle_k562") spg = load_split(rep_data.dir, "gene") gene_targets = list(spg["test_perts"])[:60] out["heldout_gene_mse"] = {} for rep in ["op_only", "gene_only", "random_id", "gene_op", "gene_pathway_op"]: try: m, _ = train(TrainConfig(dataset="replogle_k562", split="gene", epochs=45, rep_mode=rep, device_index=gpu), data=rep_data, verbose=False) d = next(m.parameters()).device p = PivotPredictor(m, rep_data, d) f = evaluate_forward(p, rep_data, gene_targets, rep_data.control_idx, max_perts=60) out["heldout_gene_mse"][rep] = round(float(f["mse"]), 4) except Exception as e: out["heldout_gene_mse"][rep] = None print("rep failed", rep, repr(e), flush=True) print("heldout_gene_mse", rep, out["heldout_gene_mse"][rep], flush=True) save_json(out, "experiments/results/norman_timing_scaling.json") print("FILL3_DONE", flush=True)