File size: 9,633 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""ablation studies (tables 8, 12-18, 20).

every ablation is a real training/eval run on real data. epoch budget (EPOCHS) is
held fixed across rows of a table for fair comparison.
"""
from __future__ import annotations

import os
import time

import numpy as np
import torch

from src.data.perturb_data import PerturbData, load_dataset
from src.data.splits import load_split
from src.training.train import TrainConfig, train
from src.evaluation.baselines import build_baseline
from src.experiments.predictors import PivotPredictor, BaselinePredictor
from src.experiments.forward_eval import evaluate_forward
from src.experiments.nomination_eval import evaluate_nomination
from src.experiments.run_tables import (RESULTS_DIR, DEVICE_INDEX, candidates_singles,
                                        save_table)
from src.models.encoders import REP_MODES
from src.utils.common import save_json

EPOCHS = 45


def _train(data, split, **kw):
    cfg = TrainConfig(dataset=data.meta["name"], split=split, epochs=EPOCHS,
                      device_index=DEVICE_INDEX, **kw)
    model, info = train(cfg, data=data, verbose=False)
    return model, info, cfg


def _fwd_inv(data, model, split, max_perts=50, max_targets=30, reward="centroid"):
    """compact forward + inverse metrics for one model (for ablation rows)."""
    device = next(model.parameters()).device
    sp = load_split(data.dir, split)
    cands = candidates_singles(data)
    test_perts = list(sp["test_perts"]) if split != "cell" else cands
    targets = [p for p in test_perts if len(data.parse(p)) == 1 and p in cands][:max_targets]
    control_pool = data.control_idx
    gc = data.functional_clusters(seed=0)
    pred = PivotPredictor(model, data, device)
    fwd = evaluate_forward(pred, data, [p for p in test_perts][:max_perts], control_pool, max_perts=max_perts)
    inv = evaluate_nomination(pred, data, targets, cands, control_pool, reward_kind=reward,
                              method="ranking", gene_cluster=gc, device=device)
    return {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]},
            "inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}}


# table 8: component ablation
def ablation_components(data, split="perturbation"):
    variants = {
        "flow-map-only": ["map"],
        "no-tangent": ["map", "semi"],
        "no-semigroup": ["map", "tan"],
        "PIVOT-full": ["map", "tan", "semi"],
    }
    out = {"split": split, "rows": {}}
    for name, comps in variants.items():
        model, info, cfg = _train(data, split, components=comps)
        out["rows"][name] = _fwd_inv(data, model, split)
    # endpointmlp baseline (no flow map)
    sp = load_split(data.dir, split)
    bl = build_baseline("EndpointMLP").fit(data, sp["train_perts"], sp["train_idx"])
    device = torch.device(f"cuda:{DEVICE_INDEX}")
    cands = candidates_singles(data)
    targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
    gc = data.functional_clusters(seed=0)
    fwd = evaluate_forward(BaselinePredictor(bl), data, list(sp["test_perts"])[:50], data.control_idx, max_perts=50)
    inv = evaluate_nomination(BaselinePredictor(bl), data, targets, cands, data.control_idx,
                              method="ranking", gene_cluster=gc, device=device)
    out["rows"]["EndpointMLP"] = {"forward": {k: fwd[k] for k in ["mse", "de_corr", "mmd", "pearson"]},
                                  "inverse": {k: inv[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}}
    return out


# table 12: perturbation representation
def ablation_representation(data, split="perturbation"):
    out = {"split": split, "rows": {}}
    for rep in REP_MODES:
        model, info, cfg = _train(data, split, rep_mode=rep)
        out["rows"][rep] = _fwd_inv(data, model, split)
    return out


# table 13: inverse-search strategy (single model)
def ablation_inverse_search(data, split="perturbation"):
    model, info, cfg = _train(data, split)
    device = next(model.parameters()).device
    sp = load_split(data.dir, split)
    cands = candidates_singles(data)
    targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
    gc = data.functional_clusters(seed=0)
    pred = PivotPredictor(model, data, device)
    common = dict(reward_kind="centroid", gene_cluster=gc, model=model, device=device)
    strategies = {
        "ranking-only": dict(method="ranking"),
        "random-init-opt": dict(method="guidance", guidance_init="random", rerank=False),
        "mean-top-init": dict(method="guidance", guidance_init="mean_top", rerank=False),
        "guidance-no-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=False, rerank=False),
        "guidance-norm": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=False),
        "guidance+rerank": dict(method="guidance", guidance_init="warm", guidance_normalize=True, rerank=True),
    }
    out = {"split": split, "rows": {}}
    for name, kw in strategies.items():
        r = evaluate_nomination(pred, data, targets, cands, data.control_idx, **common, **kw)
        out["rows"][name] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
    return out


# table 14: guidance steps
def ablation_guidance_steps(data, split="perturbation"):
    model, info, cfg = _train(data, split)
    device = next(model.parameters()).device
    sp = load_split(data.dir, split)
    cands = candidates_singles(data)
    targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:30]
    gc = data.functional_clusters(seed=0)
    pred = PivotPredictor(model, data, device)
    out = {"split": split, "rows": {}}
    for steps in [0, 5, 10, 25, 50, 100]:
        if steps == 0:
            r = evaluate_nomination(pred, data, targets, cands, data.control_idx,
                                    method="ranking", gene_cluster=gc, device=device)
        else:
            r = evaluate_nomination(pred, data, targets, cands, data.control_idx, method="guidance",
                                    gene_cluster=gc, model=model, device=device,
                                    guidance_steps=steps, rerank=True)
        out["rows"][str(steps)] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
    return out


# table 15: reward definitions
def ablation_reward(data, split="perturbation"):
    model, info, cfg = _train(data, split)
    device = next(model.parameters()).device
    sp = load_split(data.dir, split)
    cands = candidates_singles(data)
    targets = [p for p in sp["test_perts"] if len(data.parse(p)) == 1 and p in cands][:25]
    gc = data.functional_clusters(seed=0)
    pred = PivotPredictor(model, data, device)
    out = {"split": split, "rows": {}}
    for rk in ["centroid", "cosine", "nn_target", "mmd", "wasserstein"]:
        r = evaluate_nomination(pred, data, targets, cands, data.control_idx, reward_kind=rk,
                                method="ranking", gene_cluster=gc, device=device)
        out["rows"][rk] = {k: r[k] for k in ["top1", "top5", "ndcg", "endpoint_dist"]}
    return out


# table 17: data scaling
def ablation_datascale(data, split="perturbation"):
    out = {"split": split, "rows": {}}
    for frac in [0.1, 0.25, 0.5, 0.75, 1.0]:
        model, info, cfg = _train(data, split, train_frac=frac)
        out["rows"][str(frac)] = _fwd_inv(data, model, split)
    return out


# table 18: control matching
def ablation_matching(data, split="perturbation"):
    from src.data.perturb_data import MATCH_STRATEGIES
    out = {"split": split, "rows": {}}
    for ms in MATCH_STRATEGIES:
        model, info, cfg = _train(data, split, match=ms)
        out["rows"][ms] = _fwd_inv(data, model, split)
    return out


# table 20: compute cost
def ablation_cost(data, split="perturbation"):
    model, info, cfg = _train(data, split)
    device = next(model.parameters()).device
    n_params = sum(p.numel() for p in model.parameters())
    cands = candidates_singles(data)
    pred = PivotPredictor(model, data, device)
    c0 = data.emb[data.control_idx[:128]]
    # pivot endpoint ranking: time per query (one target, rank all candidates)
    import time as _t
    t0 = _t.time()
    _ = [pred.population(c, c0) for c in cands]
    rank_time = _t.time() - t0
    out = {"split": split,
           "PIVOT": {"params": int(n_params), "train_time_s": info["duration_s"],
                     "rank_time_per_query_s": rank_time, "candidate_evals": len(cands)}}
    return out


ABLATIONS = {
    "components": ablation_components,        # table 8
    "representation": ablation_representation,  # table 12
    "inverse_search": ablation_inverse_search,  # table 13
    "guidance_steps": ablation_guidance_steps,  # table 14
    "reward": ablation_reward,                # table 15
    "datascale": ablation_datascale,          # table 17
    "matching": ablation_matching,            # table 18
    "cost": ablation_cost,                    # table 20
}


if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", default="norman")
    ap.add_argument("--ablations", nargs="+", default=list(ABLATIONS.keys()))
    args = ap.parse_args()
    data = load_dataset(args.dataset)
    t0 = time.time()
    for name in args.ablations:
        print(f"=== ablation: {name} ===", flush=True)
        res = ABLATIONS[name](data)
        save_table(f"{args.dataset}_ablation_{name}", res)
        print(f"  done {name} ({time.time()-t0:.0f}s elapsed)", flush=True)
    print(f"all ablations done in {time.time()-t0:.0f}s")