| --- |
| license: mit |
| pipeline_tag: other |
| tags: |
| - biology |
| - genomics |
| - gene-perturbation |
| - RAG |
| --- |
| |
| # PT-RAG: Retrieval-Augmented Generation for Predicting Cellular Responses to Gene Perturbation |
|
|
| PT-RAG (Perturbation-aware Two-stage Retrieval-Augmented Generation) is a novel framework that extends Retrieval-Augmented Generation to cellular biology. It is designed to predict how cells respond to genetic perturbations by using a two-stage differentiable retrieval pipeline. |
|
|
| - **Paper:** [Retrieval-Augmented Generation for Predicting Cellular Responses to Gene Perturbation](https://huggingface.co/papers/2603.07233) |
| - **GitHub Repository:** [https://github.com/difra100/PT-RAG_ICLR](https://github.com/difra100/PT-RAG_ICLR) |
| - **Status:** Accepted at ICLR 2026 Workshop (Gen² @ ICLR 2026) |
|
|
| ## Overview |
|
|
| PT-RAG addresses the challenge of modeling single-cell perturbation responses by leveraging context-aware retrieval. Unlike standard RAG systems, it uses a differentiable mechanism to learn what constitutes relevant context. The pipeline consists of: |
| 1. **Candidate Retrieval**: Retrieving candidate perturbations using GenePT embeddings. |
| 2. **Adaptive Refinement**: Refining the selection through Gumbel-Softmax discrete sampling conditioned on cell state and input perturbation. |
|
|
| ## Installation |
|
|
| To set up the environment and install the necessary dependencies: |
|
|
| ```bash |
| # Create a new conda environment |
| conda create -n ptrag python=3.11 -y |
| conda activate ptrag |
| |
| # Install the base package |
| pip install -e . |
| |
| # Install RAG dependencies |
| pip install -r requirements.txt |
| ``` |
|
|
| ## Sample Usage |
|
|
| ### Training PT-RAG |
| To train a model with differentiable retrieval and sparsity regularization: |
|
|
| ```bash |
| python -m state.__main__ tx train \ |
| data.kwargs.toml_config_path=datasets/repogle_nadig_jurkat.toml \ |
| training.rag=true \ |
| training.differentiable_rag=true \ |
| training.retrieve_than_predict=true \ |
| training.gumbel_sparsity_loss=true \ |
| training.gumbel_sparsity_weight=0.1 \ |
| training.topk_rag=32 \ |
| training.use_genept=true \ |
| model=state \ |
| output_dir=experiments/ptrag_model \ |
| name=jurkat_ptrag_sparsity0.1 |
| ``` |
|
|
| ### Inference |
| The differentiable RAG index and learned weights are automatically loaded during inference: |
|
|
| ```bash |
| python -m state.__main__ tx predict \ |
| --output-dir experiments/ptrag_model \ |
| --checkpoint last.ckpt \ |
| --eval-genept-pert |
| ``` |
|
|
| ## Citation |
|
|
| If you find this work useful, please cite: |
|
|
| ```bibtex |
| @article{difrancesco2026retrieval, |
| title={Retrieval-Augmented Generation for Predicting Cellular Responses to Gene Perturbation}, |
| author={Di Francesco, Andrea Giuseppe and Rubbi, Andrea and Liò, Pietro}, |
| journal={arXiv preprint arXiv:2603.07233}, |
| year={2026} |
| } |
| ``` |
|
|
| ## Acknowledgments |
| This repository builds upon the [State](https://github.com/ArcInstitute/state) model from the Arc Institute. Evaluation metrics are computed using the [GenGeneEval (GGE)](https://github.com/AndreaRubbi/GGE) library. |