pere commited on
Commit
2f1afca
·
1 Parent(s): 055fdb6

change number of shards

Browse files
Files changed (1) hide show
  1. run_mlm_flax_stream.py +1 -1
run_mlm_flax_stream.py CHANGED
@@ -581,7 +581,7 @@ if __name__ == "__main__":
581
  shuffle_seed += 1
582
  tokenized_datasets.set_epoch(shuffle_seed)
583
 
584
- training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=dataset.n_shards, collate_fn=lambda x: x))
585
 
586
  eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
587
  samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
 
581
  shuffle_seed += 1
582
  tokenized_datasets.set_epoch(shuffle_seed)
583
 
584
+ training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=max(33,dataset.n_shards), collate_fn=lambda x: x))
585
 
586
  eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
587
  samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)