PIVOT / scripts /gears_ranking.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
6.1 kB
"""gears + ranking baseline for inverse nomination, done entirely in gears' native
gene space (fair: top-k recovery is representation-agnostic). trains gears on our
held-out-perturbation split, predicts post-perturbation expression for every
candidate, and ranks candidates by l2 distance to each held-out target's true mean
expression. emits top-1/5, ndcg@10, median rank for single-gene and combination
targets. runs in the isolated pivot_gears env."""
import os, json, pickle
import numpy as np
PROC = "data/processed/norman"
split = np.load(f"{PROC}/splits/perturbation.npz", allow_pickle=True)
splitC = np.load(f"{PROC}/splits/combination.npz", allow_pickle=True)
my_test = [str(p) for p in split["test_perts"]]
my_train = [str(p) for p in split["train_perts"]]
combo_test = [str(p) for p in splitC["test_perts"]]
def to_gears(p):
g = [x for x in p.split("_") if x]
return g[0] + "+ctrl" if len(g) == 1 else g[0] + "+" + g[1]
def ndcg_at_k(ranked, true, k=10):
rel = [1.0 if r == true else 0.0 for r in ranked[:k]]
dcg = sum(r / np.log2(i + 2) for i, r in enumerate(rel))
return float(dcg) # ideal dcg = 1 (single relevant) -> ndcg = dcg
from gears import PertData, GEARS
dev = "cuda"
pert_data = PertData("./gears_data")
pert_data.load(data_name="norman")
avail = set(map(str, pert_data.adata.obs["condition"].unique()))
gears_test = sorted({to_gears(p) for p in my_test} & avail)
gears_train = sorted(({to_gears(p) for p in my_train} & avail) - set(gears_test))
if "ctrl" in avail:
gears_train = sorted(set(gears_train) | {"ctrl"})
rng = np.random.default_rng(0)
val = sorted(rng.choice([c for c in gears_train if c != "ctrl"],
size=max(1, len(gears_train) // 10), replace=False).tolist())
gears_train = sorted(set(gears_train) - set(val))
split_dict = {"train": gears_train, "val": val, "test": gears_test}
os.makedirs("gears_data", exist_ok=True)
pickle.dump(split_dict, open("gears_data/pivot_custom_split.pkl", "wb"))
pert_data.prepare_split(split="custom", split_dict_path="gears_data/pivot_custom_split.pkl")
pert_data.get_dataloader(batch_size=64, test_batch_size=128)
model = GEARS(pert_data, device=dev)
model.model_initialize(hidden_size=64)
model.train(epochs=20)
print("GEARS trained", flush=True)
adata = pert_data.adata
genes = list(adata.var["gene_name"]) if "gene_name" in adata.var else list(adata.var_names)
cond = adata.obs["condition"].astype(str).values
X = adata.X
X = np.asarray(X.todense()) if hasattr(X, "todense") else np.asarray(X)
def true_mean(gcond):
m = cond == gcond
return X[m].mean(0) if m.sum() else None
# candidate single genes present in gears as "GENE+ctrl"
single_conds = sorted([c for c in avail if c.endswith("+ctrl") and c != "ctrl"])
single_genes = [c.split("+")[0] for c in single_conds]
# predict every candidate's post-perturbation expression (gene space)
pred = {}
B = 64
for i in range(0, len(single_genes), B):
chunk = single_genes[i:i + B]
try:
out = model.predict([[g] for g in chunk])
except Exception as e:
print("predict chunk failed", chunk[:3], e, flush=True)
out = {}
for g in chunk:
for key in (g, g + "+ctrl", "_".join([g])):
if key in out:
pred[g] = np.asarray(out[key]); break
print(f"predicted {len(pred)}/{len(single_genes)} single candidates", flush=True)
cand_genes = [g for g in single_genes if g in pred]
P = np.stack([pred[g] for g in cand_genes]) # [n_cand, n_genes]
# single-gene targets (held-out), matched to pivot singles
single_targets = [t for t in my_test if "_" not in t and to_gears(t).split("+")[0] in pred]
res = {"single": {"top1": [], "top5": [], "ndcg": [], "rank": []}}
for t in single_targets:
tg = to_gears(t)
tm = true_mean(tg)
if tm is None:
continue
d = np.linalg.norm(P - tm[None], axis=1)
order = [cand_genes[i] for i in np.argsort(d)]
tgene = tg.split("+")[0]
res["single"]["top1"].append(float(order[0] == tgene))
res["single"]["top5"].append(float(tgene in order[:5]))
res["single"]["ndcg"].append(ndcg_at_k(order, tgene, 10))
res["single"]["rank"].append(int(order.index(tgene) + 1) if tgene in order else len(order))
# combination targets: predict combos and rank among observed combos
combo_conds = sorted([c for c in avail if c.count("+") == 1 and "ctrl" not in c])
cpred = {}
for i in range(0, len(combo_conds), B):
chunk = combo_conds[i:i + B]
try:
out = model.predict([c.split("+") for c in chunk])
except Exception as e:
print("combo predict failed", e, flush=True); out = {}
for c in chunk:
for key in (c, c.replace("+", "_"), "_".join(c.split("+"))):
if key in out:
cpred[c] = np.asarray(out[key]); break
print(f"predicted {len(cpred)}/{len(combo_conds)} combo candidates", flush=True)
cc = [c for c in combo_conds if c in cpred]
if cc:
PC = np.stack([cpred[c] for c in cc])
combo_tg = [to_gears(t) for t in combo_test if to_gears(t) in cpred]
res["combo"] = {"top1": [], "top5": [], "overlap": [], "ndcg": []}
for tg in combo_tg:
tm = true_mean(tg)
if tm is None:
continue
d = np.linalg.norm(PC - tm[None], axis=1)
order = [cc[i] for i in np.argsort(d)]
res["combo"]["top1"].append(float(order[0] == tg))
res["combo"]["top5"].append(float(tg in order[:5]))
res["combo"]["ndcg"].append(ndcg_at_k(order, tg, 10))
s_true = set(tg.split("+"))
res["combo"]["overlap"].append(len(set(order[0].split("+")) & s_true) / len(s_true))
agg = {}
for grp, d in res.items():
agg[grp] = {k: (float(np.mean(v)) if v else None) for k, v in d.items()}
agg[grp]["n"] = len(next(iter(d.values())))
if "rank" in d and d["rank"]:
agg[grp]["med_rank"] = float(np.median(d["rank"]))
agg["n_candidates_single"] = len(cand_genes)
agg["n_candidates_combo"] = len(cc) if cc else 0
json.dump(agg, open("experiments/results/gears_ranking.json", "w"), indent=2, default=float)
print("GEARS_RANK_DONE", json.dumps(agg), flush=True)