| <div align="center"> |
| <h1> Retrieval as Generation: A Unified Framework with Self-Triggered Information Planning </h1> |
| <p> |
| <a href="README.md">English</a> | <strong>简体中文</strong> |
| </p> |
| |
| <p> |
| <a href="https://arxiv.org/abs/2604.11407"><img src="https://img.shields.io/badge/Paper-arXiv-b31b1b?logo=arxiv&logoColor=white" /></a> |
| <a href="https://wisdomshell.github.io/GRIP/"><img src="https://img.shields.io/badge/Project-Homepage-2ea44f?logo=githubpages&logoColor=white" /></a> |
| <a href="#overview"><img src="https://img.shields.io/badge/Task-Agentic%20RAG-purple.svg" /></a> |
| <a href="https://github.com/WisdomShell/GRIP"><img src="https://img.shields.io/badge/GitHub-Repository-181717?logo=github&logoColor=white" /></a> |
| <a href="https://2026.aclweb.org/"><img src="https://img.shields.io/badge/Venue-ACL%202026-blue" /></a> |
| <a href="#installation"><img src="https://img.shields.io/badge/Python-3.9%2B-3776AB?logo=python&logoColor=white" /></a> |
| </p> |
| <h2>[ACL'26 Main Conference]</h2> |
| <a href="https://deepblue666.github.io/">Bo Li</a>  |
| <a>Mingda Wang</a>  |
| <a>GeXiang Fang</a>  |
| <a>Shikun Zhang</a>  |
| <a>Wei Ye</a>  |
| <div> |
| </div> |
| </div> |
| |
| 传统的 RAG(检索增强生成)系统将检索视为一种外部的、一次性的干预行为——在生成开始前僵硬地预取文档。这种方式在复杂推理过程中信息需求逐步涌现时往往表现不佳。即便是动态搜索方法,也高度依赖于孤立的外部控制器或启发式规则。 |
|
|
| 我们认为,正如人类的认知过程一样,检索应当是一种内在的、生成式的能力。大语言模型必须能够自主地评估自身知识状态、触发搜索,并根据不断演进的推理状态构建具有上下文关联的后续查询。 |
|
|
| GRIP(Generation-guided Retrieval with Information Planning,生成引导的信息规划检索)正是这一新范式的具体体现。在"检索即生成"框架下,我们的模型通过特定控制词元(control tokens),将检索决策直接内化于词元级的解码过程中。这一方法将模型从依赖辅助性多阶段搜索模块中解放出来,在单一自回归轨迹内实现端到端、自触发的信息规划。 |
|
|
| ## 🌟 核心特性 |
|
|
| - 🎯 **词元驱动控制**:通过显式控制词元(如 `[RETRIEVE]`、`[ANSWER]`、`[INTERMEDIARY]`),将检索行为直接嵌入模型的生成策略,无需外部分类器。 |
| - 🔄 **自触发规划**:自主决定何时回退到内部知识、如何根据部分推理重新构建针对性查询,以及何时终止搜索。 |
| - ⚖️ **自适应检索深度**:根据问题复杂度动态调整检索轮次,在成功避免冗余搜索的同时,还能突破严格的训练预算限制进行外推。 |
| - 🚀 **最先进的性能**:在五个问答基准测试上,以更小的骨干模型(LLaMA3-8B)超越了强力的开源 RAG 基线(如 GainRAG、R1-Searcher),并达到了与 GPT-4o 相当的竞争性水平。 |
| - 🧩 **统一解码轨迹**:将多步推理与即时证据整合紧密耦合于单一、连续的生成流程中。 |
| - 🛠️ **优化的训练方案**:采用针对四种不同行为模式的结构化有监督微调(SFT),并通过基于规则的强化学习(DAPO)进一步精炼,以确保准确且均衡的检索控制。 |
|
|
|
|
| ## 🚀 快速开始 |
|
|
| ### 安装 |
|
|
| ```bash |
| git clone https://github.com/WisdomShell/GRIP |
| cd GRIP |
| conda create -n GRIP python=3.9 |
| conda activate GRIP |
| cd GRIP/model/Train |
| pip install -e . |
| cd ../ |
| pip install -r requirements.txt |
| ``` |
|
|
| ## 准备工作 |
|
|
| ### 构建 Wikipedia 索引 |
|
|
| 下载 Wikipedia 数据转储文件。 |
|
|
| ```python |
| mkdir wiki_data |
| cd wiki_data |
| wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz |
| gzip -d psgs_w100.tsv.gz |
| ``` |
|
|
| 使用 Elasticsearch 对 Wikipedia 数据建立索引。 |
|
|
| ```python |
| mkdir ret |
| cd ret |
| wget -O elasticsearch-7.17.9.tar.gz https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.17.9-linux-x86_64.tar.gz |
| tar zxvf elasticsearch-7.17.9.tar.gz |
| rm elasticsearch-7.17.9.tar.gz |
| cd elasticsearch-7.17.9 |
| nohup bin/elasticsearch |
| python data_generation/index.py --data_path path/to/your/psgs_w100.tsv --index_name wiki |
| ``` |
|
|
| ## 检查点与数据集 |
|
|
| 以下是本工作中用于 SFT 和 RL 训练的数据集,以及已训练完成的 GRIP 模型权重。 |
|
|
| | 数据集 | HF 数据集仓库 | |
| |------------------------------|-----------------------------------------------------------------------------------------------------------| |
| | GRIP_SFT_Train_Data | [WisdomShell/GRIP_SFT_Data](https://huggingface.co/datasets/WisdomShell/GRIP_SFT_Data) | |
| | GRIP_RL_Train_Data | [WisdomShell/GRIP_RL_Data](https://huggingface.co/datasets/WisdomShell/GRIP_RL_Data) | |
|
|
| | 模型 | HF 模型仓库 | |
| |------------------------------|-----------------------------------------------------------------------------------------------------------| |
| | Meta-LLaMa-3-8b-GRIP | [WisdomShell/LLaMa-3-8b-GRIP](https://huggingface.co/WisdomShell/GRIP-Llama-3-8B) | |
|
|
| ## 生成 SFT 和 RL 训练数据 |
|
|
| 在此之前,您需要下载 `NaturalQuestion-open` 训练集、`WebQuestions` 训练集和 `TriviaQA` 训练集,提取其中的问题与答案,将它们合并为一个 jsonl 文件,并转换为以下格式: |
|
|
| ```json |
| { |
| "question": "", |
| "answer":["Answer", ...] |
| } |
| ``` |
|
|
| 使用 `Meta-LLaMa-3-8B-Instruct` 模型执行以下代码: |
|
|
| ```python |
| bash data_generation/first.sh |
| ``` |
|
|
| 将您的 *OpenAI token* 写入 `use_gpt_for_data.py` 文件,并配置 `C.jsonl` 文件路径。 |
| 生成完成后,将自动覆盖原始文件。 |
|
|
| ``` |
| python generation_train_data/use_gpt_for_data.py |
| ``` |
|
|
| 将 A、B、C、D 所在的目录写入 `merge_dataset.py` 文件。 |
| 输出路径将保存 SFT_Train_data 和 RL_Train_data。 |
|
|
| ``` |
| python generation_train_data/merge_dataset.py |
| ``` |
|
|
| ## 训练 |
|
|
| ### SFT |
|
|
| 1. 数据处理 |
| |
| - 脚本:`Train/examples/data_preprocess/grip/sft.py` |
| - 您需要指定 `data_path` 参数,即 GRIP 合成数据的路径: |
| |
| ```python |
| parser.add_argument('--data_path', default='<PATH_TO_RAW_DATASET_ROOT>/SFT_data.jsonl') |
| ``` |
| - 您需要指定 `dataset` 的名称,以便在后续训练中使用: |
| |
| ```python |
| # 数据路径默认保存在 "datasets" 文件夹中 |
| parser.add_argument('--save_dir', default='datasets/GRIPSFT') |
| ``` |
| 2. 训练脚本 |
| |
| - 脚本:`Train/examples/sft/run_sft_llama.sh` |
| - 使用模型的 **Base 版本** 进行训练: |
| |
| ```bash |
| set -x |
| |
| NAME=GRIPSFT # 在此指定上一步处理后的训练数据名称 |
| torchrun --standalone --nnodes=1 --nproc_per_node=8 -m verl.trainer.fsdp_sft_trainer \ |
| data.train_files=datasets/$NAME/train.parquet \ |
| data.val_files=datasets/$NAME/test.parquet \ |
| data.prompt_key=extra_info \ |
| data.response_key=extra_info \ |
| optim.lr=1e-6 \ |
| data.prompt_dict_keys=['question'] \ |
| +data.response_dict_keys=['answer'] \ |
| data.micro_batch_size=4 \ |
| model.partial_pretrain=meta-llama/Meta-Llama-3-8B-Base \ #使用 Base 版本训练 |
| trainer.default_local_dir=/path/to/your/SFT_model \ # 微调模型保存路径 |
| trainer.project_name=GRIPSFT \ |
| trainer.experiment_name=$NAME \ |
| trainer.logger=['console'] \ # 上报到 `console` 或 `wandb` |
| trainer.total_epochs=8 \ # 训练轮次 |
| trainer.default_hdfs_dir=null $@ \ |
| ulysses_sequence_parallel_size=2 \ |
| use_remove_padding=true |
| ``` |
| |
| ### RL |
|
|
| 1. 数据处理 |
| |
| - 脚本:`Train/examples/data_preprocess/grip/rl.py` |
| - 您需要指定 `data_path` 参数,即 GRIP 合成数据的路径: |
| |
| ```python |
| parser.add_argument('--data_path', default='<PATH_TO_RAW_DATASET_ROOT>/RL_data.jsonl') |
| ``` |
| - 您需要指定 `dataset` 的名称,以便在后续训练中使用: |
| |
| ```python |
| # 数据路径默认保存在 "datasets" 文件夹中 |
| parser.add_argument('--save_dir', default='datasets/GRIPRL') |
| ``` |
| - 您需要指定 `data_source` 的名称,以便在后续训练中选择奖励模型: |
| |
| ```python |
| parser.add_argument('--data_source', default='GRIPRL') # 必填 |
| ``` |
| 2. 使用 `DAPO` 训练脚本 |
| |
| - 脚本:`Train/recipe/dapo/dapo_4w_continue_rl_ep3_llama.sh` |
| - 您需要修改以下参数以适配 RL 训练: |
| |
| ```bash |
| ... |
| # 路径配置 |
| MODEL_PATH=<PATH_TO_SAVE>/GRIPSFT_LLaMa/global_step_xxx # SFT 检查点 |
| CKPTS_DIR=<PATH_TO_SAVE>/RL_model # RL 模型保存路径 |
| TRAIN_FILE=datasets/GRIPRL/train.parquet # RL 数据集 |
| TEST_FILE=datasets/GRIPRL/test.parquet # RL 数据集 |
| ... |
| ``` |
| 3. 奖励模型的具体实现位于文件 `Train/verl/utils/reward_score/grip.py` 中。 |
| 4. 训练完成后,您需要通过脚本 `Train/scripts/merge.sh` 将模型保存的分片合并为 Hugging Face 格式。 |
| |
| ### 使用 GRIP 进行本地推理 |
|
|
| #### 测试数据格式对齐 |
|
|
| ```json |
| { |
| "question": "Test Query", |
| "answer": ["Answer List", ...] |
| } |
| ``` |
|
|
| #### 多轮 GRIP 推理 |
|
|
| - 主脚本:`inference/inference.sh` |
| |
| ```python |
| # 模型保存路径 |
| parser.add_argument('--model_path', type=str, default="/path/to/your/RL_model/step_xxx") |
| # 预测文件输出路径 |
| parser.add_argument('--output_file', type=str, default="output/rl_step_xxx_hotpot.jsonl") |
| # 待预测文件 |
| parser.add_argument('--input_file', type=str, default="test_data/hotpotQA.jsonl") |
| ``` |
| - 该脚本将按以下格式生成预测结果: |
| |
| ```json |
| { |
| "Question": "String", |
| "prediction": ["String",......] |
| } |
| ``` |
|
|
| ## 评估 |
|
|
| ```python |
| python eval/eval.py \ |
| --references_path test_dataset.jsonl \ |
| --predictions_path prediction.jsonl |
| ``` |
|
|
| ## 🤝 贡献 |
|
|
| 欢迎贡献!请参阅 [CONTRIBUTING.md](CONTRIBUTING.md) 了解相关指引。 |
|
|
| ## 📄 引用 |
|
|
| ```bibtex |
| @article{li2026retrieval, |
| title={Retrieval as Generation: A Unified Framework with Self-Triggered Information Planning}, |
| author={Li, Bo and Wang, Mingda and Fang, Gexiang and Zhang, Shikun and Ye, Wei}, |
| journal={arXiv preprint arXiv:2604.11407}, |
| year={2026} |
| } |
| ``` |
|
|
| ## 📝 许可证 |
|
|
| 本项目采用 Apache 2.0 许可证——详情请参阅 [LICENSE](LICENSE) 文件。 |
|
|
| ## 🙏 致谢 |
|
|
| 特别感谢开源社区及所有使本项目成为可能的贡献者。 |