Charlie81 commited on
Commit
3ed7a55
·
1 Parent(s): 8fc755e

changes to training script

Browse files
Files changed (1) hide show
  1. scripts/train.py +10 -3
scripts/train.py CHANGED
@@ -102,7 +102,6 @@ def main():
102
  ):
103
  param.requires_grad = True
104
  trainable_params.append(name)
105
- print(f"Unfreezing parameter: {name}")
106
 
107
 
108
  print(f"Total trainable parameters: {len(trainable_params)}")
@@ -192,13 +191,21 @@ def main():
192
  model.zero_grad()
193
 
194
  # Check for existing checkpoint
 
 
195
  checkpoint_dir = None
196
  if os.path.isdir(training_args.output_dir):
197
- checkpoints = [os.path.join(training_args.output_dir, d) for d in os.listdir(training_args.output_dir) if d.startswith("checkpoint-")]
 
 
 
 
198
  if checkpoints:
199
- checkpoint_dir = max(checkpoints, key=os.path.getmtime)
 
200
  print(f"Resuming from checkpoint: {checkpoint_dir}")
201
 
 
202
  # Train
203
  print("Starting training...")
204
  trainer.train(resume_from_checkpoint=checkpoint_dir)
 
102
  ):
103
  param.requires_grad = True
104
  trainable_params.append(name)
 
105
 
106
 
107
  print(f"Total trainable parameters: {len(trainable_params)}")
 
191
  model.zero_grad()
192
 
193
  # Check for existing checkpoint
194
+ import re
195
+
196
  checkpoint_dir = None
197
  if os.path.isdir(training_args.output_dir):
198
+ checkpoints = [
199
+ os.path.join(training_args.output_dir, d)
200
+ for d in os.listdir(training_args.output_dir)
201
+ if re.match(r"checkpoint-\d+", d)
202
+ ]
203
  if checkpoints:
204
+ # Extract step numbers and find the highest
205
+ checkpoint_dir = max(checkpoints, key=lambda x: int(x.split('-')[-1]))
206
  print(f"Resuming from checkpoint: {checkpoint_dir}")
207
 
208
+
209
  # Train
210
  print("Starting training...")
211
  trainer.train(resume_from_checkpoint=checkpoint_dir)