pere commited on
Commit
3c02d6a
·
1 Parent(s): b2299d9

Update num workers

Browse files
events.out.tfevents.1672907612.t1v-n-29919176-w-3.477755.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04fa0820a8fc41435cd7772ace5ca6361aceebfe834d4daac2363ce4da933ed0
3
+ size 40
run_mlm_flax_stream.py CHANGED
@@ -564,7 +564,7 @@ if __name__ == "__main__":
564
  train_metrics = []
565
  eval_metrics = []
566
 
567
- training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=dataset.n_shards, collate_fn=lambda x: x))
568
 
569
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
570
  eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
 
564
  train_metrics = []
565
  eval_metrics = []
566
 
567
+ training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=max(33,dataset.n_shards), collate_fn=lambda x: x))
568
 
569
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
570
  eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)