yuccaaa commited on
Commit
5c699da
·
verified ·
1 Parent(s): 0c22c24

Upload ms-swift/docs/source/Instruction/强化微调.md with huggingface_hub

Browse files
ms-swift/docs/source/Instruction/强化微调.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 强化微调
2
+
3
+ 强化微调是目前模型训练非常重要的功能之一,它本身的实现是多种多样的,SWIFT目前已经支持了强化微调所需要的原子能力,如采样、强化学习和微调。目前我们提供了拒绝采样微调的一个具体示例,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。
4
+
5
+ ## 强化微调的概念
6
+
7
+ 强化微调是从2022年开始(甚至更早)就被提出的概念。其方式一般有下列流程:
8
+
9
+ 1. 使用某个模型生成数据,或进行原始数据扩充
10
+ 2. 使用数据训练目标模型
11
+ 3. 如果有必要,重复上述过程
12
+
13
+ 步骤1:
14
+
15
+ - 如果生成数据的模型是更大的模型,如GPT、Qwen-Max、DeepSeek-V3/R1等,则该强化微调可以理解为蒸馏
16
+ - 如果生成数据的模型是本模型,则可以理解为自我提升(self-improvement)微调
17
+ - 如果采样过程是采样一个batch,然后通过KL散度和reward进行拟合训练并不断循环,则可以理解为PPO、GRPO等on-policy算法
18
+ - 采样数据的算法包含蒙特卡洛采样、do_sample采样、group beam search、dvts等
19
+ - 采样过程可以引入ORM(结果判断),PRM(过程打分),多样性过滤,语种过滤等
20
+
21
+ 步骤2:
22
+
23
+ - 如果使用SFT,则称为拒绝采样微调
24
+ - 如果是强化学习,则称为强化学习微调
25
+
26
+ 步骤3:
27
+
28
+ - 如果使用更大的模型蒸馏,例如更大模型的蒙特卡洛采样蒸馏,一般不会有循环
29
+ - 如果使用本模型进行采样,或者PPO等算法,则会有循环
30
+
31
+ 泛泛来说,常见强化微调的方式有下面几种:
32
+
33
+ 1. 蒸馏:使用蒙特卡洛、do_sample等方式从超大模型中采样大量优质数据,训练小模型
34
+ 2. 自我提升:从本模型中采样部分优质数据,筛选后训练本模型,循环执行
35
+ 3. on-policy RL:使用PPO、GRPO等方式循环训练
36
+
37
+ 采样过程一般很漫长,比训练过程漫长的多。如果使用GPT等模型蒸馏数据,则需要购买token。因此,强化微调的时间成本和花费成本比较高,所以一般作为微调的补充机制出现,当然也有特例,例如最近的DeepSeek-R1。
38
+
39
+ DeepSeek-R1使用了GRPO算法从零使base模型涌现CoT能力,该方法需要大规模集群支持,且模型需要足够大才能发生能力涌现,在本文中不详细讨论。如果需要了解该过程,请查看[论文解析](https://zhuanlan.zhihu.com/p/19714987272)。
40
+
41
+ 有关强化微调的一些论文:
42
+
43
+ - 拒绝采样微调:https://arxiv.org/pdf/2308.01825
44
+ - ReST:https://arxiv.org/pdf/2308.08998
45
+ - B-STAR:https://arxiv.org/pdf/2412.17256
46
+ - DeepSeekMath:https://arxiv.org/pdf/2402.03300
47
+ - Qwen-math-PRM:https://arxiv.org/pdf/2501.07301
48
+ - DeepSeek-R1:https://github.com/deepseek-ai/DeepSeek-R1/tree/main
49
+
50
+ ## 什么时候使用强化微调
51
+
52
+ 在LLaMA3之后,我们发现一个非常明显但却是不常被提及的特点:使用某个含有CoT的train数据集训练Instruct模型,再通过对应的test集进行评测,会发现test集评测效果变差。例如,使用gsm8k训练集训练llama3.1-8b-instruct,对生成的ckpt使用test集进行评测,会发现掉点。
53
+
54
+ 这个特性主要来源于模型的知识遗忘问题。在模型厂商的微调中,会加入非常多的CoT数据集,模型在解决数学任务的时候,用到的能力很有可能不是来自于math数据集,而是来自arc数据集,这个推论有[一些工作可以证明](https://zhuanlan.zhihu.com/p/19269451950)。在继续训练通用任务后,知识遗忘破坏了模型原有能力,导致了掉点。
55
+
56
+ 然而,优先使用微调方式训练模型总是正确的。微调可以使模型快速适应数据集的分布,并且微调的成本很低。当有如下条件之一时使用强化微调:
57
+
58
+ 1. 已经微调过模型,能力不满足需求
59
+ 2. 需要更强的CoT能力
60
+ 3. 对基模型训练通用能力,而原始数据集已经导致模型效果无法提升
61
+ 4. 对应query的输出结果可以相对准确地评估好坏,例如结果清晰(数学,代码),过程清晰(翻译,风格)等
62
+
63
+ 强化微调非常依赖于reward评估是否准确。如果评估结果不准确,可能导致模型训练原地震荡,甚至越训越差。
64
+
65
+ ## SWIFT的实现
66
+
67
+ SWIFT支持sample命令,该命令就是用于模型采样。目前支持的采样方式有:
68
+
69
+ - do_sample:sample方式对模型进行采样,该方式支持对开源模型进行采样,后续会支持模型蒸馏
70
+ - sample方式后续会支持URL采样,用于大模型蒸馏
71
+
72
+ - mcts:蒙特卡洛采样,该方式在PR中,后续会支持
73
+ - dvts:调研中
74
+
75
+ 目前我们给出了一个较为通用的[RFT脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本适用于自我提升方式的训练,且支持动态调整采样温度值、PRM阈值等超参数,并且训练方式灵活可变(微调、DPO等;或者每次迭代重新训练原模型或继续训��上个迭代的模型,甚至加载上个迭代的所有训练状态等)。开发者可以在该脚本中增加其他数据过滤(生成的数据集中,id相同的行来自同一个query),例如多样性判断、语种判断等。
76
+
77
+ ## 实验结果
78
+
79
+ 我们对该RFT脚本针对数学领域使用competition_math数据集进行了训练和评测,结果如下:
80
+
81
+ | 模型 | MATH指标 | 训练方式 | 迭代次数 | 训练后MATH指标 |
82
+ | ------------------------ | -------- | -------- | -------- | --------------------- |
83
+ | LLaMA3.1_8b | 12.0 | SFT | 3 | 25.2(LLaMA3.1_8b_sft) |
84
+ | LLaMA3.1_8b_sft | 25.2 | RFT | 2 | 32.4 |
85
+ | LLaMA3.1_8b_instruct | 52.2 | SFT | 2 | 39.0 |
86
+ | LLaMA3.1_8b_instruct | 52.2 | RFT | 3 | 58 |
87
+ | Qwen2.5_math_7b_instruct | 79.6 | RFT | 2 | 83.2 |
88
+
89
+ 可以看到,使用competition_math直接SFT后,instruct模型的掉点十分严重。而RFT后模型能力有提升,即使对Qwen2.5_math_7b_instruct这个SOTA的math模型也同样有一定提升空间。
90
+
91
+ 特别地,针对Qwen2.5_math_7b_instruct我们测试了gsm8k的指标:
92
+
93
+ | 模型 | gsm8k指标 | RFT后gsm8k指标 |
94
+ | ------------------------ | --------- | -------------- |
95
+ | Qwen2.5_math_7b_instruct | 92.8 | 91.6 |
96
+
97
+ 可以看到,RFT训练后gsm8k指标变化不大,并没有出现前述的掉点现象。
98
+
99
+ ## 未来计划
100
+
101
+ 1. 更多的采样方式,如MCTS
102
+ 2. 超大模型蒸馏训练
103
+ 3. 以PPO为主的on-policy训练