File size: 10,457 Bytes
8c7dd4f
 
 
ab2dd8f
8c7dd4f
 
 
c8d04cb
80c7076
 
 
 
 
8c7dd4f
4d3a2a5
8c7dd4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4522ee8
 
 
 
 
 
 
8c7dd4f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
<div align="center">
  <h1> Retrieval as Generation: A Unified Framework with Self-Triggered Information Planning </h1>
  <p>
  <a href="README.md">English</a> | <strong>简体中文</strong>
</p>
  
  <p>
    <a href="https://arxiv.org/abs/2604.11407"><img src="https://img.shields.io/badge/Paper-arXiv-b31b1b?logo=arxiv&logoColor=white" /></a>
<a href="https://wisdomshell.github.io/GRIP/"><img src="https://img.shields.io/badge/Project-Homepage-2ea44f?logo=githubpages&logoColor=white" /></a>
<a href="#overview"><img src="https://img.shields.io/badge/Task-Agentic%20RAG-purple.svg" /></a>
<a href="https://github.com/WisdomShell/GRIP"><img src="https://img.shields.io/badge/GitHub-Repository-181717?logo=github&logoColor=white" /></a>
<a href="https://2026.aclweb.org/"><img src="https://img.shields.io/badge/Venue-ACL%202026-blue" /></a>
<a href="#installation"><img src="https://img.shields.io/badge/Python-3.9%2B-3776AB?logo=python&logoColor=white" /></a>
    </p>
    <h2>[ACL'26 Main Conference]</h2>
  <a href="https://deepblue666.github.io/">Bo Li</a>&emsp;
  <a>Mingda Wang</a>&emsp;
  <a>GeXiang Fang</a>&emsp;
  <a>Shikun Zhang</a>&emsp;
  <a>Wei Ye</a>&emsp;
  <div>
  </div>
</div>

传统的 RAG(检索增强生成)系统将检索视为一种外部的、一次性的干预行为——在生成开始前僵硬地预取文档。这种方式在复杂推理过程中信息需求逐步涌现时往往表现不佳。即便是动态搜索方法,也高度依赖于孤立的外部控制器或启发式规则。

我们认为,正如人类的认知过程一样,检索应当是一种内在的、生成式的能力。大语言模型必须能够自主地评估自身知识状态、触发搜索,并根据不断演进的推理状态构建具有上下文关联的后续查询。

GRIP(Generation-guided Retrieval with Information Planning,生成引导的信息规划检索)正是这一新范式的具体体现。在"检索即生成"框架下,我们的模型通过特定控制词元(control tokens),将检索决策直接内化于词元级的解码过程中。这一方法将模型从依赖辅助性多阶段搜索模块中解放出来,在单一自回归轨迹内实现端到端、自触发的信息规划。

## 🌟 核心特性

- 🎯 **词元驱动控制**:通过显式控制词元(如 `[RETRIEVE]``[ANSWER]``[INTERMEDIARY]`),将检索行为直接嵌入模型的生成策略,无需外部分类器。
- 🔄 **自触发规划**:自主决定何时回退到内部知识、如何根据部分推理重新构建针对性查询,以及何时终止搜索。
- ⚖️ **自适应检索深度**:根据问题复杂度动态调整检索轮次,在成功避免冗余搜索的同时,还能突破严格的训练预算限制进行外推。
- 🚀 **最先进的性能**:在五个问答基准测试上,以更小的骨干模型(LLaMA3-8B)超越了强力的开源 RAG 基线(如 GainRAG、R1-Searcher),并达到了与 GPT-4o 相当的竞争性水平。
- 🧩 **统一解码轨迹**:将多步推理与即时证据整合紧密耦合于单一、连续的生成流程中。
- 🛠️ **优化的训练方案**:采用针对四种不同行为模式的结构化有监督微调(SFT),并通过基于规则的强化学习(DAPO)进一步精炼,以确保准确且均衡的检索控制。


## 🚀 快速开始

### 安装

```bash
git clone https://github.com/WisdomShell/GRIP
cd GRIP
conda create -n GRIP python=3.9
conda activate GRIP
cd GRIP/model/Train
pip install -e .
cd ../
pip install -r requirements.txt
```

## 准备工作

### 构建 Wikipedia 索引

下载 Wikipedia 数据转储文件。

```python
mkdir wiki_data
cd wiki_data
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
gzip -d psgs_w100.tsv.gz
```

使用 Elasticsearch 对 Wikipedia 数据建立索引。

```python
mkdir ret
cd ret
wget -O elasticsearch-7.17.9.tar.gz https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.17.9-linux-x86_64.tar.gz
tar zxvf elasticsearch-7.17.9.tar.gz
rm elasticsearch-7.17.9.tar.gz 
cd elasticsearch-7.17.9
nohup bin/elasticsearch 
python data_generation/index.py --data_path path/to/your/psgs_w100.tsv --index_name wiki
```

## 检查点与数据集

以下是本工作中用于 SFT 和 RL 训练的数据集,以及已训练完成的 GRIP 模型权重。

| 数据集 | HF 数据集仓库 |
|------------------------------|-----------------------------------------------------------------------------------------------------------|
| GRIP_SFT_Train_Data | [WisdomShell/GRIP_SFT_Data](https://huggingface.co/datasets/WisdomShell/GRIP_SFT_Data) | 
| GRIP_RL_Train_Data | [WisdomShell/GRIP_RL_Data](https://huggingface.co/datasets/WisdomShell/GRIP_RL_Data) |

| 模型 | HF 模型仓库 |
|------------------------------|-----------------------------------------------------------------------------------------------------------|
| Meta-LLaMa-3-8b-GRIP | [WisdomShell/LLaMa-3-8b-GRIP](https://huggingface.co/WisdomShell/GRIP-Llama-3-8B) |

## 生成 SFT 和 RL 训练数据

在此之前,您需要下载 `NaturalQuestion-open` 训练集、`WebQuestions` 训练集和 `TriviaQA` 训练集,提取其中的问题与答案,将它们合并为一个 jsonl 文件,并转换为以下格式:

```json
{
    "question": "",
    "answer":["Answer", ...]
}
```

使用 `Meta-LLaMa-3-8B-Instruct` 模型执行以下代码:

```python
bash data_generation/first.sh
```

将您的 *OpenAI token* 写入 `use_gpt_for_data.py` 文件,并配置 `C.jsonl` 文件路径。
生成完成后,将自动覆盖原始文件。

```
python generation_train_data/use_gpt_for_data.py
```

将 A、B、C、D 所在的目录写入 `merge_dataset.py` 文件。
输出路径将保存 SFT_Train_data 和 RL_Train_data。

```
python generation_train_data/merge_dataset.py
```

## 训练

### SFT

1. 数据处理
   
   - 脚本:`Train/examples/data_preprocess/grip/sft.py`
   - 您需要指定 `data_path` 参数,即 GRIP 合成数据的路径:
     
     ```python
     parser.add_argument('--data_path', default='<PATH_TO_RAW_DATASET_ROOT>/SFT_data.jsonl')
     ```
   - 您需要指定 `dataset` 的名称,以便在后续训练中使用:
     
     ```python
     # 数据路径默认保存在 "datasets" 文件夹中
     parser.add_argument('--save_dir', default='datasets/GRIPSFT')
     ```
2. 训练脚本
   
   - 脚本:`Train/examples/sft/run_sft_llama.sh`
   - 使用模型的 **Base 版本** 进行训练:
     
     ```bash
     set -x
     
     NAME=GRIPSFT  # 在此指定上一步处理后的训练数据名称
     torchrun --standalone --nnodes=1 --nproc_per_node=8 -m verl.trainer.fsdp_sft_trainer \
         data.train_files=datasets/$NAME/train.parquet \  
         data.val_files=datasets/$NAME/test.parquet \
         data.prompt_key=extra_info \
         data.response_key=extra_info \
         optim.lr=1e-6 \
         data.prompt_dict_keys=['question'] \
         +data.response_dict_keys=['answer'] \
         data.micro_batch_size=4 \
         model.partial_pretrain=meta-llama/Meta-Llama-3-8B-Base \       #使用 Base 版本训练
         trainer.default_local_dir=/path/to/your/SFT_model \  # 微调模型保存路径
         trainer.project_name=GRIPSFT \
         trainer.experiment_name=$NAME \
         trainer.logger=['console'] \  # 上报到 `console` 或 `wandb`
         trainer.total_epochs=8 \      # 训练轮次
         trainer.default_hdfs_dir=null $@ \
         ulysses_sequence_parallel_size=2 \
         use_remove_padding=true
     ```

### RL

1. 数据处理
   
   - 脚本:`Train/examples/data_preprocess/grip/rl.py`
   - 您需要指定 `data_path` 参数,即 GRIP 合成数据的路径:
     
     ```python
     parser.add_argument('--data_path', default='<PATH_TO_RAW_DATASET_ROOT>/RL_data.jsonl')
     ```
   - 您需要指定 `dataset` 的名称,以便在后续训练中使用:
     
     ```python
     # 数据路径默认保存在 "datasets" 文件夹中
     parser.add_argument('--save_dir', default='datasets/GRIPRL')
     ```
   - 您需要指定 `data_source` 的名称,以便在后续训练中选择奖励模型:
     
     ```python
     parser.add_argument('--data_source', default='GRIPRL')    # 必填
     ```
2. 使用 `DAPO` 训练脚本
   
   - 脚本:`Train/recipe/dapo/dapo_4w_continue_rl_ep3_llama.sh`
   - 您需要修改以下参数以适配 RL 训练:
     
     ```bash
     ...
      # 路径配置
      MODEL_PATH=<PATH_TO_SAVE>/GRIPSFT_LLaMa/global_step_xxx  # SFT 检查点
      CKPTS_DIR=<PATH_TO_SAVE>/RL_model  # RL 模型保存路径
      TRAIN_FILE=datasets/GRIPRL/train.parquet  # RL 数据集
      TEST_FILE=datasets/GRIPRL/test.parquet  # RL 数据集
      ...
     ```
3. 奖励模型的具体实现位于文件 `Train/verl/utils/reward_score/grip.py` 中。
4. 训练完成后,您需要通过脚本 `Train/scripts/merge.sh` 将模型保存的分片合并为 Hugging Face 格式。

### 使用 GRIP 进行本地推理

#### 测试数据格式对齐

```json
{
    "question": "Test Query",
    "answer": ["Answer List", ...]
}
```

#### 多轮 GRIP 推理

- 主脚本:`inference/inference.sh`
  
  ```python
  # 模型保存路径
  parser.add_argument('--model_path', type=str, default="/path/to/your/RL_model/step_xxx")
  # 预测文件输出路径
  parser.add_argument('--output_file', type=str, default="output/rl_step_xxx_hotpot.jsonl")
  # 待预测文件
  parser.add_argument('--input_file', type=str, default="test_data/hotpotQA.jsonl")
  ```
- 该脚本将按以下格式生成预测结果:
  
  ```json
  {
        "Question": "String", 
        "prediction": ["String",......]
    }
  ```

## 评估

```python
python eval/eval.py \
    --references_path test_dataset.jsonl \
    --predictions_path prediction.jsonl
```

## 🤝 贡献

欢迎贡献!请参阅 [CONTRIBUTING.md](CONTRIBUTING.md) 了解相关指引。

## 📄 引用

```bibtex
@article{li2026retrieval,
  title={Retrieval as Generation: A Unified Framework with Self-Triggered Information Planning},
  author={Li, Bo and Wang, Mingda and Fang, Gexiang and Zhang, Shikun and Ye, Wei},
  journal={arXiv preprint arXiv:2604.11407},
  year={2026}
}
```

## 📝 许可证

本项目采用 Apache 2.0 许可证——详情请参阅 [LICENSE](LICENSE) 文件。

## 🙏 致谢

特别感谢开源社区及所有使本项目成为可能的贡献者。