File size: 6,657 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """deterministic, serialized data splits.
six split regimes, each guaranteeing the stated generalization gap with no leakage:
cell held-out cells of seen perturbations (interpolation over cells)
perturbation held-out perturbation labels (cells never in train; genes may be seen elsewhere)
gene held-out target genes (gene id never appears in any train perturbation)
gene_family held-out functional families (effect-correlation clusters of genes)
combination held-out combos; train on singles + remaining combos
cell_type cross-dataset transfer (handled at the experiment level, marker only here)
each split serializes train/val/test cell indices plus the train/test perturbation
label sets, so forward training and inverse-nomination eval read identical partitions.
control cells are always available in train (and proportionally in val/test for matching).
"""
from __future__ import annotations
import argparse
import os
import numpy as np
from src.data.perturb_data import PerturbData, load_dataset
SPLIT_TYPES = ("cell", "perturbation", "gene", "gene_family", "combination")
def _split_control(control_idx, rng, fracs=(0.7, 0.15, 0.15)):
perm = rng.permutation(control_idx)
n = len(perm)
a, b = int(fracs[0] * n), int((fracs[0] + fracs[1]) * n)
return perm[:a], perm[a:b], perm[b:]
def make_split(data: PerturbData, split_type: str, seed: int = 0, holdout_frac: float = 0.2):
rng = np.random.default_rng(seed)
perts = list(data.perturbations)
ctrl_tr, ctrl_va, ctrl_te = _split_control(data.control_idx, rng)
def cells(pert_list):
if not pert_list:
return np.array([], dtype=np.int64)
return np.concatenate([data.pert_to_idx[p] for p in pert_list])
if split_type == "cell":
# all perturbations seen; split each perturbation's cells 70/15/15
tr, va, te = [], [], []
for p in perts:
idx = rng.permutation(data.pert_to_idx[p])
n = len(idx)
a, b = int(0.7 * n), int(0.85 * n)
tr.append(idx[:a]); va.append(idx[a:b]); te.append(idx[b:])
train_idx = np.concatenate(tr + [ctrl_tr])
val_idx = np.concatenate(va + [ctrl_va])
test_idx = np.concatenate(te + [ctrl_te])
train_perts = test_perts = perts
elif split_type == "perturbation":
held = set(rng.choice(perts, size=max(1, int(holdout_frac * len(perts))), replace=False))
remaining = [p for p in perts if p not in held]
# carve a disjoint validation set of perturbations (no train/val cell overlap)
n_val = max(1, len(remaining) // 8)
val_perts = set(rng.choice(remaining, size=n_val, replace=False)) if remaining else set()
train_perts = [p for p in remaining if p not in val_perts]
test_perts = sorted(held)
train_idx = np.concatenate([cells(train_perts), ctrl_tr])
val_idx = np.concatenate([cells(sorted(val_perts)), ctrl_va]) if val_perts else ctrl_va
test_idx = np.concatenate([cells(test_perts), ctrl_te])
# leakage guard: val perturbations must be disjoint from train
assert not (set(val_perts) & set(train_perts)), "LEAKAGE: val/train perturbation overlap"
elif split_type in ("gene", "gene_family"):
if split_type == "gene":
units = data.genes_vocab
unit_of = {g: g for g in units}
else:
fam = data.functional_clusters(seed=seed) # gene -> cluster id
unit_of = {g: f"fam{fam.get(g, -1)}" for g in data.genes_vocab}
units = sorted(set(unit_of.values()))
held_units = set(rng.choice(units, size=max(1, int(holdout_frac * len(units))), replace=False))
held_genes = {g for g in data.genes_vocab if unit_of[g] in held_units}
test_perts = [p for p in perts if any(g in held_genes for g in data.parse(p))]
train_perts = [p for p in perts if p not in set(test_perts)]
train_idx = np.concatenate([cells(train_perts), ctrl_tr])
val_idx = ctrl_va
test_idx = np.concatenate([cells(test_perts), ctrl_te]) if test_perts else ctrl_te
elif split_type == "combination":
combos = data.combos
if not combos:
raise ValueError(f"{data.meta['name']} has no combinations")
held = set(rng.choice(combos, size=max(1, int(holdout_frac * len(combos))), replace=False))
test_perts = sorted(held)
train_perts = [p for p in perts if p not in held]
train_idx = np.concatenate([cells(train_perts), ctrl_tr])
val_idx = ctrl_va
test_idx = np.concatenate([cells(test_perts), ctrl_te])
else:
raise ValueError(split_type)
# leakage guard: test-only perturbations must not appear in train cells
if split_type != "cell":
train_label_set = set(np.unique(data.obs.loc[train_idx, "perturbation"].values))
for p in test_perts:
assert p not in train_label_set or split_type == "cell", \
f"LEAKAGE: held-out perturbation {p} present in train cells"
return dict(
split_type=split_type,
seed=seed,
train_idx=train_idx.astype(np.int64),
val_idx=val_idx.astype(np.int64),
test_idx=test_idx.astype(np.int64),
train_perts=np.array(train_perts),
test_perts=np.array(test_perts),
)
def save_split(d: dict, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
path = os.path.join(out_dir, f"{d['split_type']}.npz")
np.savez(path, **{k: v for k, v in d.items() if isinstance(v, np.ndarray)},
meta=np.array([d["split_type"], str(d["seed"])]))
return path
def load_split(cache_dir: str, split_type: str) -> dict:
z = np.load(os.path.join(cache_dir, "splits", f"{split_type}.npz"), allow_pickle=True)
return {k: z[k] for k in z.files}
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("dataset")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--root", default="data/processed")
args = ap.parse_args()
data = load_dataset(args.dataset, root=args.root)
out = os.path.join(args.root, args.dataset, "splits")
for st in SPLIT_TYPES:
if st == "combination" and not data.combos:
print(f"skip {st} (no combos in {args.dataset})")
continue
d = make_split(data, st, seed=args.seed)
p = save_split(d, out)
print(f"{st:14s} train_cells={len(d['train_idx']):7d} "
f"val={len(d['val_idx']):6d} test={len(d['test_idx']):6d} "
f"train_perts={len(d['train_perts']):4d} test_perts={len(d['test_perts']):4d} -> {p}")
|