Update num workers
Browse files
events.out.tfevents.1672907612.t1v-n-29919176-w-3.477755.0.v2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04fa0820a8fc41435cd7772ace5ca6361aceebfe834d4daac2363ce4da933ed0
|
| 3 |
+
size 40
|
run_mlm_flax_stream.py
CHANGED
|
@@ -564,7 +564,7 @@ if __name__ == "__main__":
|
|
| 564 |
train_metrics = []
|
| 565 |
eval_metrics = []
|
| 566 |
|
| 567 |
-
training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=dataset.n_shards, collate_fn=lambda x: x))
|
| 568 |
|
| 569 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
| 570 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
|
|
|
| 564 |
train_metrics = []
|
| 565 |
eval_metrics = []
|
| 566 |
|
| 567 |
+
training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=max(33,dataset.n_shards), collate_fn=lambda x: x))
|
| 568 |
|
| 569 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
| 570 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|