unfreeze only gate and experts
Browse files- scripts/train.py +15 -0
scripts/train.py
CHANGED
|
@@ -88,7 +88,22 @@ def main():
|
|
| 88 |
warmup_ratio=0.1,
|
| 89 |
max_grad_norm=1.0,
|
| 90 |
)
|
|
|
|
|
|
|
|
|
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# Trainer
|
| 93 |
trainer = Trainer(
|
| 94 |
model=model,
|
|
|
|
| 88 |
warmup_ratio=0.1,
|
| 89 |
max_grad_norm=1.0,
|
| 90 |
)
|
| 91 |
+
# Freeze all parameters
|
| 92 |
+
for param in model.parameters():
|
| 93 |
+
param.requires_grad = False
|
| 94 |
|
| 95 |
+
# Unfreeze only the small experts and their gating networks
|
| 96 |
+
for name, param in model.named_parameters():
|
| 97 |
+
# Unfreeze small expert layers
|
| 98 |
+
if "mlp.experts" in name and any(f"mlp.experts.{i}." in name for i in range(0, config.num_experts, config.small_expert_frequency)):
|
| 99 |
+
param.requires_grad = True
|
| 100 |
+
print(f"Unfreezing small expert parameter: {name}")
|
| 101 |
+
|
| 102 |
+
# Unfreeze gating network parameters
|
| 103 |
+
if "mlp.gate" in name:
|
| 104 |
+
param.requires_grad = True
|
| 105 |
+
print(f"Unfreezing gating network parameter: {name}")
|
| 106 |
+
|
| 107 |
# Trainer
|
| 108 |
trainer = Trainer(
|
| 109 |
model=model,
|