| """gears + ranking baseline for inverse nomination, done entirely in gears' native |
| gene space (fair: top-k recovery is representation-agnostic). trains gears on our |
| held-out-perturbation split, predicts post-perturbation expression for every |
| candidate, and ranks candidates by l2 distance to each held-out target's true mean |
| expression. emits top-1/5, ndcg@10, median rank for single-gene and combination |
| targets. runs in the isolated pivot_gears env.""" |
| import os, json, pickle |
| import numpy as np |
|
|
| PROC = "data/processed/norman" |
| split = np.load(f"{PROC}/splits/perturbation.npz", allow_pickle=True) |
| splitC = np.load(f"{PROC}/splits/combination.npz", allow_pickle=True) |
| my_test = [str(p) for p in split["test_perts"]] |
| my_train = [str(p) for p in split["train_perts"]] |
| combo_test = [str(p) for p in splitC["test_perts"]] |
|
|
|
|
| def to_gears(p): |
| g = [x for x in p.split("_") if x] |
| return g[0] + "+ctrl" if len(g) == 1 else g[0] + "+" + g[1] |
|
|
|
|
| def ndcg_at_k(ranked, true, k=10): |
| rel = [1.0 if r == true else 0.0 for r in ranked[:k]] |
| dcg = sum(r / np.log2(i + 2) for i, r in enumerate(rel)) |
| return float(dcg) |
|
|
|
|
| from gears import PertData, GEARS |
|
|
| dev = "cuda" |
| pert_data = PertData("./gears_data") |
| pert_data.load(data_name="norman") |
| avail = set(map(str, pert_data.adata.obs["condition"].unique())) |
| gears_test = sorted({to_gears(p) for p in my_test} & avail) |
| gears_train = sorted(({to_gears(p) for p in my_train} & avail) - set(gears_test)) |
| if "ctrl" in avail: |
| gears_train = sorted(set(gears_train) | {"ctrl"}) |
| rng = np.random.default_rng(0) |
| val = sorted(rng.choice([c for c in gears_train if c != "ctrl"], |
| size=max(1, len(gears_train) // 10), replace=False).tolist()) |
| gears_train = sorted(set(gears_train) - set(val)) |
| split_dict = {"train": gears_train, "val": val, "test": gears_test} |
| os.makedirs("gears_data", exist_ok=True) |
| pickle.dump(split_dict, open("gears_data/pivot_custom_split.pkl", "wb")) |
| pert_data.prepare_split(split="custom", split_dict_path="gears_data/pivot_custom_split.pkl") |
| pert_data.get_dataloader(batch_size=64, test_batch_size=128) |
|
|
| model = GEARS(pert_data, device=dev) |
| model.model_initialize(hidden_size=64) |
| model.train(epochs=20) |
| print("GEARS trained", flush=True) |
|
|
| adata = pert_data.adata |
| genes = list(adata.var["gene_name"]) if "gene_name" in adata.var else list(adata.var_names) |
| cond = adata.obs["condition"].astype(str).values |
| X = adata.X |
| X = np.asarray(X.todense()) if hasattr(X, "todense") else np.asarray(X) |
|
|
|
|
| def true_mean(gcond): |
| m = cond == gcond |
| return X[m].mean(0) if m.sum() else None |
|
|
|
|
| |
| single_conds = sorted([c for c in avail if c.endswith("+ctrl") and c != "ctrl"]) |
| single_genes = [c.split("+")[0] for c in single_conds] |
| |
| pred = {} |
| B = 64 |
| for i in range(0, len(single_genes), B): |
| chunk = single_genes[i:i + B] |
| try: |
| out = model.predict([[g] for g in chunk]) |
| except Exception as e: |
| print("predict chunk failed", chunk[:3], e, flush=True) |
| out = {} |
| for g in chunk: |
| for key in (g, g + "+ctrl", "_".join([g])): |
| if key in out: |
| pred[g] = np.asarray(out[key]); break |
| print(f"predicted {len(pred)}/{len(single_genes)} single candidates", flush=True) |
|
|
| cand_genes = [g for g in single_genes if g in pred] |
| P = np.stack([pred[g] for g in cand_genes]) |
|
|
| |
| single_targets = [t for t in my_test if "_" not in t and to_gears(t).split("+")[0] in pred] |
| res = {"single": {"top1": [], "top5": [], "ndcg": [], "rank": []}} |
| for t in single_targets: |
| tg = to_gears(t) |
| tm = true_mean(tg) |
| if tm is None: |
| continue |
| d = np.linalg.norm(P - tm[None], axis=1) |
| order = [cand_genes[i] for i in np.argsort(d)] |
| tgene = tg.split("+")[0] |
| res["single"]["top1"].append(float(order[0] == tgene)) |
| res["single"]["top5"].append(float(tgene in order[:5])) |
| res["single"]["ndcg"].append(ndcg_at_k(order, tgene, 10)) |
| res["single"]["rank"].append(int(order.index(tgene) + 1) if tgene in order else len(order)) |
|
|
| |
| combo_conds = sorted([c for c in avail if c.count("+") == 1 and "ctrl" not in c]) |
| cpred = {} |
| for i in range(0, len(combo_conds), B): |
| chunk = combo_conds[i:i + B] |
| try: |
| out = model.predict([c.split("+") for c in chunk]) |
| except Exception as e: |
| print("combo predict failed", e, flush=True); out = {} |
| for c in chunk: |
| for key in (c, c.replace("+", "_"), "_".join(c.split("+"))): |
| if key in out: |
| cpred[c] = np.asarray(out[key]); break |
| print(f"predicted {len(cpred)}/{len(combo_conds)} combo candidates", flush=True) |
| cc = [c for c in combo_conds if c in cpred] |
| if cc: |
| PC = np.stack([cpred[c] for c in cc]) |
| combo_tg = [to_gears(t) for t in combo_test if to_gears(t) in cpred] |
| res["combo"] = {"top1": [], "top5": [], "overlap": [], "ndcg": []} |
| for tg in combo_tg: |
| tm = true_mean(tg) |
| if tm is None: |
| continue |
| d = np.linalg.norm(PC - tm[None], axis=1) |
| order = [cc[i] for i in np.argsort(d)] |
| res["combo"]["top1"].append(float(order[0] == tg)) |
| res["combo"]["top5"].append(float(tg in order[:5])) |
| res["combo"]["ndcg"].append(ndcg_at_k(order, tg, 10)) |
| s_true = set(tg.split("+")) |
| res["combo"]["overlap"].append(len(set(order[0].split("+")) & s_true) / len(s_true)) |
|
|
| agg = {} |
| for grp, d in res.items(): |
| agg[grp] = {k: (float(np.mean(v)) if v else None) for k, v in d.items()} |
| agg[grp]["n"] = len(next(iter(d.values()))) |
| if "rank" in d and d["rank"]: |
| agg[grp]["med_rank"] = float(np.median(d["rank"])) |
| agg["n_candidates_single"] = len(cand_genes) |
| agg["n_candidates_combo"] = len(cc) if cc else 0 |
| json.dump(agg, open("experiments/results/gears_ranking.json", "w"), indent=2, default=float) |
| print("GEARS_RANK_DONE", json.dumps(agg), flush=True) |
|
|