"""gears + ranking baseline for inverse nomination, done entirely in gears' native gene space (fair: top-k recovery is representation-agnostic). trains gears on our held-out-perturbation split, predicts post-perturbation expression for every candidate, and ranks candidates by l2 distance to each held-out target's true mean expression. emits top-1/5, ndcg@10, median rank for single-gene and combination targets. runs in the isolated pivot_gears env.""" import os, json, pickle import numpy as np PROC = "data/processed/norman" split = np.load(f"{PROC}/splits/perturbation.npz", allow_pickle=True) splitC = np.load(f"{PROC}/splits/combination.npz", allow_pickle=True) my_test = [str(p) for p in split["test_perts"]] my_train = [str(p) for p in split["train_perts"]] combo_test = [str(p) for p in splitC["test_perts"]] def to_gears(p): g = [x for x in p.split("_") if x] return g[0] + "+ctrl" if len(g) == 1 else g[0] + "+" + g[1] def ndcg_at_k(ranked, true, k=10): rel = [1.0 if r == true else 0.0 for r in ranked[:k]] dcg = sum(r / np.log2(i + 2) for i, r in enumerate(rel)) return float(dcg) # ideal dcg = 1 (single relevant) -> ndcg = dcg from gears import PertData, GEARS dev = "cuda" pert_data = PertData("./gears_data") pert_data.load(data_name="norman") avail = set(map(str, pert_data.adata.obs["condition"].unique())) gears_test = sorted({to_gears(p) for p in my_test} & avail) gears_train = sorted(({to_gears(p) for p in my_train} & avail) - set(gears_test)) if "ctrl" in avail: gears_train = sorted(set(gears_train) | {"ctrl"}) rng = np.random.default_rng(0) val = sorted(rng.choice([c for c in gears_train if c != "ctrl"], size=max(1, len(gears_train) // 10), replace=False).tolist()) gears_train = sorted(set(gears_train) - set(val)) split_dict = {"train": gears_train, "val": val, "test": gears_test} os.makedirs("gears_data", exist_ok=True) pickle.dump(split_dict, open("gears_data/pivot_custom_split.pkl", "wb")) pert_data.prepare_split(split="custom", split_dict_path="gears_data/pivot_custom_split.pkl") pert_data.get_dataloader(batch_size=64, test_batch_size=128) model = GEARS(pert_data, device=dev) model.model_initialize(hidden_size=64) model.train(epochs=20) print("GEARS trained", flush=True) adata = pert_data.adata genes = list(adata.var["gene_name"]) if "gene_name" in adata.var else list(adata.var_names) cond = adata.obs["condition"].astype(str).values X = adata.X X = np.asarray(X.todense()) if hasattr(X, "todense") else np.asarray(X) def true_mean(gcond): m = cond == gcond return X[m].mean(0) if m.sum() else None # candidate single genes present in gears as "GENE+ctrl" single_conds = sorted([c for c in avail if c.endswith("+ctrl") and c != "ctrl"]) single_genes = [c.split("+")[0] for c in single_conds] # predict every candidate's post-perturbation expression (gene space) pred = {} B = 64 for i in range(0, len(single_genes), B): chunk = single_genes[i:i + B] try: out = model.predict([[g] for g in chunk]) except Exception as e: print("predict chunk failed", chunk[:3], e, flush=True) out = {} for g in chunk: for key in (g, g + "+ctrl", "_".join([g])): if key in out: pred[g] = np.asarray(out[key]); break print(f"predicted {len(pred)}/{len(single_genes)} single candidates", flush=True) cand_genes = [g for g in single_genes if g in pred] P = np.stack([pred[g] for g in cand_genes]) # [n_cand, n_genes] # single-gene targets (held-out), matched to pivot singles single_targets = [t for t in my_test if "_" not in t and to_gears(t).split("+")[0] in pred] res = {"single": {"top1": [], "top5": [], "ndcg": [], "rank": []}} for t in single_targets: tg = to_gears(t) tm = true_mean(tg) if tm is None: continue d = np.linalg.norm(P - tm[None], axis=1) order = [cand_genes[i] for i in np.argsort(d)] tgene = tg.split("+")[0] res["single"]["top1"].append(float(order[0] == tgene)) res["single"]["top5"].append(float(tgene in order[:5])) res["single"]["ndcg"].append(ndcg_at_k(order, tgene, 10)) res["single"]["rank"].append(int(order.index(tgene) + 1) if tgene in order else len(order)) # combination targets: predict combos and rank among observed combos combo_conds = sorted([c for c in avail if c.count("+") == 1 and "ctrl" not in c]) cpred = {} for i in range(0, len(combo_conds), B): chunk = combo_conds[i:i + B] try: out = model.predict([c.split("+") for c in chunk]) except Exception as e: print("combo predict failed", e, flush=True); out = {} for c in chunk: for key in (c, c.replace("+", "_"), "_".join(c.split("+"))): if key in out: cpred[c] = np.asarray(out[key]); break print(f"predicted {len(cpred)}/{len(combo_conds)} combo candidates", flush=True) cc = [c for c in combo_conds if c in cpred] if cc: PC = np.stack([cpred[c] for c in cc]) combo_tg = [to_gears(t) for t in combo_test if to_gears(t) in cpred] res["combo"] = {"top1": [], "top5": [], "overlap": [], "ndcg": []} for tg in combo_tg: tm = true_mean(tg) if tm is None: continue d = np.linalg.norm(PC - tm[None], axis=1) order = [cc[i] for i in np.argsort(d)] res["combo"]["top1"].append(float(order[0] == tg)) res["combo"]["top5"].append(float(tg in order[:5])) res["combo"]["ndcg"].append(ndcg_at_k(order, tg, 10)) s_true = set(tg.split("+")) res["combo"]["overlap"].append(len(set(order[0].split("+")) & s_true) / len(s_true)) agg = {} for grp, d in res.items(): agg[grp] = {k: (float(np.mean(v)) if v else None) for k, v in d.items()} agg[grp]["n"] = len(next(iter(d.values()))) if "rank" in d and d["rank"]: agg[grp]["med_rank"] = float(np.median(d["rank"])) agg["n_candidates_single"] = len(cand_genes) agg["n_candidates_combo"] = len(cc) if cc else 0 json.dump(agg, open("experiments/results/gears_ranking.json", "w"), indent=2, default=float) print("GEARS_RANK_DONE", json.dumps(agg), flush=True)