Commit ·
06eb589
1
Parent(s): 0649ca2
enhance multi-GPU support
Browse files
forecasting/inference/inference.py
CHANGED
|
@@ -85,6 +85,10 @@ def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_
|
|
| 85 |
num_workers = config_data.get('num_workers', 4) if config_data else 4
|
| 86 |
pin_memory = config_data.get('pin_memory', True) if config_data else True
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
|
| 89 |
pin_memory=pin_memory, shuffle=False,
|
| 90 |
multiprocessing_context='spawn' if num_workers > 0 else None)
|
|
@@ -114,11 +118,16 @@ def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_
|
|
| 114 |
# FP16 on V100 can spike peak memory due to FP32 fallbacks in attention.
|
| 115 |
# CRITICAL: ViTLocal defaults to return_attention=True, which uses massive memory
|
| 116 |
# Only compute attention weights if we're saving them (save_weights=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
|
| 118 |
if save_weights:
|
| 119 |
-
pred =
|
| 120 |
else:
|
| 121 |
-
pred =
|
| 122 |
|
| 123 |
# Extract outputs (ViTLocal returns tuple of (predictions, attention_weights, flux_contributions) when return_attention=True)
|
| 124 |
if isinstance(pred, tuple) and len(pred) >= 3:
|
|
@@ -361,12 +370,14 @@ def load_model_from_config(config_data):
|
|
| 361 |
|
| 362 |
model.eval()
|
| 363 |
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
| 367 |
n_gpus = torch.cuda.device_count()
|
| 368 |
model = torch.nn.DataParallel(model)
|
| 369 |
-
print(f"Using DataParallel across {n_gpus} GPUs")
|
| 370 |
|
| 371 |
return model
|
| 372 |
|
|
|
|
| 85 |
num_workers = config_data.get('num_workers', 4) if config_data else 4
|
| 86 |
pin_memory = config_data.get('pin_memory', True) if config_data else True
|
| 87 |
|
| 88 |
+
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
| 89 |
+
use_multi_gpu = str(config_data.get('multi_gpu', False) if config_data else False).lower() == 'true'
|
| 90 |
+
# Unwrapped model used as fallback for batches smaller than n_gpus
|
| 91 |
+
base_model = model.module if isinstance(model, torch.nn.DataParallel) else model
|
| 92 |
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
|
| 93 |
pin_memory=pin_memory, shuffle=False,
|
| 94 |
multiprocessing_context='spawn' if num_workers > 0 else None)
|
|
|
|
| 118 |
# FP16 on V100 can spike peak memory due to FP32 fallbacks in attention.
|
| 119 |
# CRITICAL: ViTLocal defaults to return_attention=True, which uses massive memory
|
| 120 |
# Only compute attention weights if we're saving them (save_weights=True)
|
| 121 |
+
# Fall back to single GPU for batches smaller than n_gpus to avoid
|
| 122 |
+
# DataParallel crashing when some replicas receive empty inputs.
|
| 123 |
+
active_model = (base_model
|
| 124 |
+
if use_multi_gpu and aia_imgs.shape[0] < n_gpus
|
| 125 |
+
else model)
|
| 126 |
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
|
| 127 |
if save_weights:
|
| 128 |
+
pred = active_model(aia_imgs, return_attention=True)
|
| 129 |
else:
|
| 130 |
+
pred = active_model(aia_imgs, return_attention=False)
|
| 131 |
|
| 132 |
# Extract outputs (ViTLocal returns tuple of (predictions, attention_weights, flux_contributions) when return_attention=True)
|
| 133 |
if isinstance(pred, tuple) and len(pred) >= 3:
|
|
|
|
| 370 |
|
| 371 |
model.eval()
|
| 372 |
|
| 373 |
+
raw_multi_gpu = config_data.get('multi_gpu', False) if config_data else False
|
| 374 |
+
use_multi_gpu = (str(raw_multi_gpu).lower() == 'true'
|
| 375 |
+
and torch.cuda.is_available()
|
| 376 |
+
and torch.cuda.device_count() > 1)
|
| 377 |
+
if use_multi_gpu:
|
| 378 |
n_gpus = torch.cuda.device_count()
|
| 379 |
model = torch.nn.DataParallel(model)
|
| 380 |
+
print(f"Using DataParallel across {n_gpus} GPUs — ensure batch_size >= {n_gpus}")
|
| 381 |
|
| 382 |
return model
|
| 383 |
|