GRIP-Llama-3-8B / README_zh.md
WarrenWang01's picture
Update README_zh.md
4522ee8 verified

Retrieval as Generation: A Unified Framework with Self-Triggered Information Planning

English | 简体中文

[ACL'26 Main Conference]

Bo LiMingda WangGeXiang FangShikun ZhangWei Ye

传统的 RAG(检索增强生成)系统将检索视为一种外部的、一次性的干预行为——在生成开始前僵硬地预取文档。这种方式在复杂推理过程中信息需求逐步涌现时往往表现不佳。即便是动态搜索方法,也高度依赖于孤立的外部控制器或启发式规则。

我们认为,正如人类的认知过程一样,检索应当是一种内在的、生成式的能力。大语言模型必须能够自主地评估自身知识状态、触发搜索,并根据不断演进的推理状态构建具有上下文关联的后续查询。

GRIP(Generation-guided Retrieval with Information Planning,生成引导的信息规划检索)正是这一新范式的具体体现。在"检索即生成"框架下,我们的模型通过特定控制词元(control tokens),将检索决策直接内化于词元级的解码过程中。这一方法将模型从依赖辅助性多阶段搜索模块中解放出来,在单一自回归轨迹内实现端到端、自触发的信息规划。

🌟 核心特性

  • 🎯 词元驱动控制:通过显式控制词元(如 [RETRIEVE][ANSWER][INTERMEDIARY]),将检索行为直接嵌入模型的生成策略,无需外部分类器。
  • 🔄 自触发规划:自主决定何时回退到内部知识、如何根据部分推理重新构建针对性查询,以及何时终止搜索。
  • ⚖️ 自适应检索深度:根据问题复杂度动态调整检索轮次,在成功避免冗余搜索的同时,还能突破严格的训练预算限制进行外推。
  • 🚀 最先进的性能:在五个问答基准测试上,以更小的骨干模型(LLaMA3-8B)超越了强力的开源 RAG 基线(如 GainRAG、R1-Searcher),并达到了与 GPT-4o 相当的竞争性水平。
  • 🧩 统一解码轨迹:将多步推理与即时证据整合紧密耦合于单一、连续的生成流程中。
  • 🛠️ 优化的训练方案:采用针对四种不同行为模式的结构化有监督微调(SFT),并通过基于规则的强化学习(DAPO)进一步精炼,以确保准确且均衡的检索控制。

🚀 快速开始

安装

git clone https://github.com/WisdomShell/GRIP
cd GRIP
conda create -n GRIP python=3.9
conda activate GRIP
cd GRIP/model/Train
pip install -e .
cd ../
pip install -r requirements.txt

准备工作

构建 Wikipedia 索引

下载 Wikipedia 数据转储文件。

mkdir wiki_data
cd wiki_data
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
gzip -d psgs_w100.tsv.gz

使用 Elasticsearch 对 Wikipedia 数据建立索引。

mkdir ret
cd ret
wget -O elasticsearch-7.17.9.tar.gz https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.17.9-linux-x86_64.tar.gz
tar zxvf elasticsearch-7.17.9.tar.gz
rm elasticsearch-7.17.9.tar.gz 
cd elasticsearch-7.17.9
nohup bin/elasticsearch 
python data_generation/index.py --data_path path/to/your/psgs_w100.tsv --index_name wiki

检查点与数据集

以下是本工作中用于 SFT 和 RL 训练的数据集,以及已训练完成的 GRIP 模型权重。

数据集 HF 数据集仓库
GRIP_SFT_Train_Data WisdomShell/GRIP_SFT_Data
GRIP_RL_Train_Data WisdomShell/GRIP_RL_Data
模型 HF 模型仓库
Meta-LLaMa-3-8b-GRIP WisdomShell/LLaMa-3-8b-GRIP

生成 SFT 和 RL 训练数据

在此之前,您需要下载 NaturalQuestion-open 训练集、WebQuestions 训练集和 TriviaQA 训练集,提取其中的问题与答案,将它们合并为一个 jsonl 文件,并转换为以下格式:

{
    "question": "",
    "answer":["Answer", ...]
}

使用 Meta-LLaMa-3-8B-Instruct 模型执行以下代码:

bash data_generation/first.sh

将您的 OpenAI token 写入 use_gpt_for_data.py 文件,并配置 C.jsonl 文件路径。 生成完成后,将自动覆盖原始文件。

python generation_train_data/use_gpt_for_data.py

将 A、B、C、D 所在的目录写入 merge_dataset.py 文件。 输出路径将保存 SFT_Train_data 和 RL_Train_data。

python generation_train_data/merge_dataset.py

训练

SFT

  1. 数据处理

    • 脚本:Train/examples/data_preprocess/grip/sft.py

    • 您需要指定 data_path 参数,即 GRIP 合成数据的路径:

      parser.add_argument('--data_path', default='<PATH_TO_RAW_DATASET_ROOT>/SFT_data.jsonl')
      
    • 您需要指定 dataset 的名称,以便在后续训练中使用:

      # 数据路径默认保存在 "datasets" 文件夹中
      parser.add_argument('--save_dir', default='datasets/GRIPSFT')
      
  2. 训练脚本

    • 脚本:Train/examples/sft/run_sft_llama.sh

    • 使用模型的 Base 版本 进行训练:

      set -x
      
      NAME=GRIPSFT  # 在此指定上一步处理后的训练数据名称
      torchrun --standalone --nnodes=1 --nproc_per_node=8 -m verl.trainer.fsdp_sft_trainer \
          data.train_files=datasets/$NAME/train.parquet \  
          data.val_files=datasets/$NAME/test.parquet \
          data.prompt_key=extra_info \
          data.response_key=extra_info \
          optim.lr=1e-6 \
          data.prompt_dict_keys=['question'] \
          +data.response_dict_keys=['answer'] \
          data.micro_batch_size=4 \
          model.partial_pretrain=meta-llama/Meta-Llama-3-8B-Base \       #使用 Base 版本训练
          trainer.default_local_dir=/path/to/your/SFT_model \  # 微调模型保存路径
          trainer.project_name=GRIPSFT \
          trainer.experiment_name=$NAME \
          trainer.logger=['console'] \  # 上报到 `console` 或 `wandb`
          trainer.total_epochs=8 \      # 训练轮次
          trainer.default_hdfs_dir=null $@ \
          ulysses_sequence_parallel_size=2 \
          use_remove_padding=true
      

RL

  1. 数据处理

    • 脚本:Train/examples/data_preprocess/grip/rl.py

    • 您需要指定 data_path 参数,即 GRIP 合成数据的路径:

      parser.add_argument('--data_path', default='<PATH_TO_RAW_DATASET_ROOT>/RL_data.jsonl')
      
    • 您需要指定 dataset 的名称,以便在后续训练中使用:

      # 数据路径默认保存在 "datasets" 文件夹中
      parser.add_argument('--save_dir', default='datasets/GRIPRL')
      
    • 您需要指定 data_source 的名称,以便在后续训练中选择奖励模型:

      parser.add_argument('--data_source', default='GRIPRL')    # 必填
      
  2. 使用 DAPO 训练脚本

    • 脚本:Train/recipe/dapo/dapo_4w_continue_rl_ep3_llama.sh

    • 您需要修改以下参数以适配 RL 训练:

      ...
       # 路径配置
       MODEL_PATH=<PATH_TO_SAVE>/GRIPSFT_LLaMa/global_step_xxx  # SFT 检查点
       CKPTS_DIR=<PATH_TO_SAVE>/RL_model  # RL 模型保存路径
       TRAIN_FILE=datasets/GRIPRL/train.parquet  # RL 数据集
       TEST_FILE=datasets/GRIPRL/test.parquet  # RL 数据集
       ...
      
  3. 奖励模型的具体实现位于文件 Train/verl/utils/reward_score/grip.py 中。

  4. 训练完成后,您需要通过脚本 Train/scripts/merge.sh 将模型保存的分片合并为 Hugging Face 格式。

使用 GRIP 进行本地推理

测试数据格式对齐

{
    "question": "Test Query",
    "answer": ["Answer List", ...]
}

多轮 GRIP 推理

  • 主脚本:inference/inference.sh

    # 模型保存路径
    parser.add_argument('--model_path', type=str, default="/path/to/your/RL_model/step_xxx")
    # 预测文件输出路径
    parser.add_argument('--output_file', type=str, default="output/rl_step_xxx_hotpot.jsonl")
    # 待预测文件
    parser.add_argument('--input_file', type=str, default="test_data/hotpotQA.jsonl")
    
  • 该脚本将按以下格式生成预测结果:

    {
          "Question": "String", 
          "prediction": ["String",......]
      }
    

评估

python eval/eval.py \
    --references_path test_dataset.jsonl \
    --predictions_path prediction.jsonl

🤝 贡献

欢迎贡献!请参阅 CONTRIBUTING.md 了解相关指引。

📄 引用

@article{li2026retrieval,
  title={Retrieval as Generation: A Unified Framework with Self-Triggered Information Planning},
  author={Li, Bo and Wang, Mingda and Fang, Gexiang and Zhang, Shikun and Ye, Wei},
  journal={arXiv preprint arXiv:2604.11407},
  year={2026}
}

📝 许可证

本项目采用 Apache 2.0 许可证——详情请参阅 LICENSE 文件。

🙏 致谢

特别感谢开源社区及所有使本项目成为可能的贡献者。