test2
Browse files- run_mlm_flax.py +16 -2
- run_step1.sh +2 -2
run_mlm_flax.py
CHANGED
|
@@ -126,6 +126,12 @@ class DataTrainingArguments:
|
|
| 126 |
overwrite_cache: bool = field(
|
| 127 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 128 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
validation_split_percentage: Optional[int] = field(
|
| 130 |
default=5,
|
| 131 |
metadata={
|
|
@@ -327,12 +333,14 @@ if __name__ == "__main__":
|
|
| 327 |
data_args.dataset_config_name,
|
| 328 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 329 |
cache_dir=model_args.cache_dir,
|
|
|
|
| 330 |
)
|
| 331 |
datasets["train"] = load_dataset(
|
| 332 |
data_args.dataset_name,
|
| 333 |
data_args.dataset_config_name,
|
| 334 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 335 |
cache_dir=model_args.cache_dir,
|
|
|
|
| 336 |
)
|
| 337 |
else:
|
| 338 |
data_files = {}
|
|
@@ -481,7 +489,7 @@ if __name__ == "__main__":
|
|
| 481 |
|
| 482 |
if model_args.model_name_or_path:
|
| 483 |
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
| 484 |
-
|
| 485 |
)
|
| 486 |
else:
|
| 487 |
model = FlaxAutoModelForMaskedLM.from_config(
|
|
@@ -499,9 +507,15 @@ if __name__ == "__main__":
|
|
| 499 |
warmup_fn = optax.linear_schedule(
|
| 500 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
| 501 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
decay_fn = optax.linear_schedule(
|
| 503 |
init_value=training_args.learning_rate,
|
| 504 |
-
end_value=
|
| 505 |
transition_steps=num_train_steps - training_args.warmup_steps,
|
| 506 |
)
|
| 507 |
linear_decay_lr_schedule_fn = optax.join_schedules(
|
|
|
|
| 126 |
overwrite_cache: bool = field(
|
| 127 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 128 |
)
|
| 129 |
+
static_learning_rate: bool = field(
|
| 130 |
+
default=False, metadata={"help": "Use a non decaying learning rate"}
|
| 131 |
+
)
|
| 132 |
+
auth_token: bool = field(
|
| 133 |
+
default=False, metadata={"help": "Use authorisation token"}
|
| 134 |
+
)
|
| 135 |
validation_split_percentage: Optional[int] = field(
|
| 136 |
default=5,
|
| 137 |
metadata={
|
|
|
|
| 333 |
data_args.dataset_config_name,
|
| 334 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 335 |
cache_dir=model_args.cache_dir,
|
| 336 |
+
use_auth_token=data_args.auth_token,
|
| 337 |
)
|
| 338 |
datasets["train"] = load_dataset(
|
| 339 |
data_args.dataset_name,
|
| 340 |
data_args.dataset_config_name,
|
| 341 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 342 |
cache_dir=model_args.cache_dir,
|
| 343 |
+
use_auth_token=data_args.auth_token,
|
| 344 |
)
|
| 345 |
else:
|
| 346 |
data_files = {}
|
|
|
|
| 489 |
|
| 490 |
if model_args.model_name_or_path:
|
| 491 |
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
| 492 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 493 |
)
|
| 494 |
else:
|
| 495 |
model = FlaxAutoModelForMaskedLM.from_config(
|
|
|
|
| 507 |
warmup_fn = optax.linear_schedule(
|
| 508 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
| 509 |
)
|
| 510 |
+
|
| 511 |
+
if data_argsdata_argtatic_learning_rate:
|
| 512 |
+
end_lr_value = training_args.learning_rate
|
| 513 |
+
else:
|
| 514 |
+
end_lr_value = 0
|
| 515 |
+
|
| 516 |
decay_fn = optax.linear_schedule(
|
| 517 |
init_value=training_args.learning_rate,
|
| 518 |
+
end_value=end_lr_value,
|
| 519 |
transition_steps=num_train_steps - training_args.warmup_steps,
|
| 520 |
)
|
| 521 |
linear_decay_lr_schedule_fn = optax.join_schedules(
|
run_step1.sh
CHANGED
|
@@ -10,7 +10,6 @@
|
|
| 10 |
--per_device_train_batch_size="256" \
|
| 11 |
--per_device_eval_batch_size="256" \
|
| 12 |
--learning_rate="2e-4" \
|
| 13 |
-
--end_learning_rate="2e-4" \
|
| 14 |
--warmup_steps="5000" \
|
| 15 |
--overwrite_output_dir \
|
| 16 |
--num_train_epochs="1000" \
|
|
@@ -20,5 +19,6 @@
|
|
| 20 |
--save_steps="5000" \
|
| 21 |
--eval_steps="5000" \
|
| 22 |
--preprocessing_num_workers="64" \
|
| 23 |
-
--use_auth_token
|
|
|
|
| 24 |
--push_to_hub
|
|
|
|
| 10 |
--per_device_train_batch_size="256" \
|
| 11 |
--per_device_eval_batch_size="256" \
|
| 12 |
--learning_rate="2e-4" \
|
|
|
|
| 13 |
--warmup_steps="5000" \
|
| 14 |
--overwrite_output_dir \
|
| 15 |
--num_train_epochs="1000" \
|
|
|
|
| 19 |
--save_steps="5000" \
|
| 20 |
--eval_steps="5000" \
|
| 21 |
--preprocessing_num_workers="64" \
|
| 22 |
+
--use_auth_token="True" \
|
| 23 |
+
--static_learning_rale="True" \
|
| 24 |
--push_to_hub
|