GKD
Version Requirement: ms-swift >= 3.12
If you are new to GKD, please refer to the GKD Documentation first.
GKD (Generalized Knowledge Distillation) is a training method that transfers knowledge from a teacher model to a student model by computing the Jensen-Shannon Divergence (JSD) loss between their output distributions.
Feature Support
Megatron GKD currently supports the following features:
- Training Modes: Full parameter training and LoRA fine-tuning
- Parallelism Strategies: Context Parallel (CP), Pipeline Parallel (PP), Tensor Parallel (TP), and Expert Parallel (EP)
- Model Support: Compatible with LLMs and MLLMs in Megatron-SWIFT
- Teacher Offload: Supports offloading teacher model to CPU to save GPU memory
- Online Generation: Supports on-policy generation using vLLM for student model
Current Limitations
- Teacher Model Online Generation (
seq_kd=True): Teacher model generation in Sequential KD mode is not yet supported - Non-vLLM Generation: On-policy generation currently only supports vLLM
- Teacher model with different parallel parameters: Will be supported in future versions
⚠️ Notes:
- On-policy Generation: Requires vLLM (
--use_vllm true --vllm_mode colocate/server) - When
lmbda > 0but vLLM is not enabled, it will automatically fall back to off-policy mode (using dataset responses) - When
seq_kd=True, since teacher generation is not yet supported, it will automatically fall back to off-policy mode. If needed, please use swift infer to pre-generate responses for the dataset
Parameters
GKD-specific Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
--teacher_model |
str | Required | Path or model ID of the teacher model |
--beta |
float | 0.5 | JSD divergence interpolation coefficient: • 0.0: Forward KL • 0.5: Symmetric JSD • 1.0: Reverse KL |
--lmbda |
float | 0.5 | On-Policy learning probability: • 0.0: Pure Off-Policy • 1.0: Pure On-Policy |
--seq_kd |
bool | False | Use teacher-generated responses (not yet supported) |
--temperature |
float | 0.9 | Temperature for sampling and loss computation |
--sft_alpha |
float | 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions |
--max_completion_length |
int | 512 | Maximum tokens for generation |
Batch-related Parameters
Same as Megatron SFT, use the following parameters to control batch size:
| Parameter | Description |
|---|---|
--micro_batch_size |
Training batch size per GPU |
--global_batch_size |
Global batch size: micro_batch_size × dp_size × gradient_accumulation_steps |
Three Training Modes
GKD supports three training modes, controlled by lmbda and seq_kd parameters:
Mode 1: On-Policy Learning
- Trigger:
random() < lmbdaanduse_vllm=True - Data source: Responses generated by the student model
Mode 2: Sequential KD (Not Yet Supported)
- Trigger:
random() >= lmbdaandseq_kd=True - Data source: Responses generated by the teacher model
Mode 3: Off-Policy Learning
- Trigger: Other cases
- Data source: Labeled responses from the dataset
Reference
For more parameters, please refer to Command-line Parameters
For training scripts, please refer to Megatron GKD Scripts