PIVOT / README.md
bryan7264's picture
readme: reference .ptw weights
4f7cbb1 verified
|
Raw
History Blame
2.7 kB

PIVOT

Perturbation-Informed Vector-field Optimization for Transcriptomic state control.

PIVOT learns a perturbation-conditioned flow map over single-cell state embeddings and uses its Jacobian for differentiable inverse design: given a control cell state and a desired target state, it nominates gene-level interventions that move cells toward the target. The same model also does ordinary forward response prediction.

Layout

src/
  data/         loading + preprocessing of perturb-seq data, splits
  models/       perturbation encoder, flow map, the PIVOT module
  training/     training loop and losses
  evaluation/   inference, rewards, metrics, baselines
  experiments/  drivers for the result tables, ablations, figures
  utils/
scripts/        figure generation, extra ablations, GEARS comparison
experiments/    saved result json

Setup

pip install -r requirements.txt

Data is not committed. Download and preprocess from the public sources first (Norman 2019 and Replogle 2022 are pulled from scPerturb):

python -m src.data.preprocess norman
python -m src.data.preprocess replogle_k562

This writes a PCA(50) embedding over 2000 highly variable genes plus the held-out splits to data/processed/<dataset>/.

Running things

# train one model
python -m src.training.train --dataset norman --split perturbation

# forward + nomination tables
python -m src.experiments.run_tables --dataset norman --tables forward_cell forward_perturbation

# ablations
python -m src.experiments.run_ablations --dataset norman

# figures
python scripts/figures.py

The GEARS head-to-head runs in its own conda env (older torch + pyg), since the package is finicky about versions:

bash scripts/setup_gears_env.sh
conda run -n pivot_gears python scripts/gears_ranking.py

Models

Trained PIVOT checkpoints live under models/:

  • models/norman/ - trained on Norman 2019 (CRISPRa K562)
  • models/replogle_k562/ - trained on Replogle 2022 (CRISPRi K562)

each folder has model.ptw (a plain torch state dict), config.json (the training config), and train_info.json (history + run info). loading needs the matching preprocessed dataset, since the perturbation encoder vocabulary comes from the data:

import json, torch
from src.data.perturb_data import load_dataset
from src.training.train import TrainConfig, make_model

cfg = TrainConfig(**json.load(open("models/norman/config.json")))
data = load_dataset(cfg.dataset)
model = make_model(data, cfg, device="cpu")
model.load_state_dict(torch.load("models/norman/model.ptw", map_location="cpu"))
model.eval()

License

MIT, Bryan Cheng 2026. See LICENSE.