PIVOT / src /data /splits.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
6.66 kB
"""deterministic, serialized data splits.
six split regimes, each guaranteeing the stated generalization gap with no leakage:
cell held-out cells of seen perturbations (interpolation over cells)
perturbation held-out perturbation labels (cells never in train; genes may be seen elsewhere)
gene held-out target genes (gene id never appears in any train perturbation)
gene_family held-out functional families (effect-correlation clusters of genes)
combination held-out combos; train on singles + remaining combos
cell_type cross-dataset transfer (handled at the experiment level, marker only here)
each split serializes train/val/test cell indices plus the train/test perturbation
label sets, so forward training and inverse-nomination eval read identical partitions.
control cells are always available in train (and proportionally in val/test for matching).
"""
from __future__ import annotations
import argparse
import os
import numpy as np
from src.data.perturb_data import PerturbData, load_dataset
SPLIT_TYPES = ("cell", "perturbation", "gene", "gene_family", "combination")
def _split_control(control_idx, rng, fracs=(0.7, 0.15, 0.15)):
perm = rng.permutation(control_idx)
n = len(perm)
a, b = int(fracs[0] * n), int((fracs[0] + fracs[1]) * n)
return perm[:a], perm[a:b], perm[b:]
def make_split(data: PerturbData, split_type: str, seed: int = 0, holdout_frac: float = 0.2):
rng = np.random.default_rng(seed)
perts = list(data.perturbations)
ctrl_tr, ctrl_va, ctrl_te = _split_control(data.control_idx, rng)
def cells(pert_list):
if not pert_list:
return np.array([], dtype=np.int64)
return np.concatenate([data.pert_to_idx[p] for p in pert_list])
if split_type == "cell":
# all perturbations seen; split each perturbation's cells 70/15/15
tr, va, te = [], [], []
for p in perts:
idx = rng.permutation(data.pert_to_idx[p])
n = len(idx)
a, b = int(0.7 * n), int(0.85 * n)
tr.append(idx[:a]); va.append(idx[a:b]); te.append(idx[b:])
train_idx = np.concatenate(tr + [ctrl_tr])
val_idx = np.concatenate(va + [ctrl_va])
test_idx = np.concatenate(te + [ctrl_te])
train_perts = test_perts = perts
elif split_type == "perturbation":
held = set(rng.choice(perts, size=max(1, int(holdout_frac * len(perts))), replace=False))
remaining = [p for p in perts if p not in held]
# carve a disjoint validation set of perturbations (no train/val cell overlap)
n_val = max(1, len(remaining) // 8)
val_perts = set(rng.choice(remaining, size=n_val, replace=False)) if remaining else set()
train_perts = [p for p in remaining if p not in val_perts]
test_perts = sorted(held)
train_idx = np.concatenate([cells(train_perts), ctrl_tr])
val_idx = np.concatenate([cells(sorted(val_perts)), ctrl_va]) if val_perts else ctrl_va
test_idx = np.concatenate([cells(test_perts), ctrl_te])
# leakage guard: val perturbations must be disjoint from train
assert not (set(val_perts) & set(train_perts)), "LEAKAGE: val/train perturbation overlap"
elif split_type in ("gene", "gene_family"):
if split_type == "gene":
units = data.genes_vocab
unit_of = {g: g for g in units}
else:
fam = data.functional_clusters(seed=seed) # gene -> cluster id
unit_of = {g: f"fam{fam.get(g, -1)}" for g in data.genes_vocab}
units = sorted(set(unit_of.values()))
held_units = set(rng.choice(units, size=max(1, int(holdout_frac * len(units))), replace=False))
held_genes = {g for g in data.genes_vocab if unit_of[g] in held_units}
test_perts = [p for p in perts if any(g in held_genes for g in data.parse(p))]
train_perts = [p for p in perts if p not in set(test_perts)]
train_idx = np.concatenate([cells(train_perts), ctrl_tr])
val_idx = ctrl_va
test_idx = np.concatenate([cells(test_perts), ctrl_te]) if test_perts else ctrl_te
elif split_type == "combination":
combos = data.combos
if not combos:
raise ValueError(f"{data.meta['name']} has no combinations")
held = set(rng.choice(combos, size=max(1, int(holdout_frac * len(combos))), replace=False))
test_perts = sorted(held)
train_perts = [p for p in perts if p not in held]
train_idx = np.concatenate([cells(train_perts), ctrl_tr])
val_idx = ctrl_va
test_idx = np.concatenate([cells(test_perts), ctrl_te])
else:
raise ValueError(split_type)
# leakage guard: test-only perturbations must not appear in train cells
if split_type != "cell":
train_label_set = set(np.unique(data.obs.loc[train_idx, "perturbation"].values))
for p in test_perts:
assert p not in train_label_set or split_type == "cell", \
f"LEAKAGE: held-out perturbation {p} present in train cells"
return dict(
split_type=split_type,
seed=seed,
train_idx=train_idx.astype(np.int64),
val_idx=val_idx.astype(np.int64),
test_idx=test_idx.astype(np.int64),
train_perts=np.array(train_perts),
test_perts=np.array(test_perts),
)
def save_split(d: dict, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
path = os.path.join(out_dir, f"{d['split_type']}.npz")
np.savez(path, **{k: v for k, v in d.items() if isinstance(v, np.ndarray)},
meta=np.array([d["split_type"], str(d["seed"])]))
return path
def load_split(cache_dir: str, split_type: str) -> dict:
z = np.load(os.path.join(cache_dir, "splits", f"{split_type}.npz"), allow_pickle=True)
return {k: z[k] for k in z.files}
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("dataset")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--root", default="data/processed")
args = ap.parse_args()
data = load_dataset(args.dataset, root=args.root)
out = os.path.join(args.root, args.dataset, "splits")
for st in SPLIT_TYPES:
if st == "combination" and not data.combos:
print(f"skip {st} (no combos in {args.dataset})")
continue
d = make_split(data, st, seed=args.seed)
p = save_split(d, out)
print(f"{st:14s} train_cells={len(d['train_idx']):7d} "
f"val={len(d['val_idx']):6d} test={len(d['test_idx']):6d} "
f"train_perts={len(d['train_perts']):4d} test_perts={len(d['test_perts']):4d} -> {p}")