File size: 1,061 Bytes
05c5c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

# config.py - Training configuration
from qwen_distill import QwenDistillationConfig

class MyConfig(QwenDistillationConfig):
    def __init__(self):
        super().__init__()
        
        # Paths
        self.data_file = "data/train.txt"
        self.teacher_model_name = "Qwen/Qwen2.5-0.5B"
        
        # Student size (adjust based on your needs)
        # Small: 3 layers, 128 hidden = ~30M params
        # Medium: 5 layers, 256 hidden = ~100M params
        # Large: 8 layers, 384 hidden = ~250M params
        
        self.student_num_layers = 5
        self.student_hidden_dim = 256
        self.student_num_heads = 4
        
        # Training
        self.batch_size = 2
        self.gradient_accumulation_steps = 4
        self.max_steps = 2000
        self.learning_rate = 8e-4
        
        # Distillation
        self.temperature = 3.0
        self.alpha = 0.8  # 80% KD loss
        self.beta = 0.2   # 20% feature loss
        
        # Memory
        self.use_gradient_checkpointing = True
        self.mixed_precision = "fp16"