luxopes commited on
Commit
776b5b5
·
verified ·
1 Parent(s): 711e74d

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -1
train.py CHANGED
@@ -282,7 +282,7 @@ def train(
282
  grad_accum=8,
283
  epochs=1,
284
  lr=1e-5,
285
- warmup_steps=0,
286
  ):
287
  accelerator = Accelerator(
288
  mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16",
@@ -491,6 +491,7 @@ if __name__ == "__main__":
491
 
492
  model = Transformer(args)
493
 
 
494
  RESUME_FROM = "checkpoints/step_200000.pt"
495
 
496
  if os.path.exists(RESUME_FROM):
@@ -505,6 +506,7 @@ if __name__ == "__main__":
505
  # Old format: checkpoint is directly the model state_dict
506
  model.load_state_dict(checkpoint)
507
  print(f"[Resume] Loaded model (old format)")
 
508
 
509
  train(
510
  model,
 
282
  grad_accum=8,
283
  epochs=1,
284
  lr=1e-5,
285
+ warmup_steps=500,
286
  ):
287
  accelerator = Accelerator(
288
  mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16",
 
491
 
492
  model = Transformer(args)
493
 
494
+ '''
495
  RESUME_FROM = "checkpoints/step_200000.pt"
496
 
497
  if os.path.exists(RESUME_FROM):
 
506
  # Old format: checkpoint is directly the model state_dict
507
  model.load_state_dict(checkpoint)
508
  print(f"[Resume] Loaded model (old format)")
509
+ '''
510
 
511
  train(
512
  model,