GKD
版本依赖:ms-swift >= 3.12
如果你是首次使用 GKD,请先参考 GKD文档。
GKD(Generalized Knowledge Distillation,广义知识蒸馏)是一种将教师模型的知识迁移到学生模型的训练方法,通过计算两个模型输出分布之间的 Jensen-Shannon 散度(JSD)损失来实现知识蒸馏。
功能支持
Megatron GKD 当前已支持以下功能:
- 训练模式:全参数训练与 LoRA 微调
- 并行策略:支持上下文并行(CP)、流水线并行(PP)、张量并行(TP)和专家并行(EP)
- 模型支持:兼容 Megatron-SWIFT 中的 LLM 及 MLLM
- Teacher Offload:支持将教师模型卸载到 CPU 以节省 GPU 显存
- 在线生成:支持使用 vLLM 进行学生模型的 on-policy 生成
当前限制
- 教师模型在线生成(
seq_kd=True):当前 Sequential KD 模式下的教师模型生成暂不支持 - 非vLLM生成:On-policy 生成当前仅支持 vLLM
- 教师模型使用与学生模型不同的并行参数: 将在未来版本支持
⚠️ 注意事项:
- On-policy 生成:需要启用 vLLM(
--use_vllm true --vllm_mode colocate/server) - 当
lmbda > 0但未启用 vLLM 时,将自动回退到离线学习模式(使用数据集响应) - 当
seq_kd=True时,由于教师生成暂不支持,将自动回退到离线学习模式,如需使用,请提前用swift infer推理数据集
参数说明
GKD 特有参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--teacher_model |
str | 必需 | 教师模型路径或模型 ID |
--beta |
float | 0.5 | JSD 散度插值系数: • 0.0: Forward KL • 0.5: 对称 JSD • 1.0: Reverse KL |
--lmbda |
float | 0.5 | On-Policy 学习触发概率: • 0.0: 纯 Off-Policy • 1.0: 纯 On-Policy |
--seq_kd |
bool | False | 是否使用教师生成的响应(当前暂不支持) |
--temperature |
float | 0.9 | 温度参数,用于采样和损失计算 |
--sft_alpha |
float | 0 | 混合一定比例的sft loss,对非student生成结果生效 |
--max_completion_length |
int | 512 | 生成时的最大 token 数 |
批量相关参数
与 Megatron SFT 相同,使用以下参数控制批量大小:
| 参数 | 说明 |
|---|---|
--micro_batch_size |
每张 GPU 的训练批次大小 |
--global_batch_size |
全局批次大小:micro_batch_size × dp_size × gradient_accumulation_steps |
三种训练模式
GKD 支持三种训练模式,通过 lmbda 和 seq_kd 参数控制:
Mode 1: On-Policy 学习
- 触发条件:
random() < lmbda且use_vllm=True - 数据来源:学生模型生成的响应
Mode 2: Sequential KD(当前暂不支持)
- 触发条件:
random() >= lmbda且seq_kd=True - 数据来源:教师模型生成的响应
Mode 3: Off-Policy 学习
- 触发条件:其他情况
- 数据来源:数据集中的标注响应
参考
更多参数请参考命令行文档
训练脚本请参考 Megatron GKD 脚本