--- license: mit pipeline_tag: other tags: - biology - genomics - gene-perturbation - RAG --- # PT-RAG: Retrieval-Augmented Generation for Predicting Cellular Responses to Gene Perturbation PT-RAG (Perturbation-aware Two-stage Retrieval-Augmented Generation) is a novel framework that extends Retrieval-Augmented Generation to cellular biology. It is designed to predict how cells respond to genetic perturbations by using a two-stage differentiable retrieval pipeline. - **Paper:** [Retrieval-Augmented Generation for Predicting Cellular Responses to Gene Perturbation](https://huggingface.co/papers/2603.07233) - **GitHub Repository:** [https://github.com/difra100/PT-RAG_ICLR](https://github.com/difra100/PT-RAG_ICLR) - **Status:** Accepted at ICLR 2026 Workshop (Gen² @ ICLR 2026) ## Overview PT-RAG addresses the challenge of modeling single-cell perturbation responses by leveraging context-aware retrieval. Unlike standard RAG systems, it uses a differentiable mechanism to learn what constitutes relevant context. The pipeline consists of: 1. **Candidate Retrieval**: Retrieving candidate perturbations using GenePT embeddings. 2. **Adaptive Refinement**: Refining the selection through Gumbel-Softmax discrete sampling conditioned on cell state and input perturbation. ## Installation To set up the environment and install the necessary dependencies: ```bash # Create a new conda environment conda create -n ptrag python=3.11 -y conda activate ptrag # Install the base package pip install -e . # Install RAG dependencies pip install -r requirements.txt ``` ## Sample Usage ### Training PT-RAG To train a model with differentiable retrieval and sparsity regularization: ```bash python -m state.__main__ tx train \ data.kwargs.toml_config_path=datasets/repogle_nadig_jurkat.toml \ training.rag=true \ training.differentiable_rag=true \ training.retrieve_than_predict=true \ training.gumbel_sparsity_loss=true \ training.gumbel_sparsity_weight=0.1 \ training.topk_rag=32 \ training.use_genept=true \ model=state \ output_dir=experiments/ptrag_model \ name=jurkat_ptrag_sparsity0.1 ``` ### Inference The differentiable RAG index and learned weights are automatically loaded during inference: ```bash python -m state.__main__ tx predict \ --output-dir experiments/ptrag_model \ --checkpoint last.ckpt \ --eval-genept-pert ``` ## Citation If you find this work useful, please cite: ```bibtex @article{difrancesco2026retrieval, title={Retrieval-Augmented Generation for Predicting Cellular Responses to Gene Perturbation}, author={Di Francesco, Andrea Giuseppe and Rubbi, Andrea and Liò, Pietro}, journal={arXiv preprint arXiv:2603.07233}, year={2026} } ``` ## Acknowledgments This repository builds upon the [State](https://github.com/ArcInstitute/state) model from the Arc Institute. Evaluation metrics are computed using the [GenGeneEval (GGE)](https://github.com/AndreaRubbi/GGE) library.