pere commited on
Commit
00e6514
·
1 Parent(s): 50b0086

another attempt with train files locally

Browse files
Files changed (2) hide show
  1. run_flax.sh +2 -1
  2. run_mlm_flax.py +1 -2
run_flax.sh CHANGED
@@ -3,7 +3,8 @@
3
  --model_type="roberta" \
4
  --config_name="./" \
5
  --tokenizer_name="./" \
6
- --dataset_name="NbAiLab/NCC_small" \
 
7
  --max_seq_length="128" \
8
  --weight_decay="0.01" \
9
  --per_device_train_batch_size="232" \
 
3
  --model_type="roberta" \
4
  --config_name="./" \
5
  --tokenizer_name="./" \
6
+ --train_file="/mnt/disks/flaxdisk/smallcorpus/train-shard-0001-of-0001.json" \
7
+ --validation_file="/mnt/disks/flaxdisk/smallcorpus/validation-shard-0001-of-0001.json" \
8
  --max_seq_length="128" \
9
  --weight_decay="0.01" \
10
  --per_device_train_batch_size="232" \
run_mlm_flax.py CHANGED
@@ -317,10 +317,9 @@ if __name__ == "__main__":
317
  #
318
  # In distributed training, the load_dataset function guarantees that only one local process can concurrently
319
  # download the dataset.
320
- chunksize = 10<<20
321
  if data_args.dataset_name is not None:
322
  # Downloading and loading a dataset from the hub.
323
- datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, chunksize=chunksize)
324
 
325
  if "validation" not in datasets.keys():
326
  datasets["validation"] = load_dataset(
 
317
  #
318
  # In distributed training, the load_dataset function guarantees that only one local process can concurrently
319
  # download the dataset.
 
320
  if data_args.dataset_name is not None:
321
  # Downloading and loading a dataset from the hub.
322
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
323
 
324
  if "validation" not in datasets.keys():
325
  datasets["validation"] = load_dataset(