Charlie81 commited on
Commit
356573e
·
1 Parent(s): 6b0e19d

unfreeze only gate and experts

Browse files
Files changed (1) hide show
  1. scripts/train.py +15 -0
scripts/train.py CHANGED
@@ -88,7 +88,22 @@ def main():
88
  warmup_ratio=0.1,
89
  max_grad_norm=1.0,
90
  )
 
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Trainer
93
  trainer = Trainer(
94
  model=model,
 
88
  warmup_ratio=0.1,
89
  max_grad_norm=1.0,
90
  )
91
+ # Freeze all parameters
92
+ for param in model.parameters():
93
+ param.requires_grad = False
94
 
95
+ # Unfreeze only the small experts and their gating networks
96
+ for name, param in model.named_parameters():
97
+ # Unfreeze small expert layers
98
+ if "mlp.experts" in name and any(f"mlp.experts.{i}." in name for i in range(0, config.num_experts, config.small_expert_frequency)):
99
+ param.requires_grad = True
100
+ print(f"Unfreezing small expert parameter: {name}")
101
+
102
+ # Unfreeze gating network parameters
103
+ if "mlp.gate" in name:
104
+ param.requires_grad = True
105
+ print(f"Unfreezing gating network parameter: {name}")
106
+
107
  # Trainer
108
  trainer = Trainer(
109
  model=model,