PIVOT / scripts /gears_compare.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
2.75 kB
"""real gears head-to-head on norman, aligned to our held-out-perturbation test set.
runs in the isolated pivot_gears env (torch cu118 + pyg + cell-gears, gpu-capable)."""
import sys, os, json
import numpy as np
import torch
# our held-out perturbation test labels (no src import; read the npz directly)
split = np.load("data/processed/norman/splits/perturbation.npz", allow_pickle=True)
my_test = [str(p) for p in split["test_perts"]]
def to_gears(p):
g = [x for x in p.split("_") if x]
return g[0] + "+ctrl" if len(g) == 1 else g[0] + "+" + g[1]
import pickle
from gears import PertData, GEARS
from gears.inference import evaluate, compute_metrics
my_train = [str(p) for p in np.load("data/processed/norman/splits/perturbation.npz",
allow_pickle=True)["train_perts"]]
dev = "cuda"
pert_data = PertData("./gears_data")
pert_data.load(data_name="norman")
avail = set(map(str, pert_data.adata.obs["condition"].unique()))
print("example GEARS conditions:", list(avail)[:6], flush=True)
gears_test = sorted({to_gears(p) for p in my_test} & avail)
gears_train = sorted(({to_gears(p) for p in my_train} & avail) - set(gears_test))
if "ctrl" in avail:
gears_train = sorted(set(gears_train) | {"ctrl"})
# carve a small val set from train (gears requires non-empty val)
rng = np.random.default_rng(0)
val = sorted(rng.choice([c for c in gears_train if c != "ctrl"],
size=max(1, len(gears_train) // 10), replace=False).tolist())
gears_train = sorted(set(gears_train) - set(val))
split_dict = {"train": gears_train, "val": val, "test": gears_test}
print(f"custom split -> train {len(gears_train)} | val {len(val)} | test {len(gears_test)} "
f"(our test {len(my_test)})", flush=True)
os.makedirs("gears_data", exist_ok=True)
sd_path = "gears_data/pivot_custom_split.pkl"
pickle.dump(split_dict, open(sd_path, "wb"))
pert_data.prepare_split(split="custom", split_dict_path=sd_path)
pert_data.get_dataloader(batch_size=64, test_batch_size=128)
model = GEARS(pert_data, device=dev)
model.model_initialize(hidden_size=64)
model.train(epochs=20)
res = evaluate(pert_data.dataloader["test_loader"], model.model, False, dev)
out = compute_metrics(res)
metrics, per_pert = out if isinstance(out, tuple) else (out, {})
keep = {k: float(v) for k, v in metrics.items()
if any(s in k for s in ["pearson", "mse", "r2", "spearman"])}
print("GEARS held-out-perturbation metrics:", flush=True)
for k, v in sorted(keep.items()):
print(f" {k:24s} {v:.4f}", flush=True)
json.dump({"n_test_perts": len(gears_test), "test_perts": gears_test, "metrics": keep},
open("experiments/results/gears_norman.json", "w"), indent=2, default=float)
print("GEARS_RUN_DONE", flush=True)