## CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning
arXiv License deploy deploy deploy data
This is the official open-source release of CLaRa, a state-of-the-art, end-to-end Retrieval-Augmented Generation model. ### Updates - Dec 11, 2025. All used data are available on [Huggingface](https://huggingface.co/datasets/apple/CLaRa_multi_stage). - Dec 10, 2025. We are working on an MLX version of the model, to be announced soon. - Dec 3, 2025. Evaluation data are available in `./evaluation/evaluation_data`. - Nov 25, 2025. Models are available on Huggingface. ### Motivation Retrieval-Augmented Generation (RAG) enhances large language models with external knowledge but suffers from **long contexts** and **disjoint retrieval-generation optimization**. Existing soft compression frameworks face two key limitations: (i) reconstruction-based objectives bias compressors toward surface patterns rather than semantic preservation; (ii) retrievers and compressors are trained separately, requiring double encoding despite compressed vectors being inherently retrievable. In this work, we investigate: - **How can we improve semantic preservation in compressed representations through better pretraining objectives?** - **How can we unify retrieval and generation optimization to avoid redundant encoding and disjoint objectives?**
We design a Three-stage training approach and introduce document compression techniques to improve RAG efficiency. The key findings are listed below. ### Findings - **Efficient Compression**: CLaRa achieves significant compression rates (32x-64x) while preserving essential information for accurate answer generation. - **Three-Stage Training**: A carefully designed Three-stage training approach (compression pretraining + compression instruction tuning + end-to-end fine-tuning) enables effective learning of both retrieval and generation. For more interesting findings, please refer to our original paper! --- ### Three-Stage Training CLaRa uses a carefully designed three-stage training approach: **Stage 1: Compression Pretraining** - Train the compressor using SCP framework with QA pairs and paraphrases - Retain key semantics through QA-based and paraphrase-guided supervision - Support compression rates of 1x-256x **Stage 2: Compression Instruction Tuning** - Fine-tune the compressor on instruction-following tasks for downstream QA - Use text-based QA output to ensure compressed representations retain sufficient semantics **Stage 3: End-to-End Fine-tuning (CLaRa)** - Jointly train reranker and generator via a single language modeling loss - Unify retrieval and generation in shared continuous space using differentiable top-k estimator In this repository, we release our implementation of **CLaRa**, built upon [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF). ### Getting Started ``` ├── scripts/ # Training and evaluation scripts │ ├── train_pretraining.sh # Stage 1: Compression pretraining │ ├── train_instruction_tuning.sh # Stage 2: Compression instruction tuning │ ├── train_stage_end_to_end.sh # Stage 3: End-to-end training │ └── evaluation_end_to_end.sh # Evaluation scripts ├── openrlhf/ # Core training framework │ ├── models/ # Model implementations │ │ └── modeling_clara.py # CLaRa model definition │ ├── datasets/ # Dataset handling │ │ └── sft_dataset.py # Training dataset │ ├── trainer/ # Training utilities │ │ └── sft_trainer.py # SFT trainer │ └── cli/ # Command line interface │ └── train_sft.py # Main training script ├── evaluation/ # Evaluation framework ├── example/ # Example training data │ ├── pretrain_data.jsonl │ ├── instruction_tuning_data.jsonl │ └── end_to_end_data.jsonl └── README.md # This file ``` Video instruction for installation (from @Fahd Mirza): https://youtu.be/al2VoAKn8GU?si=Q8bq7QNMaTvcArwa Video digest (from @Richard Aragon): https://www.youtube.com/watch?v=yRM92mmKNH4 #### 1. Prepare code and environment Clone the repository and set up the environment: ```bash # Create conda environment env=clara conda create -n $env python=3.10 -y conda activate $env # Install dependencies pip install -r requirements.txt # Set up environment variables export PYTHONPATH=/path/to/clara:$PYTHONPATH ``` Key dependencies include: - PyTorch >= 2.0 - Transformers >= 4.20 - DeepSpeed >= 0.18 - Flash Attention 2 - Accelerate #### 2. Data preparation Prepare training data in JSONL format. For pretraining stage: ```bash # Example data format for pretraining { "data_type": "qa", "question": ["Question 1",], "answers": ["Answer 1"], "docs": ["Document 1"] } ``` For end-to-end training: ```bash { "question": "Single question text", "docs": ["Document 1", "Document 2", ...], "gold_answer": "Reference answer" } ``` #### 3. Start training **Stage 1: Salient Compressor Pretraining (SCP)** Pre-train the document compressor : ```bash bash scripts/train_pretraining.sh ``` Key parameters: - `--compress_rate`: Compression rate (default: 32) - `--doc_max_length`: Maximum document length (default: 256) - `--stage stage1`: Training stage - `--mse_loss`: Use MSE loss to align compressed and original representations - `--qa_loss`: Use QA loss for semantic preservation **Stage 2: Compression Instruction Tuning** Fine-tune the compressor on instruction-following tasks: ```bash bash scripts/train_instruction_tuning.sh ``` Key parameters: - `--pretrain_checkpoint`: Path to stage 1 checkpoint - `--stage stage1_2`: Training stage - `--generation_top_k`: Top-k sampling for generation (default: 5) - `--mse_loss`: Use MSE loss for compression training - `--do_eval_gen`: Enable generation evaluation **Stage 3: End-to-End Training** Fine-tune the model end-to-end with retrieval: ```bash bash scripts/train_stage_end_to_end.sh ``` Key parameters: - `--pretrain_checkpoint`: Path to stage 2 checkpoint - `--stage stage2`: Training stage - `--generation_top_k`: Top-k sampling for generation - `--do_eval_gen`: Enable generation evaluation #### 4. Distributed Training The training scripts support distributed training across multiple nodes and GPUs: - `--max_len`: Maximum sequence length (default: 2048 for stage1/stage2, 1024 for stage3) - `--train_batch_size`: Training batch size - `--micro_train_batch_size`: Micro batch size for gradient accumulation - `--learning_rate`: Learning rate (default: 1e-4 for stage1/stage2, 5e-6 for stage3) - `--max_epochs`: Maximum training epochs - `--zero_stage`: ZeRO optimization stage (default: 2) - `--bf16`: Use bfloat16 precision - `--flash_attn`: Use Flash Attention 2 ### Inference The CLaRa models can be loaded and used for inference. We provide three models corresponding to different training stages:
Stage 1: Compression Pretraining model (click to expand) ```python from transformers import AutoModel model_path = "path/to/stage1/model" model = AutoModel.from_pretrained( model_path, trust_remote_code=True ).to('cuda') # Example documents documents = [ [ "Document 1 content...", "Document 2 content...", "Document 3 content..." ] ] questions = ["" for _ in range(len(documents))] # Generate paraphrase from compressed representations output = model.generate_from_paraphrase( questions=questions, documents=documents, max_new_tokens=64 ) print('Generated paraphrase:', output[0]) ```
Stage 2: Compression Instruction Tuning model (click to expand) ```python from transformers import AutoModel model_path = "path/to/stage2/model" model = AutoModel.from_pretrained( model_path, trust_remote_code=True ).to('cuda') # Example documents and question documents = [ [ "Document 1 content...", "Document 2 content...", "Document 3 content..." ] ] questions = ["Your question here"] # Generate answer from compressed representations output = model.generate_from_text( questions=questions, documents=documents, max_new_tokens=64 ) print('Generated answer:', output[0]) ```
Stage 3: End-to-End (CLaRa) model (click to expand) ```python from transformers import AutoModel model_path = "path/to/stage3/model" model = AutoModel.from_pretrained( model_path, trust_remote_code=True ).to('cuda') # Example documents and question # Note: Stage 3 supports retrieval with multiple candidate documents documents = [ ["Document 1 content..." for _ in range(20)] # 20 candidate documents ] questions = ["Your question here"] # Generate answer with retrieval and reranking # The top-k is decided by generation_top_k in config.json output, topk_indices = model.generate_from_questions( questions=questions, documents=documents, max_new_tokens=64 ) print('Generated answer:', output[0]) print('Top-k selected document indices:', topk_indices) ```
### Evaluation The evaluation framework is based on standard RAG benchmarks. Run evaluation: **End-to-end evaluation:** ```bash bash scripts/evaluation_end_to_end.sh ``` **Instruction tuning evaluation:** ```bash bash scripts/evaluation_instruction_tuning.sh ``` Supported datasets: - **HotpotQA**: Multi-hop question answering - **MuSiQue**: Multi-hop question answering with diverse reasoning - **2WikiMultiHopQA**: Multi-hop question answering over Wikipedia - **Natural Questions**: Open-domain question answering ### Results #### Compression Performance We evaluate our document compressor on four QA datasets (NQ, HotpotQA, MuSiQue, 2WikiMultiHopQA) under two settings: **Normal** (retrieving top-5 documents) and **Oracle** (gold document included). CLaRa consistently outperforms all baselines across different compression ratios.
**Main Results (Mistral-7B, Normal Setting)** | Model | CR | NQ | HotpotQA | MuSiQue | 2Wiki | Avg | |:---|:---:|:---:|:---:|:---:|:---:|:---:| | AutoCompressor | - | 17.24 | 14.61 | 3.81 | 19.89 | 13.89 | | XRAG | 128 | 32.35 | 25.16 | 3.64 | 28.79 | 22.48 | | COCOM | 16 | 24.12 | 21.48 | 3.52 | 24.48 | 18.40 | | PCC | 16 | 31.38 | 22.29 | 3.43 | 19.47 | 19.14 | | LLMLingua-2 | 4 | 47.53 | 37.05 | 9.02 | 44.35 | 34.49 | | PISCO | 16 | 54.39 | 41.94 | 10.09 | 44.88 | 37.83 | | Mistral-7B w/ retrieval | - | 54.58 | 42.94 | 8.94 | 44.24 | 37.67 | | **CLaRa (CR=4)** | **4** | **57.05** | **45.09** | **10.34** | **46.94** | **39.86** | | **CLaRa (CR=16)** | **16** | **55.56** | **43.72** | **10.55** | **46.00** | **38.96** | | **CLaRa (CR=32)** | **32** | **54.64** | **43.52** | **10.55** | **46.58** | **38.82** | **Oracle Setting Results (Mistral-7B)** | Model | CR | NQ | HotpotQA | MuSiQue | 2Wiki | Avg | |:---|:---:|:---:|:---:|:---:|:---:|:---:| | PISCO | 16 | 73.44 | 66.53 | 33.80 | 60.45 | 58.55 | | Mistral-7B w/ retrieval | - | 71.64 | 70.77 | 45.72 | 68.83 | 64.24 | | **CLaRa (CR=4)** | **4** | **76.50** | **73.81** | **46.26** | **70.48** | **66.76** | | **CLaRa (CR=16)** | **16** | **75.48** | **70.79** | **43.15** | **66.16** | **63.90** | | **CLaRa (CR=32)** | **32** | **73.77** | **69.51** | **38.31** | **64.54** | **61.53** |
**Key Findings:** - ✅ CLaRa outperforms PISCO by **+1.13%** (Normal) and **+5.35%** (Oracle) on average - ✅ CLaRa outperforms LLMLingua-2 by **+5.37%** (Normal) on average - ✅ CLaRa matches/exceeds text-based baseline with **+2.36%** average gain on Mistral-7B #### Retrieval Performance
For detailed experimental results and analysis, please refer to our paper. ## Acknowledgments We sincerely appreciate the following works for CLaRa: - Our implementation is built upon the [OpenRLHF framework](https://github.com/OpenRLHF/OpenRLHF). - Inspired by [PISCO-mistral](https://huggingface.co/naver/pisco-mistral) for document compression techniques ## Citation ```bibtex @misc{he2025clarabridgingretrievalgeneration, title={CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning}, author={Jie He and Richard He Bai and Sinead Williamson and Jeff Z. Pan and Navdeep Jaitly and Yizhe Zhang}, year={2025}, eprint={2511.18659}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2511.18659}, } ```