Commit ·
6ece3d2
1
Parent(s): 06eb589
enhance multi-GPU support
Browse files
forecasting/inference/inference.py
CHANGED
|
@@ -121,7 +121,7 @@ def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_
|
|
| 121 |
# Fall back to single GPU for batches smaller than n_gpus to avoid
|
| 122 |
# DataParallel crashing when some replicas receive empty inputs.
|
| 123 |
active_model = (base_model
|
| 124 |
-
if
|
| 125 |
else model)
|
| 126 |
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
|
| 127 |
if save_weights:
|
|
|
|
| 121 |
# Fall back to single GPU for batches smaller than n_gpus to avoid
|
| 122 |
# DataParallel crashing when some replicas receive empty inputs.
|
| 123 |
active_model = (base_model
|
| 124 |
+
if isinstance(model, torch.nn.DataParallel) and aia_imgs.shape[0] < n_gpus
|
| 125 |
else model)
|
| 126 |
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
|
| 127 |
if save_weights:
|