"""deterministic, serialized data splits. six split regimes, each guaranteeing the stated generalization gap with no leakage: cell held-out cells of seen perturbations (interpolation over cells) perturbation held-out perturbation labels (cells never in train; genes may be seen elsewhere) gene held-out target genes (gene id never appears in any train perturbation) gene_family held-out functional families (effect-correlation clusters of genes) combination held-out combos; train on singles + remaining combos cell_type cross-dataset transfer (handled at the experiment level, marker only here) each split serializes train/val/test cell indices plus the train/test perturbation label sets, so forward training and inverse-nomination eval read identical partitions. control cells are always available in train (and proportionally in val/test for matching). """ from __future__ import annotations import argparse import os import numpy as np from src.data.perturb_data import PerturbData, load_dataset SPLIT_TYPES = ("cell", "perturbation", "gene", "gene_family", "combination") def _split_control(control_idx, rng, fracs=(0.7, 0.15, 0.15)): perm = rng.permutation(control_idx) n = len(perm) a, b = int(fracs[0] * n), int((fracs[0] + fracs[1]) * n) return perm[:a], perm[a:b], perm[b:] def make_split(data: PerturbData, split_type: str, seed: int = 0, holdout_frac: float = 0.2): rng = np.random.default_rng(seed) perts = list(data.perturbations) ctrl_tr, ctrl_va, ctrl_te = _split_control(data.control_idx, rng) def cells(pert_list): if not pert_list: return np.array([], dtype=np.int64) return np.concatenate([data.pert_to_idx[p] for p in pert_list]) if split_type == "cell": # all perturbations seen; split each perturbation's cells 70/15/15 tr, va, te = [], [], [] for p in perts: idx = rng.permutation(data.pert_to_idx[p]) n = len(idx) a, b = int(0.7 * n), int(0.85 * n) tr.append(idx[:a]); va.append(idx[a:b]); te.append(idx[b:]) train_idx = np.concatenate(tr + [ctrl_tr]) val_idx = np.concatenate(va + [ctrl_va]) test_idx = np.concatenate(te + [ctrl_te]) train_perts = test_perts = perts elif split_type == "perturbation": held = set(rng.choice(perts, size=max(1, int(holdout_frac * len(perts))), replace=False)) remaining = [p for p in perts if p not in held] # carve a disjoint validation set of perturbations (no train/val cell overlap) n_val = max(1, len(remaining) // 8) val_perts = set(rng.choice(remaining, size=n_val, replace=False)) if remaining else set() train_perts = [p for p in remaining if p not in val_perts] test_perts = sorted(held) train_idx = np.concatenate([cells(train_perts), ctrl_tr]) val_idx = np.concatenate([cells(sorted(val_perts)), ctrl_va]) if val_perts else ctrl_va test_idx = np.concatenate([cells(test_perts), ctrl_te]) # leakage guard: val perturbations must be disjoint from train assert not (set(val_perts) & set(train_perts)), "LEAKAGE: val/train perturbation overlap" elif split_type in ("gene", "gene_family"): if split_type == "gene": units = data.genes_vocab unit_of = {g: g for g in units} else: fam = data.functional_clusters(seed=seed) # gene -> cluster id unit_of = {g: f"fam{fam.get(g, -1)}" for g in data.genes_vocab} units = sorted(set(unit_of.values())) held_units = set(rng.choice(units, size=max(1, int(holdout_frac * len(units))), replace=False)) held_genes = {g for g in data.genes_vocab if unit_of[g] in held_units} test_perts = [p for p in perts if any(g in held_genes for g in data.parse(p))] train_perts = [p for p in perts if p not in set(test_perts)] train_idx = np.concatenate([cells(train_perts), ctrl_tr]) val_idx = ctrl_va test_idx = np.concatenate([cells(test_perts), ctrl_te]) if test_perts else ctrl_te elif split_type == "combination": combos = data.combos if not combos: raise ValueError(f"{data.meta['name']} has no combinations") held = set(rng.choice(combos, size=max(1, int(holdout_frac * len(combos))), replace=False)) test_perts = sorted(held) train_perts = [p for p in perts if p not in held] train_idx = np.concatenate([cells(train_perts), ctrl_tr]) val_idx = ctrl_va test_idx = np.concatenate([cells(test_perts), ctrl_te]) else: raise ValueError(split_type) # leakage guard: test-only perturbations must not appear in train cells if split_type != "cell": train_label_set = set(np.unique(data.obs.loc[train_idx, "perturbation"].values)) for p in test_perts: assert p not in train_label_set or split_type == "cell", \ f"LEAKAGE: held-out perturbation {p} present in train cells" return dict( split_type=split_type, seed=seed, train_idx=train_idx.astype(np.int64), val_idx=val_idx.astype(np.int64), test_idx=test_idx.astype(np.int64), train_perts=np.array(train_perts), test_perts=np.array(test_perts), ) def save_split(d: dict, out_dir: str): os.makedirs(out_dir, exist_ok=True) path = os.path.join(out_dir, f"{d['split_type']}.npz") np.savez(path, **{k: v for k, v in d.items() if isinstance(v, np.ndarray)}, meta=np.array([d["split_type"], str(d["seed"])])) return path def load_split(cache_dir: str, split_type: str) -> dict: z = np.load(os.path.join(cache_dir, "splits", f"{split_type}.npz"), allow_pickle=True) return {k: z[k] for k in z.files} if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("dataset") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--root", default="data/processed") args = ap.parse_args() data = load_dataset(args.dataset, root=args.root) out = os.path.join(args.root, args.dataset, "splits") for st in SPLIT_TYPES: if st == "combination" and not data.combos: print(f"skip {st} (no combos in {args.dataset})") continue d = make_split(data, st, seed=args.seed) p = save_split(d, out) print(f"{st:14s} train_cells={len(d['train_idx']):7d} " f"val={len(d['val_idx']):6d} test={len(d['test_idx']):6d} " f"train_perts={len(d['train_perts']):4d} test_perts={len(d['test_perts']):4d} -> {p}")