BechusRantus's picture
Upload folder using huggingface_hub
7134ce7 verified

GKD

Version Requirement: ms-swift >= 3.12

If you are new to GKD, please refer to the GKD Documentation first.

GKD (Generalized Knowledge Distillation) is a training method that transfers knowledge from a teacher model to a student model by computing the Jensen-Shannon Divergence (JSD) loss between their output distributions.

Feature Support

Megatron GKD currently supports the following features:

  • Training Modes: Full parameter training and LoRA fine-tuning
  • Parallelism Strategies: Context Parallel (CP), Pipeline Parallel (PP), Tensor Parallel (TP), and Expert Parallel (EP)
  • Model Support: Compatible with LLMs and MLLMs in Megatron-SWIFT
  • Teacher Offload: Supports offloading teacher model to CPU to save GPU memory
  • Online Generation: Supports on-policy generation using vLLM for student model

Current Limitations

  • Teacher Model Online Generation (seq_kd=True): Teacher model generation in Sequential KD mode is not yet supported
  • Non-vLLM Generation: On-policy generation currently only supports vLLM
  • Teacher model with different parallel parameters: Will be supported in future versions

⚠️ Notes:

  • On-policy Generation: Requires vLLM (--use_vllm true --vllm_mode colocate/server)
  • When lmbda > 0 but vLLM is not enabled, it will automatically fall back to off-policy mode (using dataset responses)
  • When seq_kd=True, since teacher generation is not yet supported, it will automatically fall back to off-policy mode. If needed, please use swift infer to pre-generate responses for the dataset

Parameters

GKD-specific Parameters

Parameter Type Default Description
--teacher_model str Required Path or model ID of the teacher model
--beta float 0.5 JSD divergence interpolation coefficient:
• 0.0: Forward KL
• 0.5: Symmetric JSD
• 1.0: Reverse KL
--lmbda float 0.5 On-Policy learning probability:
• 0.0: Pure Off-Policy
• 1.0: Pure On-Policy
--seq_kd bool False Use teacher-generated responses (not yet supported)
--temperature float 0.9 Temperature for sampling and loss computation
--sft_alpha float 0 Mix in a proportion of SFT loss; applied to non-student-generated completions
--max_completion_length int 512 Maximum tokens for generation

Batch-related Parameters

Same as Megatron SFT, use the following parameters to control batch size:

Parameter Description
--micro_batch_size Training batch size per GPU
--global_batch_size Global batch size: micro_batch_size × dp_size × gradient_accumulation_steps

Three Training Modes

GKD supports three training modes, controlled by lmbda and seq_kd parameters:

Mode 1: On-Policy Learning

  • Trigger: random() < lmbda and use_vllm=True
  • Data source: Responses generated by the student model

Mode 2: Sequential KD (Not Yet Supported)

  • Trigger: random() >= lmbda and seq_kd=True
  • Data source: Responses generated by the teacher model

Mode 3: Off-Policy Learning

  • Trigger: Other cases
  • Data source: Labeled responses from the dataset

Reference

For more parameters, please refer to Command-line Parameters

For training scripts, please refer to Megatron GKD Scripts