Commit ·
b1b3841
1
Parent(s): 37fa322
Saving weights and logs of step 8
Browse files- README.md +0 -0
- config.json +0 -0
- create_config.py +0 -0
- flax_model.msgpack +3 -0
- run.sh +5 -2
- run.sh.save +16 -0
- run_mlm_flax.py +39 -35
- tokenizer.json +0 -0
- train_tokenizer.py +0 -0
README.md
CHANGED
|
File without changes
|
config.json
CHANGED
|
File without changes
|
create_config.py
CHANGED
|
File without changes
|
flax_model.msgpack
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2e26684c7b415b88900d2b10f657004a4262d41aca55f28d52013b051535c43
|
| 3 |
+
size 498796983
|
run.sh
CHANGED
|
@@ -7,10 +7,13 @@
|
|
| 7 |
--dataset_name="oscar" \
|
| 8 |
--dataset_config_name="unshuffled_deduplicated_it" \
|
| 9 |
--max_seq_length="128" \
|
| 10 |
-
--per_device_train_batch_size="
|
| 11 |
-
--per_device_eval_batch_size="
|
| 12 |
--learning_rate="3e-4" \
|
| 13 |
--warmup_steps="1000" \
|
| 14 |
--overwrite_output_dir \
|
| 15 |
--num_train_epochs="8" \
|
|
|
|
|
|
|
|
|
|
| 16 |
--push_to_hub
|
|
|
|
| 7 |
--dataset_name="oscar" \
|
| 8 |
--dataset_config_name="unshuffled_deduplicated_it" \
|
| 9 |
--max_seq_length="128" \
|
| 10 |
+
--per_device_train_batch_size="1" \
|
| 11 |
+
--per_device_eval_batch_size="1" \
|
| 12 |
--learning_rate="3e-4" \
|
| 13 |
--warmup_steps="1000" \
|
| 14 |
--overwrite_output_dir \
|
| 15 |
--num_train_epochs="8" \
|
| 16 |
+
--logging_steps="10" \
|
| 17 |
+
--save_steps="8" \
|
| 18 |
+
--eval_steps="15" \
|
| 19 |
--push_to_hub
|
run.sh.save
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/usr/bin/env bash
|
| 2 |
+
./run_mlm_flax.py \
|
| 3 |
+
--output_dir="./" \
|
| 4 |
+
--model_type="roberta" \
|
| 5 |
+
--config_name="./" \
|
| 6 |
+
--tokenizer_name="./" \
|
| 7 |
+
--dataset_name="oscar" \
|
| 8 |
+
--dataset_config_name="unshuffled_deduplicated_it" \
|
| 9 |
+
--max_seq_length="128" \
|
| 10 |
+
--per_device_train_batch_size="4" \
|
| 11 |
+
--per_device_eval_batch_size="4" \
|
| 12 |
+
--learning_rate="3e-4" \
|
| 13 |
+
--warmup_steps="1000" \
|
| 14 |
+
--overwrite_output_dir \
|
| 15 |
+
--num_train_epochs="8" \
|
| 16 |
+
--push_to_hub
|
run_mlm_flax.py
CHANGED
|
@@ -297,6 +297,10 @@ if __name__ == "__main__":
|
|
| 297 |
if extension == "txt":
|
| 298 |
extension = "text"
|
| 299 |
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 301 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 302 |
# Load pretrained model and tokenizer
|
|
@@ -512,7 +516,7 @@ if __name__ == "__main__":
|
|
| 512 |
model_inputs = shard(model_inputs.data)
|
| 513 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 514 |
train_metrics.append(train_metric)
|
| 515 |
-
cur_step = epoch * num_train_samples + step
|
| 516 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
| 517 |
# Save metrics
|
| 518 |
train_metric = jax_utils.unreplicate(train_metric)
|
|
@@ -523,37 +527,37 @@ if __name__ == "__main__":
|
|
| 523 |
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 524 |
)
|
| 525 |
train_metrics = []
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
|
|
|
| 297 |
if extension == "txt":
|
| 298 |
extension = "text"
|
| 299 |
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
| 300 |
+
|
| 301 |
+
datasets["train"] = datasets["train"].select(range(10000))
|
| 302 |
+
datasets["validation"] = datasets["validation"].select(range(1000))
|
| 303 |
+
|
| 304 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
| 305 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 306 |
# Load pretrained model and tokenizer
|
|
|
|
| 516 |
model_inputs = shard(model_inputs.data)
|
| 517 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 518 |
train_metrics.append(train_metric)
|
| 519 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
| 520 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
| 521 |
# Save metrics
|
| 522 |
train_metric = jax_utils.unreplicate(train_metric)
|
|
|
|
| 527 |
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 528 |
)
|
| 529 |
train_metrics = []
|
| 530 |
+
if cur_step % training_args.eval_steps == 0 and step > 0:
|
| 531 |
+
# ======================== Evaluating ==============================
|
| 532 |
+
num_eval_samples = len(tokenized_datasets["validation"])
|
| 533 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
| 534 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
| 535 |
+
eval_metrics = []
|
| 536 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
| 537 |
+
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
| 538 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
| 539 |
+
# Model forward
|
| 540 |
+
model_inputs = shard(model_inputs.data)
|
| 541 |
+
metrics = p_eval_step(state.params, model_inputs)
|
| 542 |
+
eval_metrics.append(metrics)
|
| 543 |
+
# normalize eval metrics
|
| 544 |
+
eval_metrics = get_metrics(eval_metrics)
|
| 545 |
+
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
| 546 |
+
eval_normalizer = eval_metrics.pop("normalizer")
|
| 547 |
+
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
| 548 |
+
# Update progress bar
|
| 549 |
+
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
| 550 |
+
# Save metrics
|
| 551 |
+
if has_tensorboard and jax.process_index() == 0:
|
| 552 |
+
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
| 553 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
| 554 |
+
if cur_step % training_args.save_steps == 0 and step > 0:
|
| 555 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
| 556 |
+
if jax.process_index() == 0:
|
| 557 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 558 |
+
model.save_pretrained(
|
| 559 |
+
training_args.output_dir,
|
| 560 |
+
params=params,
|
| 561 |
+
push_to_hub=training_args.push_to_hub,
|
| 562 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 563 |
+
)
|
tokenizer.json
CHANGED
|
File without changes
|
train_tokenizer.py
CHANGED
|
File without changes
|