"""ablation studies (tables 8, 12-18, 20). every ablation is a real training/eval run on real data. epoch budget (EPOCHS) is held fixed across rows of a table for fair comparison. """ from __future__ import annotations import os import time import numpy as np import torch from src.data.perturb_data import PerturbData, load_dataset from src.data.splits import load_split from src.training.train import TrainConfig, train from src.evaluation.baselines import build_baseline from src.experiments.predictors import PivotPredictor, BaselinePredictor from src.experiments.forward_eval import evaluate_forward from src.experiments.nomination_eval import evaluate_nomination from src.experiments.run_tables import (RESULTS_DIR, DEVICE_INDEX, candidates_singles, save_table) from src.models.encoders import REP_MODES from src.utils.common import save_json EPOCHS = 45 def _train(data, split, **kw): cfg = TrainConfig(dataset=data.meta["name"], split=split, epochs=EPOCHS, device_index=DEVICE_INDEX, **kw) model, info = train(cfg, data=data, verbose=False) return model, info, cfg def _fwd_inv(data, model, split, max_perts=50, max_targets=30, reward="centroid"): """compact forward + inverse metrics for one model (for ablation rows).""" device = next(model.parameters()).device sp = load_split(data.dir, split) cands = candidates_singles(data) test_perts = list(sp["test_perts"]) if split != "cell" else cands targets = [p for p in test_perts if len(data.parse(p)) == 1 and p in cands][:max_targets] control_pool = data.control_idx gc = data.functional_clusters(seed=0) pred = PivotPredictor(model, data, device) fwd = evaluate_forward(pred, data, [p for p in test_perts][:max_perts], control_pool, max_perts=max_perts) inv = evaluate_nomination(pred, data, targets, cands, control_pool, reward_kind=reward, method="ranking", gene_cluster=gc, device=device) return {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]}, "inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}} # table 8: component ablation def ablation_components(data, split="perturbation"): variants = { "flow-map-only": ["map"], "no-tangent": ["map", "semi"], "no-semigroup": ["map", "tan"], "PIVOT-full": ["map", "tan", "semi"], } out = {"split": split, "rows": {}} for name, comps in variants.items(): model, info, cfg = _train(data, split, components=comps) out["rows"][name] = _fwd_inv(data, model, split) # endpointmlp baseline (no flow map) sp = load_split(data.dir, split) bl = build_baseline("EndpointMLP").fit(data, sp["train_perts"], sp["train_idx"]) device = torch.device(f"cuda:{DEVICE_INDEX}") cands = candidates_singles(data) targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30] gc = data.functional_clusters(seed=0) fwd = evaluate_forward(BaselinePredictor(bl), data, list(sp["test_perts"])[:50], data.control_idx, max_perts=50) inv = evaluate_nomination(BaselinePredictor(bl), data, targets, cands, data.control_idx, method="ranking", gene_cluster=gc, device=device) out["rows"]["EndpointMLP"] = {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]}, "inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}} return out # table 12: perturbation representation def ablation_representation(data, split="perturbation"): out = {"split": split, "rows": {}} for rep in REP_MODES: model, info, cfg = _train(data, split, rep_mode=rep) out["rows"][rep] = _fwd_inv(data, model, split) return out # table 13: inverse-search strategy (single model) def ablation_inverse_search(data, split="perturbation"): model, info, cfg = _train(data, split) device = next(model.parameters()).device sp = load_split(data.dir, split) cands = candidates_singles(data) targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30] gc = data.functional_clusters(seed=0) pred = PivotPredictor(model, data, device) common = dict(reward_kind="centroid", gene_cluster=gc, model=model, device=device) strategies = { "ranking-only": dict(method="ranking"), "random-init-opt": dict(method="guidance", guidance_init="random", rerank=False), "mean-top-init": dict(method="guidance", guidance_init="mean_top", rerank=False), "guidance-no-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=False, rerank=False), "guidance-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=False), "guidance+rerank": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=True), } out = {"split": split, "rows": {}} for name, kw in strategies.items(): r = evaluate_nomination(pred, data, targets, cands, data.control_idx, **common, **kw) out["rows"][name] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]} return out # table 14: guidance steps def ablation_guidance_steps(data, split="perturbation"): model, info, cfg = _train(data, split) device = next(model.parameters()).device sp = load_split(data.dir, split) cands = candidates_singles(data) targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30] gc = data.functional_clusters(seed=0) pred = PivotPredictor(model, data, device) out = {"split": split, "rows": {}} for steps in [0, 5, 10, 25, 50, 100]: if steps == 0: r = evaluate_nomination(pred, data, targets, cands, data.control_idx, method="ranking", gene_cluster=gc, device=device) else: r = evaluate_nomination(pred, data, targets, cands, data.control_idx, method="guidance", gene_cluster=gc, model=model, device=device, guidance_steps=steps, rerank=True) out["rows"][str(steps)] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]} return out # table 15: reward definitions def ablation_reward(data, split="perturbation"): model, info, cfg = _train(data, split) device = next(model.parameters()).device sp = load_split(data.dir, split) cands = candidates_singles(data) targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:25] gc = data.functional_clusters(seed=0) pred = PivotPredictor(model, data, device) out = {"split": split, "rows": {}} for rk in ["centroid", "cosine", "nn_target", "mmd", "wasserstein"]: r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=rk, method="ranking", gene_cluster=gc, device=device) out["rows"][rk] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]} return out # table 17: data scaling def ablation_datascale(data, split="perturbation"): out = {"split": split, "rows": {}} for frac in [0.1, 0.25, 0.5, 0.75, 1.0]: model, info, cfg = _train(data, split, train_frac=frac) out["rows"][str(frac)] = _fwd_inv(data, model, split) return out # table 18: control matching def ablation_matching(data, split="perturbation"): from src.data.perturb_data import MATCH_STRATEGIES out = {"split": split, "rows": {}} for ms in MATCH_STRATEGIES: model, info, cfg = _train(data, split, match=ms) out["rows"][ms] = _fwd_inv(data, model, split) return out # table 20: compute cost def ablation_cost(data, split="perturbation"): model, info, cfg = _train(data, split) device = next(model.parameters()).device n_params = sum(p.numel() for p in model.parameters()) cands = candidates_singles(data) pred = PivotPredictor(model, data, device) c0 = data.emb[data.control_idx[:128]] # pivot endpoint ranking: time per query (one target, rank all candidates) import time as _t t0 = _t.time() _ = [pred.population(c, c0) for c in cands] rank_time = _t.time() - t0 out = {"split": split, "PIVOT": {"params": int(n_params), "train_time_s": info["duration_s"], "rank_time_per_query_s": rank_time, "candidate_evals": len(cands)}} return out ABLATIONS = { "components": ablation_components, # table 8 "representation": ablation_representation, # table 12 "inverse_search": ablation_inverse_search, # table 13 "guidance_steps": ablation_guidance_steps, # table 14 "reward": ablation_reward, # table 15 "datascale": ablation_datascale, # table 17 "matching": ablation_matching, # table 18 "cost": ablation_cost, # table 20 } if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("--dataset", default="norman") ap.add_argument("--ablations", nargs="+", default=list(ABLATIONS.keys())) args = ap.parse_args() data = load_dataset(args.dataset) t0 = time.time() for name in args.ablations: print(f"=== ablation: {name} ===", flush=True) res = ABLATIONS[name](data) save_table(f"{args.dataset}_ablation_{name}", res) print(f" done {name} ({time.time()-t0:.0f}s elapsed)", flush=True) print(f"all ablations done in {time.time()-t0:.0f}s")