Upload ms-swift/docs/source/Instruction/强化微调.md with huggingface_hub
Browse files
ms-swift/docs/source/Instruction/强化微调.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 强化微调
|
| 2 |
+
|
| 3 |
+
强化微调是目前模型训练非常重要的功能之一,它本身的实现是多种多样的,SWIFT目前已经支持了强化微调所需要的原子能力,如采样、强化学习和微调。目前我们提供了拒绝采样微调的一个具体示例,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。
|
| 4 |
+
|
| 5 |
+
## 强化微调的概念
|
| 6 |
+
|
| 7 |
+
强化微调是从2022年开始(甚至更早)就被提出的概念。其方式一般有下列流程:
|
| 8 |
+
|
| 9 |
+
1. 使用某个模型生成数据,或进行原始数据扩充
|
| 10 |
+
2. 使用数据训练目标模型
|
| 11 |
+
3. 如果有必要,重复上述过程
|
| 12 |
+
|
| 13 |
+
步骤1:
|
| 14 |
+
|
| 15 |
+
- 如果生成数据的模型是更大的模型,如GPT、Qwen-Max、DeepSeek-V3/R1等,则该强化微调可以理解为蒸馏
|
| 16 |
+
- 如果生成数据的模型是本模型,则可以理解为自我提升(self-improvement)微调
|
| 17 |
+
- 如果采样过程是采样一个batch,然后通过KL散度和reward进行拟合训练并不断循环,则可以理解为PPO、GRPO等on-policy算法
|
| 18 |
+
- 采样数据的算法包含蒙特卡洛采样、do_sample采样、group beam search、dvts等
|
| 19 |
+
- 采样过程可以引入ORM(结果判断),PRM(过程打分),多样性过滤,语种过滤等
|
| 20 |
+
|
| 21 |
+
步骤2:
|
| 22 |
+
|
| 23 |
+
- 如果使用SFT,则称为拒绝采样微调
|
| 24 |
+
- 如果是强化学习,则称为强化学习微调
|
| 25 |
+
|
| 26 |
+
步骤3:
|
| 27 |
+
|
| 28 |
+
- 如果使用更大的模型蒸馏,例如更大模型的蒙特卡洛采样蒸馏,一般不会有循环
|
| 29 |
+
- 如果使用本模型进行采样,或者PPO等算法,则会有循环
|
| 30 |
+
|
| 31 |
+
泛泛来说,常见强化微调的方式有下面几种:
|
| 32 |
+
|
| 33 |
+
1. 蒸馏:使用蒙特卡洛、do_sample等方式从超大模型中采样大量优质数据,训练小模型
|
| 34 |
+
2. 自我提升:从本模型中采样部分优质数据,筛选后训练本模型,循环执行
|
| 35 |
+
3. on-policy RL:使用PPO、GRPO等方式循环训练
|
| 36 |
+
|
| 37 |
+
采样过程一般很漫长,比训练过程漫长的多。如果使用GPT等模型蒸馏数据,则需要购买token。因此,强化微调的时间成本和花费成本比较高,所以一般作为微调的补充机制出现,当然也有特例,例如最近的DeepSeek-R1。
|
| 38 |
+
|
| 39 |
+
DeepSeek-R1使用了GRPO算法从零使base模型涌现CoT能力,该方法需要大规模集群支持,且模型需要足够大才能发生能力涌现,在本文中不详细讨论。如果需要了解该过程,请查看[论文解析](https://zhuanlan.zhihu.com/p/19714987272)。
|
| 40 |
+
|
| 41 |
+
有关强化微调的一些论文:
|
| 42 |
+
|
| 43 |
+
- 拒绝采样微调:https://arxiv.org/pdf/2308.01825
|
| 44 |
+
- ReST:https://arxiv.org/pdf/2308.08998
|
| 45 |
+
- B-STAR:https://arxiv.org/pdf/2412.17256
|
| 46 |
+
- DeepSeekMath:https://arxiv.org/pdf/2402.03300
|
| 47 |
+
- Qwen-math-PRM:https://arxiv.org/pdf/2501.07301
|
| 48 |
+
- DeepSeek-R1:https://github.com/deepseek-ai/DeepSeek-R1/tree/main
|
| 49 |
+
|
| 50 |
+
## 什么时候使用强化微调
|
| 51 |
+
|
| 52 |
+
在LLaMA3之后,我们发现一个非常明显但却是不常被提及的特点:使用某个含有CoT的train数据集训练Instruct模型,再通过对应的test集进行评测,会发现test集评测效果变差。例如,使用gsm8k训练集训练llama3.1-8b-instruct,对生成的ckpt使用test集进行评测,会发现掉点。
|
| 53 |
+
|
| 54 |
+
这个特性主要来源于模型的知识遗忘问题。在模型厂商的微调中,会加入非常多的CoT数据集,模型在解决数学任务的时候,用到的能力很有可能不是来自于math数据集,而是来自arc数据集,这个推论有[一些工作可以证明](https://zhuanlan.zhihu.com/p/19269451950)。在继续训练通用任务后,知识遗忘破坏了模型原有能力,导致了掉点。
|
| 55 |
+
|
| 56 |
+
然而,优先使用微调方式训练模型总是正确的。微调可以使模型快速适应数据集的分布,并且微调的成本很低。当有如下条件之一时使用强化微调:
|
| 57 |
+
|
| 58 |
+
1. 已经微调过模型,能力不满足需求
|
| 59 |
+
2. 需要更强的CoT能力
|
| 60 |
+
3. 对基模型训练通用能力,而原始数据集已经导致模型效果无法提升
|
| 61 |
+
4. 对应query的输出结果可以相对准确地评估好坏,例如结果清晰(数学,代码),过程清晰(翻译,风格)等
|
| 62 |
+
|
| 63 |
+
强化微调非常依赖于reward评估是否准确。如果评估结果不准确,可能导致模型训练原地震荡,甚至越训越差。
|
| 64 |
+
|
| 65 |
+
## SWIFT的实现
|
| 66 |
+
|
| 67 |
+
SWIFT支持sample命令,该命令就是用于模型采样。目前支持的采样方式有:
|
| 68 |
+
|
| 69 |
+
- do_sample:sample方式对模型进行采样,该方式支持对开源模型进行采样,后续会支持模型蒸馏
|
| 70 |
+
- sample方式后续会支持URL采样,用于大模型蒸馏
|
| 71 |
+
|
| 72 |
+
- mcts:蒙特卡洛采样,该方式在PR中,后续会支持
|
| 73 |
+
- dvts:调研中
|
| 74 |
+
|
| 75 |
+
目前我们给出了一个较为通用的[RFT脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本适用于自我提升方式的训练,且支持动态调整采样温度值、PRM阈值等超参数,并且训练方式灵活可变(微调、DPO等;或者每次迭代重新训练原模型或继续训��上个迭代的模型,甚至加载上个迭代的所有训练状态等)。开发者可以在该脚本中增加其他数据过滤(生成的数据集中,id相同的行来自同一个query),例如多样性判断、语种判断等。
|
| 76 |
+
|
| 77 |
+
## 实验结果
|
| 78 |
+
|
| 79 |
+
我们对该RFT脚本针对数学领域使用competition_math数据集进行了训练和评测,结果如下:
|
| 80 |
+
|
| 81 |
+
| 模型 | MATH指标 | 训练方式 | 迭代次数 | 训练后MATH指标 |
|
| 82 |
+
| ------------------------ | -------- | -------- | -------- | --------------------- |
|
| 83 |
+
| LLaMA3.1_8b | 12.0 | SFT | 3 | 25.2(LLaMA3.1_8b_sft) |
|
| 84 |
+
| LLaMA3.1_8b_sft | 25.2 | RFT | 2 | 32.4 |
|
| 85 |
+
| LLaMA3.1_8b_instruct | 52.2 | SFT | 2 | 39.0 |
|
| 86 |
+
| LLaMA3.1_8b_instruct | 52.2 | RFT | 3 | 58 |
|
| 87 |
+
| Qwen2.5_math_7b_instruct | 79.6 | RFT | 2 | 83.2 |
|
| 88 |
+
|
| 89 |
+
可以看到,使用competition_math直接SFT后,instruct模型的掉点十分严重。而RFT后模型能力有提升,即使对Qwen2.5_math_7b_instruct这个SOTA的math模型也同样有一定提升空间。
|
| 90 |
+
|
| 91 |
+
特别地,针对Qwen2.5_math_7b_instruct我们测试了gsm8k的指标:
|
| 92 |
+
|
| 93 |
+
| 模型 | gsm8k指标 | RFT后gsm8k指标 |
|
| 94 |
+
| ------------------------ | --------- | -------------- |
|
| 95 |
+
| Qwen2.5_math_7b_instruct | 92.8 | 91.6 |
|
| 96 |
+
|
| 97 |
+
可以看到,RFT训练后gsm8k指标变化不大,并没有出现前述的掉点现象。
|
| 98 |
+
|
| 99 |
+
## 未来计划
|
| 100 |
+
|
| 101 |
+
1. 更多的采样方式,如MCTS
|
| 102 |
+
2. 超大模型蒸馏训练
|
| 103 |
+
3. 以PPO为主的on-policy训练
|