another attempt with train files locally
Browse files- run_flax.sh +2 -1
- run_mlm_flax.py +1 -2
run_flax.sh
CHANGED
|
@@ -3,7 +3,8 @@
|
|
| 3 |
--model_type="roberta" \
|
| 4 |
--config_name="./" \
|
| 5 |
--tokenizer_name="./" \
|
| 6 |
-
--
|
|
|
|
| 7 |
--max_seq_length="128" \
|
| 8 |
--weight_decay="0.01" \
|
| 9 |
--per_device_train_batch_size="232" \
|
|
|
|
| 3 |
--model_type="roberta" \
|
| 4 |
--config_name="./" \
|
| 5 |
--tokenizer_name="./" \
|
| 6 |
+
--train_file="/mnt/disks/flaxdisk/smallcorpus/train-shard-0001-of-0001.json" \
|
| 7 |
+
--validation_file="/mnt/disks/flaxdisk/smallcorpus/validation-shard-0001-of-0001.json" \
|
| 8 |
--max_seq_length="128" \
|
| 9 |
--weight_decay="0.01" \
|
| 10 |
--per_device_train_batch_size="232" \
|
run_mlm_flax.py
CHANGED
|
@@ -317,10 +317,9 @@ if __name__ == "__main__":
|
|
| 317 |
#
|
| 318 |
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 319 |
# download the dataset.
|
| 320 |
-
chunksize = 10<<20
|
| 321 |
if data_args.dataset_name is not None:
|
| 322 |
# Downloading and loading a dataset from the hub.
|
| 323 |
-
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
|
| 324 |
|
| 325 |
if "validation" not in datasets.keys():
|
| 326 |
datasets["validation"] = load_dataset(
|
|
|
|
| 317 |
#
|
| 318 |
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
| 319 |
# download the dataset.
|
|
|
|
| 320 |
if data_args.dataset_name is not None:
|
| 321 |
# Downloading and loading a dataset from the hub.
|
| 322 |
+
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
| 323 |
|
| 324 |
if "validation" not in datasets.keys():
|
| 325 |
datasets["validation"] = load_dataset(
|