pere commited on
Commit
fcbb238
·
1 Parent(s): f4fe90b
Files changed (2) hide show
  1. run_mlm_flax.py +16 -2
  2. run_step1.sh +2 -2
run_mlm_flax.py CHANGED
@@ -126,6 +126,12 @@ class DataTrainingArguments:
126
  overwrite_cache: bool = field(
127
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
128
  )
 
 
 
 
 
 
129
  validation_split_percentage: Optional[int] = field(
130
  default=5,
131
  metadata={
@@ -327,12 +333,14 @@ if __name__ == "__main__":
327
  data_args.dataset_config_name,
328
  split=f"train[:{data_args.validation_split_percentage}%]",
329
  cache_dir=model_args.cache_dir,
 
330
  )
331
  datasets["train"] = load_dataset(
332
  data_args.dataset_name,
333
  data_args.dataset_config_name,
334
  split=f"train[{data_args.validation_split_percentage}%:]",
335
  cache_dir=model_args.cache_dir,
 
336
  )
337
  else:
338
  data_files = {}
@@ -481,7 +489,7 @@ if __name__ == "__main__":
481
 
482
  if model_args.model_name_or_path:
483
  model = FlaxAutoModelForMaskedLM.from_pretrained(
484
- model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
485
  )
486
  else:
487
  model = FlaxAutoModelForMaskedLM.from_config(
@@ -499,9 +507,15 @@ if __name__ == "__main__":
499
  warmup_fn = optax.linear_schedule(
500
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
501
  )
 
 
 
 
 
 
502
  decay_fn = optax.linear_schedule(
503
  init_value=training_args.learning_rate,
504
- end_value=0,
505
  transition_steps=num_train_steps - training_args.warmup_steps,
506
  )
507
  linear_decay_lr_schedule_fn = optax.join_schedules(
 
126
  overwrite_cache: bool = field(
127
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
128
  )
129
+ static_learning_rate: bool = field(
130
+ default=False, metadata={"help": "Use a non decaying learning rate"}
131
+ )
132
+ auth_token: bool = field(
133
+ default=False, metadata={"help": "Use authorisation token"}
134
+ )
135
  validation_split_percentage: Optional[int] = field(
136
  default=5,
137
  metadata={
 
333
  data_args.dataset_config_name,
334
  split=f"train[:{data_args.validation_split_percentage}%]",
335
  cache_dir=model_args.cache_dir,
336
+ use_auth_token=data_args.auth_token,
337
  )
338
  datasets["train"] = load_dataset(
339
  data_args.dataset_name,
340
  data_args.dataset_config_name,
341
  split=f"train[{data_args.validation_split_percentage}%:]",
342
  cache_dir=model_args.cache_dir,
343
+ use_auth_token=data_args.auth_token,
344
  )
345
  else:
346
  data_files = {}
 
489
 
490
  if model_args.model_name_or_path:
491
  model = FlaxAutoModelForMaskedLM.from_pretrained(
492
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
493
  )
494
  else:
495
  model = FlaxAutoModelForMaskedLM.from_config(
 
507
  warmup_fn = optax.linear_schedule(
508
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
509
  )
510
+
511
+ if data_argsdata_argtatic_learning_rate:
512
+ end_lr_value = training_args.learning_rate
513
+ else:
514
+ end_lr_value = 0
515
+
516
  decay_fn = optax.linear_schedule(
517
  init_value=training_args.learning_rate,
518
+ end_value=end_lr_value,
519
  transition_steps=num_train_steps - training_args.warmup_steps,
520
  )
521
  linear_decay_lr_schedule_fn = optax.join_schedules(
run_step1.sh CHANGED
@@ -10,7 +10,6 @@
10
  --per_device_train_batch_size="256" \
11
  --per_device_eval_batch_size="256" \
12
  --learning_rate="2e-4" \
13
- --end_learning_rate="2e-4" \
14
  --warmup_steps="5000" \
15
  --overwrite_output_dir \
16
  --num_train_epochs="1000" \
@@ -20,5 +19,6 @@
20
  --save_steps="5000" \
21
  --eval_steps="5000" \
22
  --preprocessing_num_workers="64" \
23
- --use_auth_token
 
24
  --push_to_hub
 
10
  --per_device_train_batch_size="256" \
11
  --per_device_eval_batch_size="256" \
12
  --learning_rate="2e-4" \
 
13
  --warmup_steps="5000" \
14
  --overwrite_output_dir \
15
  --num_train_epochs="1000" \
 
19
  --save_steps="5000" \
20
  --eval_steps="5000" \
21
  --preprocessing_num_workers="64" \
22
+ --use_auth_token="True" \
23
+ --static_learning_rale="True" \
24
  --push_to_hub