Aman K
commited on
Commit
·
2e5979b
1
Parent(s):
0b86536
Updated code to have different seed and reduced lr
Browse files- run.sh +1 -1
- run_mlm_flax.py +10 -0
run.sh
CHANGED
|
@@ -11,7 +11,7 @@
|
|
| 11 |
--preprocessing_num_workers="64" \
|
| 12 |
--per_device_train_batch_size="64" \
|
| 13 |
--per_device_eval_batch_size="64" \
|
| 14 |
-
--learning_rate="
|
| 15 |
--warmup_steps="1000" \
|
| 16 |
--overwrite_output_dir \
|
| 17 |
--num_train_epochs="8" \
|
|
|
|
| 11 |
--preprocessing_num_workers="64" \
|
| 12 |
--per_device_train_batch_size="64" \
|
| 13 |
--per_device_eval_batch_size="64" \
|
| 14 |
+
--learning_rate="2e-4" \
|
| 15 |
--warmup_steps="1000" \
|
| 16 |
--overwrite_output_dir \
|
| 17 |
--num_train_epochs="8" \
|
run_mlm_flax.py
CHANGED
|
@@ -324,6 +324,7 @@ if __name__ == "__main__":
|
|
| 324 |
logger.info(f"Training/evaluation parameters {training_args}")
|
| 325 |
|
| 326 |
# Set seed before initializing model.
|
|
|
|
| 327 |
set_seed(training_args.seed)
|
| 328 |
|
| 329 |
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
|
@@ -587,6 +588,7 @@ if __name__ == "__main__":
|
|
| 587 |
|
| 588 |
train_time = 0
|
| 589 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
|
|
| 590 |
for epoch in epochs:
|
| 591 |
# ======================== Training ================================
|
| 592 |
train_start = time.time()
|
|
@@ -609,6 +611,14 @@ if __name__ == "__main__":
|
|
| 609 |
model_inputs = shard(model_inputs.data)
|
| 610 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 611 |
train_metrics.append(train_metric)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
|
| 613 |
train_time += time.time() - train_start
|
| 614 |
|
|
|
|
| 324 |
logger.info(f"Training/evaluation parameters {training_args}")
|
| 325 |
|
| 326 |
# Set seed before initializing model.
|
| 327 |
+
training_args.seed = 42
|
| 328 |
set_seed(training_args.seed)
|
| 329 |
|
| 330 |
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
|
|
|
| 588 |
|
| 589 |
train_time = 0
|
| 590 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 591 |
+
save_checkpoint=True
|
| 592 |
for epoch in epochs:
|
| 593 |
# ======================== Training ================================
|
| 594 |
train_start = time.time()
|
|
|
|
| 611 |
model_inputs = shard(model_inputs.data)
|
| 612 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 613 |
train_metrics.append(train_metric)
|
| 614 |
+
if save_checkpoint and (train_metric['loss'] < 5.).all():
|
| 615 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 616 |
+
model.save_pretrained(
|
| 617 |
+
'/home/khandelia1000/checkpoints/',
|
| 618 |
+
params=params,
|
| 619 |
+
push_to_hub=False
|
| 620 |
+
)
|
| 621 |
+
save_checkpoint = False
|
| 622 |
|
| 623 |
train_time += time.time() - train_start
|
| 624 |
|