| """deterministic, serialized data splits. |
| |
| six split regimes, each guaranteeing the stated generalization gap with no leakage: |
| cell held-out cells of seen perturbations (interpolation over cells) |
| perturbation held-out perturbation labels (cells never in train; genes may be seen elsewhere) |
| gene held-out target genes (gene id never appears in any train perturbation) |
| gene_family held-out functional families (effect-correlation clusters of genes) |
| combination held-out combos; train on singles + remaining combos |
| cell_type cross-dataset transfer (handled at the experiment level, marker only here) |
| |
| each split serializes train/val/test cell indices plus the train/test perturbation |
| label sets, so forward training and inverse-nomination eval read identical partitions. |
| control cells are always available in train (and proportionally in val/test for matching). |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import os |
|
|
| import numpy as np |
|
|
| from src.data.perturb_data import PerturbData, load_dataset |
|
|
| SPLIT_TYPES = ("cell", "perturbation", "gene", "gene_family", "combination") |
|
|
|
|
| def _split_control(control_idx, rng, fracs=(0.7, 0.15, 0.15)): |
| perm = rng.permutation(control_idx) |
| n = len(perm) |
| a, b = int(fracs[0] * n), int((fracs[0] + fracs[1]) * n) |
| return perm[:a], perm[a:b], perm[b:] |
|
|
|
|
| def make_split(data: PerturbData, split_type: str, seed: int = 0, holdout_frac: float = 0.2): |
| rng = np.random.default_rng(seed) |
| perts = list(data.perturbations) |
| ctrl_tr, ctrl_va, ctrl_te = _split_control(data.control_idx, rng) |
|
|
| def cells(pert_list): |
| if not pert_list: |
| return np.array([], dtype=np.int64) |
| return np.concatenate([data.pert_to_idx[p] for p in pert_list]) |
|
|
| if split_type == "cell": |
| |
| tr, va, te = [], [], [] |
| for p in perts: |
| idx = rng.permutation(data.pert_to_idx[p]) |
| n = len(idx) |
| a, b = int(0.7 * n), int(0.85 * n) |
| tr.append(idx[:a]); va.append(idx[a:b]); te.append(idx[b:]) |
| train_idx = np.concatenate(tr + [ctrl_tr]) |
| val_idx = np.concatenate(va + [ctrl_va]) |
| test_idx = np.concatenate(te + [ctrl_te]) |
| train_perts = test_perts = perts |
|
|
| elif split_type == "perturbation": |
| held = set(rng.choice(perts, size=max(1, int(holdout_frac * len(perts))), replace=False)) |
| remaining = [p for p in perts if p not in held] |
| |
| n_val = max(1, len(remaining) // 8) |
| val_perts = set(rng.choice(remaining, size=n_val, replace=False)) if remaining else set() |
| train_perts = [p for p in remaining if p not in val_perts] |
| test_perts = sorted(held) |
| train_idx = np.concatenate([cells(train_perts), ctrl_tr]) |
| val_idx = np.concatenate([cells(sorted(val_perts)), ctrl_va]) if val_perts else ctrl_va |
| test_idx = np.concatenate([cells(test_perts), ctrl_te]) |
| |
| assert not (set(val_perts) & set(train_perts)), "LEAKAGE: val/train perturbation overlap" |
|
|
| elif split_type in ("gene", "gene_family"): |
| if split_type == "gene": |
| units = data.genes_vocab |
| unit_of = {g: g for g in units} |
| else: |
| fam = data.functional_clusters(seed=seed) |
| unit_of = {g: f"fam{fam.get(g, -1)}" for g in data.genes_vocab} |
| units = sorted(set(unit_of.values())) |
| held_units = set(rng.choice(units, size=max(1, int(holdout_frac * len(units))), replace=False)) |
| held_genes = {g for g in data.genes_vocab if unit_of[g] in held_units} |
| test_perts = [p for p in perts if any(g in held_genes for g in data.parse(p))] |
| train_perts = [p for p in perts if p not in set(test_perts)] |
| train_idx = np.concatenate([cells(train_perts), ctrl_tr]) |
| val_idx = ctrl_va |
| test_idx = np.concatenate([cells(test_perts), ctrl_te]) if test_perts else ctrl_te |
|
|
| elif split_type == "combination": |
| combos = data.combos |
| if not combos: |
| raise ValueError(f"{data.meta['name']} has no combinations") |
| held = set(rng.choice(combos, size=max(1, int(holdout_frac * len(combos))), replace=False)) |
| test_perts = sorted(held) |
| train_perts = [p for p in perts if p not in held] |
| train_idx = np.concatenate([cells(train_perts), ctrl_tr]) |
| val_idx = ctrl_va |
| test_idx = np.concatenate([cells(test_perts), ctrl_te]) |
| else: |
| raise ValueError(split_type) |
|
|
| |
| if split_type != "cell": |
| train_label_set = set(np.unique(data.obs.loc[train_idx, "perturbation"].values)) |
| for p in test_perts: |
| assert p not in train_label_set or split_type == "cell", \ |
| f"LEAKAGE: held-out perturbation {p} present in train cells" |
|
|
| return dict( |
| split_type=split_type, |
| seed=seed, |
| train_idx=train_idx.astype(np.int64), |
| val_idx=val_idx.astype(np.int64), |
| test_idx=test_idx.astype(np.int64), |
| train_perts=np.array(train_perts), |
| test_perts=np.array(test_perts), |
| ) |
|
|
|
|
| def save_split(d: dict, out_dir: str): |
| os.makedirs(out_dir, exist_ok=True) |
| path = os.path.join(out_dir, f"{d['split_type']}.npz") |
| np.savez(path, **{k: v for k, v in d.items() if isinstance(v, np.ndarray)}, |
| meta=np.array([d["split_type"], str(d["seed"])])) |
| return path |
|
|
|
|
| def load_split(cache_dir: str, split_type: str) -> dict: |
| z = np.load(os.path.join(cache_dir, "splits", f"{split_type}.npz"), allow_pickle=True) |
| return {k: z[k] for k in z.files} |
|
|
|
|
| if __name__ == "__main__": |
| ap = argparse.ArgumentParser() |
| ap.add_argument("dataset") |
| ap.add_argument("--seed", type=int, default=0) |
| ap.add_argument("--root", default="data/processed") |
| args = ap.parse_args() |
| data = load_dataset(args.dataset, root=args.root) |
| out = os.path.join(args.root, args.dataset, "splits") |
| for st in SPLIT_TYPES: |
| if st == "combination" and not data.combos: |
| print(f"skip {st} (no combos in {args.dataset})") |
| continue |
| d = make_split(data, st, seed=args.seed) |
| p = save_split(d, out) |
| print(f"{st:14s} train_cells={len(d['train_idx']):7d} " |
| f"val={len(d['val_idx']):6d} test={len(d['test_idx']):6d} " |
| f"train_perts={len(d['train_perts']):4d} test_perts={len(d['test_perts']):4d} -> {p}") |
|
|