TRM-coding commited on
Commit
3f1968c
·
verified ·
1 Parent(s): 377e9d8

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +4 -4
train.py CHANGED
@@ -124,13 +124,13 @@ config = {"train_batch_size": 2,
124
  "shuffle_buffer": 1000,
125
  "learning_rate": 5e-4,
126
  "lr_scheduler_type": "cosine",
127
- "num_warmup_steps": 0,
128
  "gradient_accumulation_steps": 1,
129
- "max_train_steps": 15,
130
- "max_eval_steps": 15,
131
  "seq_length": 1024,
132
  "seed": 1,
133
- "save_checkpoint_steps": 10}
134
  args = Namespace(**config, **acc_state)
135
  samples_per_step = accelerator.state.num_processes * args.train_batch_size
136
  set_seed(args.seed)
 
124
  "shuffle_buffer": 1000,
125
  "learning_rate": 5e-4,
126
  "lr_scheduler_type": "cosine",
127
+ "num_warmup_steps": 2000,
128
  "gradient_accumulation_steps": 1,
129
+ "max_train_steps": 150000,
130
+ "max_eval_steps": -1,
131
  "seq_length": 1024,
132
  "seed": 1,
133
+ "save_checkpoint_steps": 15000}
134
  args = Namespace(**config, **acc_state)
135
  samples_per_step = accelerator.state.num_processes * args.train_batch_size
136
  set_seed(args.seed)