Charlie81 commited on
Commit
e785830
·
1 Parent(s): 20c7ba3

save functionality

Browse files
Files changed (1) hide show
  1. scripts/train.py +27 -5
scripts/train.py CHANGED
@@ -10,6 +10,8 @@ from transformers import (
10
  from datasets import load_dataset
11
  from myolmoe import MyOlmoeForCausalLM, OlmoeConfig
12
  import os
 
 
13
 
14
  def main():
15
  print("Starting my COOL OLMoE training script for small experts")
@@ -73,10 +75,10 @@ def main():
73
  per_device_train_batch_size=2,
74
  gradient_accumulation_steps=8,
75
  learning_rate=1e-4,
76
- num_train_epochs=3,
77
  logging_dir="./logs",
78
  logging_steps=10,
79
- save_steps=1000,
80
  save_total_limit=2,
81
  bf16=True,
82
  gradient_checkpointing=False, # Disabled for now
@@ -134,6 +136,17 @@ def main():
134
  raise RuntimeError("Loss doesn't require gradients. Check model parameters.")
135
 
136
  return (loss, outputs) if return_outputs else loss
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  # Initialize trainer
139
  trainer = CustomTrainer(
@@ -141,6 +154,7 @@ def main():
141
  args=training_args,
142
  train_dataset=tokenized_dataset,
143
  data_collator=data_collator,
 
144
  )
145
 
146
  # Test forward/backward pass before training
@@ -167,11 +181,19 @@ def main():
167
 
168
  # Reset gradients
169
  model.zero_grad()
170
-
 
 
 
 
 
 
 
 
171
  # Train
172
  print("Starting training...")
173
- trainer.train()
174
-
175
  # Save only the small experts and gates
176
  print("Saving small experts and gates...")
177
  small_expert_state_dict = {
 
10
  from datasets import load_dataset
11
  from myolmoe import MyOlmoeForCausalLM, OlmoeConfig
12
  import os
13
+ from transformers import TrainerCallback
14
+ import subprocess
15
 
16
  def main():
17
  print("Starting my COOL OLMoE training script for small experts")
 
75
  per_device_train_batch_size=2,
76
  gradient_accumulation_steps=8,
77
  learning_rate=1e-4,
78
+ num_train_epochs=0.001,
79
  logging_dir="./logs",
80
  logging_steps=10,
81
+ save_steps=2000,
82
  save_total_limit=2,
83
  bf16=True,
84
  gradient_checkpointing=False, # Disabled for now
 
136
  raise RuntimeError("Loss doesn't require gradients. Check model parameters.")
137
 
138
  return (loss, outputs) if return_outputs else loss
139
+
140
+ class GitPushCallback(TrainerCallback):
141
+ def on_save(self, args, state, control, **kwargs):
142
+ try:
143
+ print("Pushing checkpoint to Git...")
144
+ subprocess.run(["git", "add", "."], check=True)
145
+ subprocess.run(["git", "commit", "-m", f"Checkpoint at step {state.global_step}"], check=True)
146
+ subprocess.run(["git", "push"], check=True)
147
+ print("Checkpoint pushed successfully.")
148
+ except subprocess.CalledProcessError as e:
149
+ print(f"Git push failed: {e}")
150
 
151
  # Initialize trainer
152
  trainer = CustomTrainer(
 
154
  args=training_args,
155
  train_dataset=tokenized_dataset,
156
  data_collator=data_collator,
157
+ callbacks=[GitPushCallback()]
158
  )
159
 
160
  # Test forward/backward pass before training
 
181
 
182
  # Reset gradients
183
  model.zero_grad()
184
+
185
+ # Check for existing checkpoint
186
+ checkpoint_dir = None
187
+ if os.path.isdir(training_args.output_dir):
188
+ checkpoints = [os.path.join(training_args.output_dir, d) for d in os.listdir(training_args.output_dir) if d.startswith("checkpoint-")]
189
+ if checkpoints:
190
+ checkpoint_dir = max(checkpoints, key=os.path.getmtime)
191
+ print(f"Resuming from checkpoint: {checkpoint_dir}")
192
+
193
  # Train
194
  print("Starting training...")
195
+ trainer.train(resume_from_checkpoint=checkpoint_dir)
196
+
197
  # Save only the small experts and gates
198
  print("Saving small experts and gates...")
199
  small_expert_state_dict = {