griffingoodwin04 commited on
Commit
06eb589
·
1 Parent(s): 0649ca2

enhance multi-GPU support

Browse files
Files changed (1) hide show
  1. forecasting/inference/inference.py +17 -6
forecasting/inference/inference.py CHANGED
@@ -85,6 +85,10 @@ def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_
85
  num_workers = config_data.get('num_workers', 4) if config_data else 4
86
  pin_memory = config_data.get('pin_memory', True) if config_data else True
87
 
 
 
 
 
88
  loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
89
  pin_memory=pin_memory, shuffle=False,
90
  multiprocessing_context='spawn' if num_workers > 0 else None)
@@ -114,11 +118,16 @@ def evaluate_model_on_dataset(model, dataset, batch_size=16, times=None, config_
114
  # FP16 on V100 can spike peak memory due to FP32 fallbacks in attention.
115
  # CRITICAL: ViTLocal defaults to return_attention=True, which uses massive memory
116
  # Only compute attention weights if we're saving them (save_weights=True)
 
 
 
 
 
117
  with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
118
  if save_weights:
119
- pred = model(aia_imgs, return_attention=True)
120
  else:
121
- pred = model(aia_imgs, return_attention=False)
122
 
123
  # Extract outputs (ViTLocal returns tuple of (predictions, attention_weights, flux_contributions) when return_attention=True)
124
  if isinstance(pred, tuple) and len(pred) >= 3:
@@ -361,12 +370,14 @@ def load_model_from_config(config_data):
361
 
362
  model.eval()
363
 
364
- if (config_data.get('multi_gpu', False)
365
- and torch.cuda.is_available()
366
- and torch.cuda.device_count() > 1):
 
 
367
  n_gpus = torch.cuda.device_count()
368
  model = torch.nn.DataParallel(model)
369
- print(f"Using DataParallel across {n_gpus} GPUs")
370
 
371
  return model
372
 
 
85
  num_workers = config_data.get('num_workers', 4) if config_data else 4
86
  pin_memory = config_data.get('pin_memory', True) if config_data else True
87
 
88
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
89
+ use_multi_gpu = str(config_data.get('multi_gpu', False) if config_data else False).lower() == 'true'
90
+ # Unwrapped model used as fallback for batches smaller than n_gpus
91
+ base_model = model.module if isinstance(model, torch.nn.DataParallel) else model
92
  loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
93
  pin_memory=pin_memory, shuffle=False,
94
  multiprocessing_context='spawn' if num_workers > 0 else None)
 
118
  # FP16 on V100 can spike peak memory due to FP32 fallbacks in attention.
119
  # CRITICAL: ViTLocal defaults to return_attention=True, which uses massive memory
120
  # Only compute attention weights if we're saving them (save_weights=True)
121
+ # Fall back to single GPU for batches smaller than n_gpus to avoid
122
+ # DataParallel crashing when some replicas receive empty inputs.
123
+ active_model = (base_model
124
+ if use_multi_gpu and aia_imgs.shape[0] < n_gpus
125
+ else model)
126
  with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
127
  if save_weights:
128
+ pred = active_model(aia_imgs, return_attention=True)
129
  else:
130
+ pred = active_model(aia_imgs, return_attention=False)
131
 
132
  # Extract outputs (ViTLocal returns tuple of (predictions, attention_weights, flux_contributions) when return_attention=True)
133
  if isinstance(pred, tuple) and len(pred) >= 3:
 
370
 
371
  model.eval()
372
 
373
+ raw_multi_gpu = config_data.get('multi_gpu', False) if config_data else False
374
+ use_multi_gpu = (str(raw_multi_gpu).lower() == 'true'
375
+ and torch.cuda.is_available()
376
+ and torch.cuda.device_count() > 1)
377
+ if use_multi_gpu:
378
  n_gpus = torch.cuda.device_count()
379
  model = torch.nn.DataParallel(model)
380
+ print(f"Using DataParallel across {n_gpus} GPUs — ensure batch_size >= {n_gpus}")
381
 
382
  return model
383