PIVOT / README.md
bryan7264's picture
readme: reference .ptw weights
4f7cbb1 verified
|
Raw
History Blame
2.7 kB
# PIVOT
Perturbation-Informed Vector-field Optimization for Transcriptomic state control.
PIVOT learns a perturbation-conditioned flow map over single-cell state embeddings and
uses its Jacobian for differentiable inverse design: given a control cell state and a
desired target state, it nominates gene-level interventions that move cells toward the
target. The same model also does ordinary forward response prediction.
## Layout
```
src/
data/ loading + preprocessing of perturb-seq data, splits
models/ perturbation encoder, flow map, the PIVOT module
training/ training loop and losses
evaluation/ inference, rewards, metrics, baselines
experiments/ drivers for the result tables, ablations, figures
utils/
scripts/ figure generation, extra ablations, GEARS comparison
experiments/ saved result json
```
## Setup
```bash
pip install -r requirements.txt
```
Data is not committed. Download and preprocess from the public sources first (Norman 2019
and Replogle 2022 are pulled from scPerturb):
```bash
python -m src.data.preprocess norman
python -m src.data.preprocess replogle_k562
```
This writes a PCA(50) embedding over 2000 highly variable genes plus the held-out splits to
`data/processed/<dataset>/`.
## Running things
```bash
# train one model
python -m src.training.train --dataset norman --split perturbation
# forward + nomination tables
python -m src.experiments.run_tables --dataset norman --tables forward_cell forward_perturbation
# ablations
python -m src.experiments.run_ablations --dataset norman
# figures
python scripts/figures.py
```
The GEARS head-to-head runs in its own conda env (older torch + pyg), since the package is
finicky about versions:
```bash
bash scripts/setup_gears_env.sh
conda run -n pivot_gears python scripts/gears_ranking.py
```
## Models
Trained PIVOT checkpoints live under `models/`:
- `models/norman/` - trained on Norman 2019 (CRISPRa K562)
- `models/replogle_k562/` - trained on Replogle 2022 (CRISPRi K562)
each folder has `model.ptw` (a plain torch state dict), `config.json` (the training config),
and `train_info.json` (history + run info). loading needs the matching preprocessed dataset,
since the perturbation encoder vocabulary comes from the data:
```python
import json, torch
from src.data.perturb_data import load_dataset
from src.training.train import TrainConfig, make_model
cfg = TrainConfig(**json.load(open("models/norman/config.json")))
data = load_dataset(cfg.dataset)
model = make_model(data, cfg, device="cpu")
model.load_state_dict(torch.load("models/norman/model.ptw", map_location="cpu"))
model.eval()
```
## License
MIT, Bryan Cheng 2026. See `LICENSE`.