| """ablation studies (tables 8, 12-18, 20). |
| |
| every ablation is a real training/eval run on real data. epoch budget (EPOCHS) is |
| held fixed across rows of a table for fair comparison. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import time |
|
|
| import numpy as np |
| import torch |
|
|
| from src.data.perturb_data import PerturbData, load_dataset |
| from src.data.splits import load_split |
| from src.training.train import TrainConfig, train |
| from src.evaluation.baselines import build_baseline |
| from src.experiments.predictors import PivotPredictor, BaselinePredictor |
| from src.experiments.forward_eval import evaluate_forward |
| from src.experiments.nomination_eval import evaluate_nomination |
| from src.experiments.run_tables import (RESULTS_DIR, DEVICE_INDEX, candidates_singles, |
| save_table) |
| from src.models.encoders import REP_MODES |
| from src.utils.common import save_json |
|
|
| EPOCHS = 45 |
|
|
|
|
| def _train(data, split, **kw): |
| cfg = TrainConfig(dataset=data.meta["name"], split=split, epochs=EPOCHS, |
| device_index=DEVICE_INDEX, **kw) |
| model, info = train(cfg, data=data, verbose=False) |
| return model, info, cfg |
|
|
|
|
| def _fwd_inv(data, model, split, max_perts=50, max_targets=30, reward="centroid"): |
| """compact forward + inverse metrics for one model (for ablation rows).""" |
| device = next(model.parameters()).device |
| sp = load_split(data.dir, split) |
| cands = candidates_singles(data) |
| test_perts = list(sp["test_perts"]) if split != "cell" else cands |
| targets = [p for p in test_perts if len(data.parse(p)) == 1 and p in cands][:max_targets] |
| control_pool = data.control_idx |
| gc = data.functional_clusters(seed=0) |
| pred = PivotPredictor(model, data, device) |
| fwd = evaluate_forward(pred, data, [p for p in test_perts][:max_perts], control_pool, max_perts=max_perts) |
| inv = evaluate_nomination(pred, data, targets, cands, control_pool, reward_kind=reward, |
| method="ranking", gene_cluster=gc, device=device) |
| return {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]}, |
| "inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}} |
|
|
|
|
| |
| def ablation_components(data, split="perturbation"): |
| variants = { |
| "flow-map-only": ["map"], |
| "no-tangent": ["map", "semi"], |
| "no-semigroup": ["map", "tan"], |
| "PIVOT-full": ["map", "tan", "semi"], |
| } |
| out = {"split": split, "rows": {}} |
| for name, comps in variants.items(): |
| model, info, cfg = _train(data, split, components=comps) |
| out["rows"][name] = _fwd_inv(data, model, split) |
| |
| sp = load_split(data.dir, split) |
| bl = build_baseline("EndpointMLP").fit(data, sp["train_perts"], sp["train_idx"]) |
| device = torch.device(f"cuda:{DEVICE_INDEX}") |
| cands = candidates_singles(data) |
| targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30] |
| gc = data.functional_clusters(seed=0) |
| fwd = evaluate_forward(BaselinePredictor(bl), data, list(sp["test_perts"])[:50], data.control_idx, max_perts=50) |
| inv = evaluate_nomination(BaselinePredictor(bl), data, targets, cands, data.control_idx, |
| method="ranking", gene_cluster=gc, device=device) |
| out["rows"]["EndpointMLP"] = {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]}, |
| "inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}} |
| return out |
|
|
|
|
| |
| def ablation_representation(data, split="perturbation"): |
| out = {"split": split, "rows": {}} |
| for rep in REP_MODES: |
| model, info, cfg = _train(data, split, rep_mode=rep) |
| out["rows"][rep] = _fwd_inv(data, model, split) |
| return out |
|
|
|
|
| |
| def ablation_inverse_search(data, split="perturbation"): |
| model, info, cfg = _train(data, split) |
| device = next(model.parameters()).device |
| sp = load_split(data.dir, split) |
| cands = candidates_singles(data) |
| targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30] |
| gc = data.functional_clusters(seed=0) |
| pred = PivotPredictor(model, data, device) |
| common = dict(reward_kind="centroid", gene_cluster=gc, model=model, device=device) |
| strategies = { |
| "ranking-only": dict(method="ranking"), |
| "random-init-opt": dict(method="guidance", guidance_init="random", rerank=False), |
| "mean-top-init": dict(method="guidance", guidance_init="mean_top", rerank=False), |
| "guidance-no-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=False, rerank=False), |
| "guidance-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=False), |
| "guidance+rerank": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=True), |
| } |
| out = {"split": split, "rows": {}} |
| for name, kw in strategies.items(): |
| r = evaluate_nomination(pred, data, targets, cands, data.control_idx, **common, **kw) |
| out["rows"][name] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]} |
| return out |
|
|
|
|
| |
| def ablation_guidance_steps(data, split="perturbation"): |
| model, info, cfg = _train(data, split) |
| device = next(model.parameters()).device |
| sp = load_split(data.dir, split) |
| cands = candidates_singles(data) |
| targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30] |
| gc = data.functional_clusters(seed=0) |
| pred = PivotPredictor(model, data, device) |
| out = {"split": split, "rows": {}} |
| for steps in [0, 5, 10, 25, 50, 100]: |
| if steps == 0: |
| r = evaluate_nomination(pred, data, targets, cands, data.control_idx, |
| method="ranking", gene_cluster=gc, device=device) |
| else: |
| r = evaluate_nomination(pred, data, targets, cands, data.control_idx, method="guidance", |
| gene_cluster=gc, model=model, device=device, |
| guidance_steps=steps, rerank=True) |
| out["rows"][str(steps)] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]} |
| return out |
|
|
|
|
| |
| def ablation_reward(data, split="perturbation"): |
| model, info, cfg = _train(data, split) |
| device = next(model.parameters()).device |
| sp = load_split(data.dir, split) |
| cands = candidates_singles(data) |
| targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:25] |
| gc = data.functional_clusters(seed=0) |
| pred = PivotPredictor(model, data, device) |
| out = {"split": split, "rows": {}} |
| for rk in ["centroid", "cosine", "nn_target", "mmd", "wasserstein"]: |
| r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=rk, |
| method="ranking", gene_cluster=gc, device=device) |
| out["rows"][rk] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]} |
| return out |
|
|
|
|
| |
| def ablation_datascale(data, split="perturbation"): |
| out = {"split": split, "rows": {}} |
| for frac in [0.1, 0.25, 0.5, 0.75, 1.0]: |
| model, info, cfg = _train(data, split, train_frac=frac) |
| out["rows"][str(frac)] = _fwd_inv(data, model, split) |
| return out |
|
|
|
|
| |
| def ablation_matching(data, split="perturbation"): |
| from src.data.perturb_data import MATCH_STRATEGIES |
| out = {"split": split, "rows": {}} |
| for ms in MATCH_STRATEGIES: |
| model, info, cfg = _train(data, split, match=ms) |
| out["rows"][ms] = _fwd_inv(data, model, split) |
| return out |
|
|
|
|
| |
| def ablation_cost(data, split="perturbation"): |
| model, info, cfg = _train(data, split) |
| device = next(model.parameters()).device |
| n_params = sum(p.numel() for p in model.parameters()) |
| cands = candidates_singles(data) |
| pred = PivotPredictor(model, data, device) |
| c0 = data.emb[data.control_idx[:128]] |
| |
| import time as _t |
| t0 = _t.time() |
| _ = [pred.population(c, c0) for c in cands] |
| rank_time = _t.time() - t0 |
| out = {"split": split, |
| "PIVOT": {"params": int(n_params), "train_time_s": info["duration_s"], |
| "rank_time_per_query_s": rank_time, "candidate_evals": len(cands)}} |
| return out |
|
|
|
|
| ABLATIONS = { |
| "components": ablation_components, |
| "representation": ablation_representation, |
| "inverse_search": ablation_inverse_search, |
| "guidance_steps": ablation_guidance_steps, |
| "reward": ablation_reward, |
| "datascale": ablation_datascale, |
| "matching": ablation_matching, |
| "cost": ablation_cost, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--dataset", default="norman") |
| ap.add_argument("--ablations", nargs="+", default=list(ABLATIONS.keys())) |
| args = ap.parse_args() |
| data = load_dataset(args.dataset) |
| t0 = time.time() |
| for name in args.ablations: |
| print(f"=== ablation: {name} ===", flush=True) |
| res = ABLATIONS[name](data) |
| save_table(f"{args.dataset}_ablation_{name}", res) |
| print(f" done {name} ({time.time()-t0:.0f}s elapsed)", flush=True) |
| print(f"all ablations done in {time.time()-t0:.0f}s") |
|
|