Update train.py
Browse files
train.py
CHANGED
|
@@ -124,13 +124,13 @@ config = {"train_batch_size": 2,
|
|
| 124 |
"shuffle_buffer": 1000,
|
| 125 |
"learning_rate": 5e-4,
|
| 126 |
"lr_scheduler_type": "cosine",
|
| 127 |
-
"num_warmup_steps":
|
| 128 |
"gradient_accumulation_steps": 1,
|
| 129 |
-
"max_train_steps":
|
| 130 |
-
"max_eval_steps":
|
| 131 |
"seq_length": 1024,
|
| 132 |
"seed": 1,
|
| 133 |
-
"save_checkpoint_steps":
|
| 134 |
args = Namespace(**config, **acc_state)
|
| 135 |
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
| 136 |
set_seed(args.seed)
|
|
|
|
| 124 |
"shuffle_buffer": 1000,
|
| 125 |
"learning_rate": 5e-4,
|
| 126 |
"lr_scheduler_type": "cosine",
|
| 127 |
+
"num_warmup_steps": 2000,
|
| 128 |
"gradient_accumulation_steps": 1,
|
| 129 |
+
"max_train_steps": 150000,
|
| 130 |
+
"max_eval_steps": -1,
|
| 131 |
"seq_length": 1024,
|
| 132 |
"seed": 1,
|
| 133 |
+
"save_checkpoint_steps": 15000}
|
| 134 |
args = Namespace(**config, **acc_state)
|
| 135 |
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
| 136 |
set_seed(args.seed)
|