"""real gears head-to-head on norman, aligned to our held-out-perturbation test set. runs in the isolated pivot_gears env (torch cu118 + pyg + cell-gears, gpu-capable).""" import sys, os, json import numpy as np import torch # our held-out perturbation test labels (no src import; read the npz directly) split = np.load("data/processed/norman/splits/perturbation.npz", allow_pickle=True) my_test = [str(p) for p in split["test_perts"]] def to_gears(p): g = [x for x in p.split("_") if x] return g[0] + "+ctrl" if len(g) == 1 else g[0] + "+" + g[1] import pickle from gears import PertData, GEARS from gears.inference import evaluate, compute_metrics my_train = [str(p) for p in np.load("data/processed/norman/splits/perturbation.npz", allow_pickle=True)["train_perts"]] dev = "cuda" pert_data = PertData("./gears_data") pert_data.load(data_name="norman") avail = set(map(str, pert_data.adata.obs["condition"].unique())) print("example GEARS conditions:", list(avail)[:6], flush=True) gears_test = sorted({to_gears(p) for p in my_test} & avail) gears_train = sorted(({to_gears(p) for p in my_train} & avail) - set(gears_test)) if "ctrl" in avail: gears_train = sorted(set(gears_train) | {"ctrl"}) # carve a small val set from train (gears requires non-empty val) rng = np.random.default_rng(0) val = sorted(rng.choice([c for c in gears_train if c != "ctrl"], size=max(1, len(gears_train) // 10), replace=False).tolist()) gears_train = sorted(set(gears_train) - set(val)) split_dict = {"train": gears_train, "val": val, "test": gears_test} print(f"custom split -> train {len(gears_train)} | val {len(val)} | test {len(gears_test)} " f"(our test {len(my_test)})", flush=True) os.makedirs("gears_data", exist_ok=True) sd_path = "gears_data/pivot_custom_split.pkl" pickle.dump(split_dict, open(sd_path, "wb")) pert_data.prepare_split(split="custom", split_dict_path=sd_path) pert_data.get_dataloader(batch_size=64, test_batch_size=128) model = GEARS(pert_data, device=dev) model.model_initialize(hidden_size=64) model.train(epochs=20) res = evaluate(pert_data.dataloader["test_loader"], model.model, False, dev) out = compute_metrics(res) metrics, per_pert = out if isinstance(out, tuple) else (out, {}) keep = {k: float(v) for k, v in metrics.items() if any(s in k for s in ["pearson", "mse", "r2", "spearman"])} print("GEARS held-out-perturbation metrics:", flush=True) for k, v in sorted(keep.items()): print(f" {k:24s} {v:.4f}", flush=True) json.dump({"n_test_perts": len(gears_test), "test_perts": gears_test, "metrics": keep}, open("experiments/results/gears_norman.json", "w"), indent=2, default=float) print("GEARS_RUN_DONE", flush=True)