nroggendorff commited on
Commit
dc942f1
·
verified ·
1 Parent(s): 87153fd

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +1 -1
train.py CHANGED
@@ -102,7 +102,7 @@ else:
102
  num_gpus = torch.cuda.device_count()
103
  models = [load_model(device_id=i) for i in range(num_gpus)]
104
 
105
- batch_size = 8
106
  shard_size = len(ds) // num_gpus
107
 
108
 
 
102
  num_gpus = torch.cuda.device_count()
103
  models = [load_model(device_id=i) for i in range(num_gpus)]
104
 
105
+ batch_size = 32
106
  shard_size = len(ds) // num_gpus
107
 
108