PIVOT
Perturbation-Informed Vector-field Optimization for Transcriptomic state control.
PIVOT learns a perturbation-conditioned flow map over single-cell state embeddings and uses its Jacobian for differentiable inverse design: given a control cell state and a desired target state, it nominates gene-level interventions that move cells toward the target. The same model also does ordinary forward response prediction.
Layout
src/
data/ loading + preprocessing of perturb-seq data, splits
models/ perturbation encoder, flow map, the PIVOT module
training/ training loop and losses
evaluation/ inference, rewards, metrics, baselines
experiments/ drivers for the result tables, ablations, figures
utils/
scripts/ figure generation, extra ablations, GEARS comparison
experiments/ saved result json
Setup
pip install -r requirements.txt
Data is not committed. Download and preprocess from the public sources first (Norman 2019 and Replogle 2022 are pulled from scPerturb):
python -m src.data.preprocess norman
python -m src.data.preprocess replogle_k562
This writes a PCA(50) embedding over 2000 highly variable genes plus the held-out splits to
data/processed/<dataset>/.
Running things
# train one model
python -m src.training.train --dataset norman --split perturbation
# forward + nomination tables
python -m src.experiments.run_tables --dataset norman --tables forward_cell forward_perturbation
# ablations
python -m src.experiments.run_ablations --dataset norman
# figures
python scripts/figures.py
The GEARS head-to-head runs in its own conda env (older torch + pyg), since the package is finicky about versions:
bash scripts/setup_gears_env.sh
conda run -n pivot_gears python scripts/gears_ranking.py
Models
Trained PIVOT checkpoints live under models/:
models/norman/- trained on Norman 2019 (CRISPRa K562)models/replogle_k562/- trained on Replogle 2022 (CRISPRi K562)
each folder has model.ptw (a plain torch state dict), config.json (the training config),
and train_info.json (history + run info). loading needs the matching preprocessed dataset,
since the perturbation encoder vocabulary comes from the data:
import json, torch
from src.data.perturb_data import load_dataset
from src.training.train import TrainConfig, make_model
cfg = TrainConfig(**json.load(open("models/norman/config.json")))
data = load_dataset(cfg.dataset)
model = make_model(data, cfg, device="cpu")
model.load_state_dict(torch.load("models/norman/model.ptw", map_location="cpu"))
model.eval()
License
MIT, Bryan Cheng 2026. See LICENSE.