| """real gears head-to-head on norman, aligned to our held-out-perturbation test set. |
| runs in the isolated pivot_gears env (torch cu118 + pyg + cell-gears, gpu-capable).""" |
| import sys, os, json |
| import numpy as np |
| import torch |
|
|
| |
| split = np.load("data/processed/norman/splits/perturbation.npz", allow_pickle=True) |
| my_test = [str(p) for p in split["test_perts"]] |
|
|
| def to_gears(p): |
| g = [x for x in p.split("_") if x] |
| return g[0] + "+ctrl" if len(g) == 1 else g[0] + "+" + g[1] |
|
|
| import pickle |
| from gears import PertData, GEARS |
| from gears.inference import evaluate, compute_metrics |
|
|
| my_train = [str(p) for p in np.load("data/processed/norman/splits/perturbation.npz", |
| allow_pickle=True)["train_perts"]] |
| dev = "cuda" |
| pert_data = PertData("./gears_data") |
| pert_data.load(data_name="norman") |
| avail = set(map(str, pert_data.adata.obs["condition"].unique())) |
| print("example GEARS conditions:", list(avail)[:6], flush=True) |
| gears_test = sorted({to_gears(p) for p in my_test} & avail) |
| gears_train = sorted(({to_gears(p) for p in my_train} & avail) - set(gears_test)) |
| if "ctrl" in avail: |
| gears_train = sorted(set(gears_train) | {"ctrl"}) |
| |
| rng = np.random.default_rng(0) |
| val = sorted(rng.choice([c for c in gears_train if c != "ctrl"], |
| size=max(1, len(gears_train) // 10), replace=False).tolist()) |
| gears_train = sorted(set(gears_train) - set(val)) |
| split_dict = {"train": gears_train, "val": val, "test": gears_test} |
| print(f"custom split -> train {len(gears_train)} | val {len(val)} | test {len(gears_test)} " |
| f"(our test {len(my_test)})", flush=True) |
| os.makedirs("gears_data", exist_ok=True) |
| sd_path = "gears_data/pivot_custom_split.pkl" |
| pickle.dump(split_dict, open(sd_path, "wb")) |
| pert_data.prepare_split(split="custom", split_dict_path=sd_path) |
| pert_data.get_dataloader(batch_size=64, test_batch_size=128) |
|
|
| model = GEARS(pert_data, device=dev) |
| model.model_initialize(hidden_size=64) |
| model.train(epochs=20) |
|
|
| res = evaluate(pert_data.dataloader["test_loader"], model.model, False, dev) |
| out = compute_metrics(res) |
| metrics, per_pert = out if isinstance(out, tuple) else (out, {}) |
| keep = {k: float(v) for k, v in metrics.items() |
| if any(s in k for s in ["pearson", "mse", "r2", "spearman"])} |
| print("GEARS held-out-perturbation metrics:", flush=True) |
| for k, v in sorted(keep.items()): |
| print(f" {k:24s} {v:.4f}", flush=True) |
| json.dump({"n_test_perts": len(gears_test), "test_perts": gears_test, "metrics": keep}, |
| open("experiments/results/gears_norman.json", "w"), indent=2, default=float) |
| print("GEARS_RUN_DONE", flush=True) |
|
|