File size: 10,457 Bytes
8c7dd4f ab2dd8f 8c7dd4f c8d04cb 80c7076 8c7dd4f 4d3a2a5 8c7dd4f 4522ee8 8c7dd4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 | <div align="center">
<h1> Retrieval as Generation: A Unified Framework with Self-Triggered Information Planning </h1>
<p>
<a href="README.md">English</a> | <strong>简体中文</strong>
</p>
<p>
<a href="https://arxiv.org/abs/2604.11407"><img src="https://img.shields.io/badge/Paper-arXiv-b31b1b?logo=arxiv&logoColor=white" /></a>
<a href="https://wisdomshell.github.io/GRIP/"><img src="https://img.shields.io/badge/Project-Homepage-2ea44f?logo=githubpages&logoColor=white" /></a>
<a href="#overview"><img src="https://img.shields.io/badge/Task-Agentic%20RAG-purple.svg" /></a>
<a href="https://github.com/WisdomShell/GRIP"><img src="https://img.shields.io/badge/GitHub-Repository-181717?logo=github&logoColor=white" /></a>
<a href="https://2026.aclweb.org/"><img src="https://img.shields.io/badge/Venue-ACL%202026-blue" /></a>
<a href="#installation"><img src="https://img.shields.io/badge/Python-3.9%2B-3776AB?logo=python&logoColor=white" /></a>
</p>
<h2>[ACL'26 Main Conference]</h2>
<a href="https://deepblue666.github.io/">Bo Li</a> 
<a>Mingda Wang</a> 
<a>GeXiang Fang</a> 
<a>Shikun Zhang</a> 
<a>Wei Ye</a> 
<div>
</div>
</div>
传统的 RAG(检索增强生成)系统将检索视为一种外部的、一次性的干预行为——在生成开始前僵硬地预取文档。这种方式在复杂推理过程中信息需求逐步涌现时往往表现不佳。即便是动态搜索方法,也高度依赖于孤立的外部控制器或启发式规则。
我们认为,正如人类的认知过程一样,检索应当是一种内在的、生成式的能力。大语言模型必须能够自主地评估自身知识状态、触发搜索,并根据不断演进的推理状态构建具有上下文关联的后续查询。
GRIP(Generation-guided Retrieval with Information Planning,生成引导的信息规划检索)正是这一新范式的具体体现。在"检索即生成"框架下,我们的模型通过特定控制词元(control tokens),将检索决策直接内化于词元级的解码过程中。这一方法将模型从依赖辅助性多阶段搜索模块中解放出来,在单一自回归轨迹内实现端到端、自触发的信息规划。
## 🌟 核心特性
- 🎯 **词元驱动控制**:通过显式控制词元(如 `[RETRIEVE]`、`[ANSWER]`、`[INTERMEDIARY]`),将检索行为直接嵌入模型的生成策略,无需外部分类器。
- 🔄 **自触发规划**:自主决定何时回退到内部知识、如何根据部分推理重新构建针对性查询,以及何时终止搜索。
- ⚖️ **自适应检索深度**:根据问题复杂度动态调整检索轮次,在成功避免冗余搜索的同时,还能突破严格的训练预算限制进行外推。
- 🚀 **最先进的性能**:在五个问答基准测试上,以更小的骨干模型(LLaMA3-8B)超越了强力的开源 RAG 基线(如 GainRAG、R1-Searcher),并达到了与 GPT-4o 相当的竞争性水平。
- 🧩 **统一解码轨迹**:将多步推理与即时证据整合紧密耦合于单一、连续的生成流程中。
- 🛠️ **优化的训练方案**:采用针对四种不同行为模式的结构化有监督微调(SFT),并通过基于规则的强化学习(DAPO)进一步精炼,以确保准确且均衡的检索控制。
## 🚀 快速开始
### 安装
```bash
git clone https://github.com/WisdomShell/GRIP
cd GRIP
conda create -n GRIP python=3.9
conda activate GRIP
cd GRIP/model/Train
pip install -e .
cd ../
pip install -r requirements.txt
```
## 准备工作
### 构建 Wikipedia 索引
下载 Wikipedia 数据转储文件。
```python
mkdir wiki_data
cd wiki_data
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
gzip -d psgs_w100.tsv.gz
```
使用 Elasticsearch 对 Wikipedia 数据建立索引。
```python
mkdir ret
cd ret
wget -O elasticsearch-7.17.9.tar.gz https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.17.9-linux-x86_64.tar.gz
tar zxvf elasticsearch-7.17.9.tar.gz
rm elasticsearch-7.17.9.tar.gz
cd elasticsearch-7.17.9
nohup bin/elasticsearch
python data_generation/index.py --data_path path/to/your/psgs_w100.tsv --index_name wiki
```
## 检查点与数据集
以下是本工作中用于 SFT 和 RL 训练的数据集,以及已训练完成的 GRIP 模型权重。
| 数据集 | HF 数据集仓库 |
|------------------------------|-----------------------------------------------------------------------------------------------------------|
| GRIP_SFT_Train_Data | [WisdomShell/GRIP_SFT_Data](https://huggingface.co/datasets/WisdomShell/GRIP_SFT_Data) |
| GRIP_RL_Train_Data | [WisdomShell/GRIP_RL_Data](https://huggingface.co/datasets/WisdomShell/GRIP_RL_Data) |
| 模型 | HF 模型仓库 |
|------------------------------|-----------------------------------------------------------------------------------------------------------|
| Meta-LLaMa-3-8b-GRIP | [WisdomShell/LLaMa-3-8b-GRIP](https://huggingface.co/WisdomShell/GRIP-Llama-3-8B) |
## 生成 SFT 和 RL 训练数据
在此之前,您需要下载 `NaturalQuestion-open` 训练集、`WebQuestions` 训练集和 `TriviaQA` 训练集,提取其中的问题与答案,将它们合并为一个 jsonl 文件,并转换为以下格式:
```json
{
"question": "",
"answer":["Answer", ...]
}
```
使用 `Meta-LLaMa-3-8B-Instruct` 模型执行以下代码:
```python
bash data_generation/first.sh
```
将您的 *OpenAI token* 写入 `use_gpt_for_data.py` 文件,并配置 `C.jsonl` 文件路径。
生成完成后,将自动覆盖原始文件。
```
python generation_train_data/use_gpt_for_data.py
```
将 A、B、C、D 所在的目录写入 `merge_dataset.py` 文件。
输出路径将保存 SFT_Train_data 和 RL_Train_data。
```
python generation_train_data/merge_dataset.py
```
## 训练
### SFT
1. 数据处理
- 脚本:`Train/examples/data_preprocess/grip/sft.py`
- 您需要指定 `data_path` 参数,即 GRIP 合成数据的路径:
```python
parser.add_argument('--data_path', default='<PATH_TO_RAW_DATASET_ROOT>/SFT_data.jsonl')
```
- 您需要指定 `dataset` 的名称,以便在后续训练中使用:
```python
# 数据路径默认保存在 "datasets" 文件夹中
parser.add_argument('--save_dir', default='datasets/GRIPSFT')
```
2. 训练脚本
- 脚本:`Train/examples/sft/run_sft_llama.sh`
- 使用模型的 **Base 版本** 进行训练:
```bash
set -x
NAME=GRIPSFT # 在此指定上一步处理后的训练数据名称
torchrun --standalone --nnodes=1 --nproc_per_node=8 -m verl.trainer.fsdp_sft_trainer \
data.train_files=datasets/$NAME/train.parquet \
data.val_files=datasets/$NAME/test.parquet \
data.prompt_key=extra_info \
data.response_key=extra_info \
optim.lr=1e-6 \
data.prompt_dict_keys=['question'] \
+data.response_dict_keys=['answer'] \
data.micro_batch_size=4 \
model.partial_pretrain=meta-llama/Meta-Llama-3-8B-Base \ #使用 Base 版本训练
trainer.default_local_dir=/path/to/your/SFT_model \ # 微调模型保存路径
trainer.project_name=GRIPSFT \
trainer.experiment_name=$NAME \
trainer.logger=['console'] \ # 上报到 `console` 或 `wandb`
trainer.total_epochs=8 \ # 训练轮次
trainer.default_hdfs_dir=null $@ \
ulysses_sequence_parallel_size=2 \
use_remove_padding=true
```
### RL
1. 数据处理
- 脚本:`Train/examples/data_preprocess/grip/rl.py`
- 您需要指定 `data_path` 参数,即 GRIP 合成数据的路径:
```python
parser.add_argument('--data_path', default='<PATH_TO_RAW_DATASET_ROOT>/RL_data.jsonl')
```
- 您需要指定 `dataset` 的名称,以便在后续训练中使用:
```python
# 数据路径默认保存在 "datasets" 文件夹中
parser.add_argument('--save_dir', default='datasets/GRIPRL')
```
- 您需要指定 `data_source` 的名称,以便在后续训练中选择奖励模型:
```python
parser.add_argument('--data_source', default='GRIPRL') # 必填
```
2. 使用 `DAPO` 训练脚本
- 脚本:`Train/recipe/dapo/dapo_4w_continue_rl_ep3_llama.sh`
- 您需要修改以下参数以适配 RL 训练:
```bash
...
# 路径配置
MODEL_PATH=<PATH_TO_SAVE>/GRIPSFT_LLaMa/global_step_xxx # SFT 检查点
CKPTS_DIR=<PATH_TO_SAVE>/RL_model # RL 模型保存路径
TRAIN_FILE=datasets/GRIPRL/train.parquet # RL 数据集
TEST_FILE=datasets/GRIPRL/test.parquet # RL 数据集
...
```
3. 奖励模型的具体实现位于文件 `Train/verl/utils/reward_score/grip.py` 中。
4. 训练完成后,您需要通过脚本 `Train/scripts/merge.sh` 将模型保存的分片合并为 Hugging Face 格式。
### 使用 GRIP 进行本地推理
#### 测试数据格式对齐
```json
{
"question": "Test Query",
"answer": ["Answer List", ...]
}
```
#### 多轮 GRIP 推理
- 主脚本:`inference/inference.sh`
```python
# 模型保存路径
parser.add_argument('--model_path', type=str, default="/path/to/your/RL_model/step_xxx")
# 预测文件输出路径
parser.add_argument('--output_file', type=str, default="output/rl_step_xxx_hotpot.jsonl")
# 待预测文件
parser.add_argument('--input_file', type=str, default="test_data/hotpotQA.jsonl")
```
- 该脚本将按以下格式生成预测结果:
```json
{
"Question": "String",
"prediction": ["String",......]
}
```
## 评估
```python
python eval/eval.py \
--references_path test_dataset.jsonl \
--predictions_path prediction.jsonl
```
## 🤝 贡献
欢迎贡献!请参阅 [CONTRIBUTING.md](CONTRIBUTING.md) 了解相关指引。
## 📄 引用
```bibtex
@article{li2026retrieval,
title={Retrieval as Generation: A Unified Framework with Self-Triggered Information Planning},
author={Li, Bo and Wang, Mingda and Fang, Gexiang and Zhang, Shikun and Ye, Wei},
journal={arXiv preprint arXiv:2604.11407},
year={2026}
}
```
## 📝 许可证
本项目采用 Apache 2.0 许可证——详情请参阅 [LICENSE](LICENSE) 文件。
## 🙏 致谢
特别感谢开源社区及所有使本项目成为可能的贡献者。 |