File size: 6,657 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""deterministic, serialized data splits.

six split regimes, each guaranteeing the stated generalization gap with no leakage:
  cell         held-out cells of seen perturbations (interpolation over cells)
  perturbation held-out perturbation labels (cells never in train; genes may be seen elsewhere)
  gene         held-out target genes (gene id never appears in any train perturbation)
  gene_family  held-out functional families (effect-correlation clusters of genes)
  combination  held-out combos; train on singles + remaining combos
  cell_type    cross-dataset transfer (handled at the experiment level, marker only here)

each split serializes train/val/test cell indices plus the train/test perturbation
label sets, so forward training and inverse-nomination eval read identical partitions.
control cells are always available in train (and proportionally in val/test for matching).
"""
from __future__ import annotations

import argparse
import os

import numpy as np

from src.data.perturb_data import PerturbData, load_dataset

SPLIT_TYPES = ("cell", "perturbation", "gene", "gene_family", "combination")


def _split_control(control_idx, rng, fracs=(0.7, 0.15, 0.15)):
    perm = rng.permutation(control_idx)
    n = len(perm)
    a, b = int(fracs[0] * n), int((fracs[0] + fracs[1]) * n)
    return perm[:a], perm[a:b], perm[b:]


def make_split(data: PerturbData, split_type: str, seed: int = 0, holdout_frac: float = 0.2):
    rng = np.random.default_rng(seed)
    perts = list(data.perturbations)
    ctrl_tr, ctrl_va, ctrl_te = _split_control(data.control_idx, rng)

    def cells(pert_list):
        if not pert_list:
            return np.array([], dtype=np.int64)
        return np.concatenate([data.pert_to_idx[p] for p in pert_list])

    if split_type == "cell":
        # all perturbations seen; split each perturbation's cells 70/15/15
        tr, va, te = [], [], []
        for p in perts:
            idx = rng.permutation(data.pert_to_idx[p])
            n = len(idx)
            a, b = int(0.7 * n), int(0.85 * n)
            tr.append(idx[:a]); va.append(idx[a:b]); te.append(idx[b:])
        train_idx = np.concatenate(tr + [ctrl_tr])
        val_idx = np.concatenate(va + [ctrl_va])
        test_idx = np.concatenate(te + [ctrl_te])
        train_perts = test_perts = perts

    elif split_type == "perturbation":
        held = set(rng.choice(perts, size=max(1, int(holdout_frac * len(perts))), replace=False))
        remaining = [p for p in perts if p not in held]
        # carve a disjoint validation set of perturbations (no train/val cell overlap)
        n_val = max(1, len(remaining) // 8)
        val_perts = set(rng.choice(remaining, size=n_val, replace=False)) if remaining else set()
        train_perts = [p for p in remaining if p not in val_perts]
        test_perts = sorted(held)
        train_idx = np.concatenate([cells(train_perts), ctrl_tr])
        val_idx = np.concatenate([cells(sorted(val_perts)), ctrl_va]) if val_perts else ctrl_va
        test_idx = np.concatenate([cells(test_perts), ctrl_te])
        # leakage guard: val perturbations must be disjoint from train
        assert not (set(val_perts) & set(train_perts)), "LEAKAGE: val/train perturbation overlap"

    elif split_type in ("gene", "gene_family"):
        if split_type == "gene":
            units = data.genes_vocab
            unit_of = {g: g for g in units}
        else:
            fam = data.functional_clusters(seed=seed)  # gene -> cluster id
            unit_of = {g: f"fam{fam.get(g, -1)}" for g in data.genes_vocab}
            units = sorted(set(unit_of.values()))
        held_units = set(rng.choice(units, size=max(1, int(holdout_frac * len(units))), replace=False))
        held_genes = {g for g in data.genes_vocab if unit_of[g] in held_units}
        test_perts = [p for p in perts if any(g in held_genes for g in data.parse(p))]
        train_perts = [p for p in perts if p not in set(test_perts)]
        train_idx = np.concatenate([cells(train_perts), ctrl_tr])
        val_idx = ctrl_va
        test_idx = np.concatenate([cells(test_perts), ctrl_te]) if test_perts else ctrl_te

    elif split_type == "combination":
        combos = data.combos
        if not combos:
            raise ValueError(f"{data.meta['name']} has no combinations")
        held = set(rng.choice(combos, size=max(1, int(holdout_frac * len(combos))), replace=False))
        test_perts = sorted(held)
        train_perts = [p for p in perts if p not in held]
        train_idx = np.concatenate([cells(train_perts), ctrl_tr])
        val_idx = ctrl_va
        test_idx = np.concatenate([cells(test_perts), ctrl_te])
    else:
        raise ValueError(split_type)

    # leakage guard: test-only perturbations must not appear in train cells
    if split_type != "cell":
        train_label_set = set(np.unique(data.obs.loc[train_idx, "perturbation"].values))
        for p in test_perts:
            assert p not in train_label_set or split_type == "cell", \
                f"LEAKAGE: held-out perturbation {p} present in train cells"

    return dict(
        split_type=split_type,
        seed=seed,
        train_idx=train_idx.astype(np.int64),
        val_idx=val_idx.astype(np.int64),
        test_idx=test_idx.astype(np.int64),
        train_perts=np.array(train_perts),
        test_perts=np.array(test_perts),
    )


def save_split(d: dict, out_dir: str):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, f"{d['split_type']}.npz")
    np.savez(path, **{k: v for k, v in d.items() if isinstance(v, np.ndarray)},
             meta=np.array([d["split_type"], str(d["seed"])]))
    return path


def load_split(cache_dir: str, split_type: str) -> dict:
    z = np.load(os.path.join(cache_dir, "splits", f"{split_type}.npz"), allow_pickle=True)
    return {k: z[k] for k in z.files}


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("dataset")
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--root", default="data/processed")
    args = ap.parse_args()
    data = load_dataset(args.dataset, root=args.root)
    out = os.path.join(args.root, args.dataset, "splits")
    for st in SPLIT_TYPES:
        if st == "combination" and not data.combos:
            print(f"skip {st} (no combos in {args.dataset})")
            continue
        d = make_split(data, st, seed=args.seed)
        p = save_split(d, out)
        print(f"{st:14s} train_cells={len(d['train_idx']):7d} "
              f"val={len(d['val_idx']):6d} test={len(d['test_idx']):6d} "
              f"train_perts={len(d['train_perts']):4d} test_perts={len(d['test_perts']):4d} -> {p}")