Charlie81 commited on
Commit
f9596a0
·
1 Parent(s): 50cd1ec

update training script

Browse files
Files changed (1) hide show
  1. scripts/train.py +26 -10
scripts/train.py CHANGED
@@ -61,7 +61,6 @@ def main():
61
  tokenized["labels"] = tokenized["input_ids"].copy()
62
  return tokenized
63
 
64
-
65
  tokenized_dataset = dataset.map(
66
  tokenize_function,
67
  batched=True,
@@ -74,7 +73,7 @@ def main():
74
  output_dir="./output",
75
  per_device_train_batch_size=1,
76
  gradient_accumulation_steps=8,
77
- learning_rate=1e-5,
78
  num_train_epochs=1,
79
  logging_dir="./logs",
80
  logging_steps=10,
@@ -88,21 +87,28 @@ def main():
88
  warmup_ratio=0.1,
89
  max_grad_norm=1.0,
90
  )
91
- # Freeze all parameters
 
92
  for param in model.parameters():
93
  param.requires_grad = False
94
 
95
  # Unfreeze only the small experts and their gating networks
96
  for name, param in model.named_parameters():
97
  # Unfreeze small expert layers
98
- if "mlp.experts" in name and any(f"mlp.experts.{i}." in name for i in range(0, config.num_experts, config.small_expert_count)):
99
  param.requires_grad = True
100
  print(f"Unfreezing small expert parameter: {name}")
101
 
102
- # Unfreeze gating network parameters
103
- if "mlp.gate" in name:
104
  param.requires_grad = True
105
- print(f"Unfreezing gating network parameter: {name}")
 
 
 
 
 
 
106
 
107
  # Trainer
108
  trainer = Trainer(
@@ -110,14 +116,24 @@ def main():
110
  args=training_args,
111
  train_dataset=tokenized_dataset,
112
  tokenizer=tokenizer,
113
- data_collator=default_data_collator,
114
  )
115
 
116
  # Train
117
  trainer.train()
118
 
119
- # Save
120
- trainer.save_model("./final_model")
 
 
 
 
 
 
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
  main()
 
61
  tokenized["labels"] = tokenized["input_ids"].copy()
62
  return tokenized
63
 
 
64
  tokenized_dataset = dataset.map(
65
  tokenize_function,
66
  batched=True,
 
73
  output_dir="./output",
74
  per_device_train_batch_size=1,
75
  gradient_accumulation_steps=8,
76
+ learning_rate=1e-4, # Higher LR for expert training
77
  num_train_epochs=1,
78
  logging_dir="./logs",
79
  logging_steps=10,
 
87
  warmup_ratio=0.1,
88
  max_grad_norm=1.0,
89
  )
90
+
91
+ # Freeze all parameters first
92
  for param in model.parameters():
93
  param.requires_grad = False
94
 
95
  # Unfreeze only the small experts and their gating networks
96
  for name, param in model.named_parameters():
97
  # Unfreeze small expert layers
98
+ if "mlp.small_experts" in name:
99
  param.requires_grad = True
100
  print(f"Unfreezing small expert parameter: {name}")
101
 
102
+ # Unfreeze small gating network parameters
103
+ if "mlp.small_gate" in name:
104
  param.requires_grad = True
105
+ print(f"Unfreezing small gate parameter: {name}")
106
+
107
+ # Create custom data collator to handle router logits
108
+ def data_collator(features):
109
+ batch = default_data_collator(features)
110
+ batch["output_router_logits"] = True # Ensure we get router logits for aux loss
111
+ return batch
112
 
113
  # Trainer
114
  trainer = Trainer(
 
116
  args=training_args,
117
  train_dataset=tokenized_dataset,
118
  tokenizer=tokenizer,
119
+ data_collator=data_collator,
120
  )
121
 
122
  # Train
123
  trainer.train()
124
 
125
+ # Save only the small experts and gates
126
+ print("Saving only small experts and gates...")
127
+ small_expert_state_dict = {
128
+ name: param for name, param in model.named_parameters()
129
+ if "mlp.small_experts" in name or "mlp.small_gate" in name
130
+ }
131
+
132
+ os.makedirs("./final_model", exist_ok=True)
133
+ torch.save(small_expert_state_dict, "./final_model/small_experts_and_gates.bin")
134
+
135
+ # Also save config
136
+ config.save_pretrained("./final_model")
137
 
138
  if __name__ == "__main__":
139
  main()