Alikestocode commited on
Commit
75aac04
·
1 Parent(s): 34ee4d1

Improve vLLM device detection: force torch CUDA reinit

Browse files
Files changed (1) hide show
  1. app.py +12 -0
app.py CHANGED
@@ -215,11 +215,20 @@ def load_vllm_model(model_name: str):
215
  # Ensure CUDA_VISIBLE_DEVICES is set correctly for vLLM device detection
216
  # ZeroGPU uses MIG UUIDs, but vLLM needs numeric device index
217
  # IMPORTANT: Set this BEFORE creating LLM() instance, as vLLM checks device during init
 
218
  cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
219
  if not cuda_visible or not cuda_visible.isdigit():
220
  # If CUDA_VISIBLE_DEVICES is a MIG UUID or empty, use "0" for single GPU
221
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
222
  print(f" → Set CUDA_VISIBLE_DEVICES=0 (was: {cuda_visible})")
 
 
 
 
 
 
 
 
223
 
224
  # Force torch to see the correct device after setting CUDA_VISIBLE_DEVICES
225
  # This ensures vLLM's device detection works correctly
@@ -228,6 +237,9 @@ def load_vllm_model(model_name: str):
228
  # Verify device is accessible
229
  device_name = torch.cuda.get_device_name(0)
230
  print(f" → Verified CUDA device accessible: {device_name}")
 
 
 
231
 
232
  # Add quantization if specified (vLLM auto-detects AWQ via llm-compressor)
233
  if quantization == "awq":
 
215
  # Ensure CUDA_VISIBLE_DEVICES is set correctly for vLLM device detection
216
  # ZeroGPU uses MIG UUIDs, but vLLM needs numeric device index
217
  # IMPORTANT: Set this BEFORE creating LLM() instance, as vLLM checks device during init
218
+ # Also need to ensure torch sees the change immediately
219
  cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
220
  if not cuda_visible or not cuda_visible.isdigit():
221
  # If CUDA_VISIBLE_DEVICES is a MIG UUID or empty, use "0" for single GPU
222
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
223
  print(f" → Set CUDA_VISIBLE_DEVICES=0 (was: {cuda_visible})")
224
+ # Force torch to reinitialize CUDA context after changing CUDA_VISIBLE_DEVICES
225
+ # This ensures vLLM sees the correct device
226
+ try:
227
+ import torch
228
+ if hasattr(torch.cuda, '_lazy_init'):
229
+ torch.cuda._lazy_init()
230
+ except Exception:
231
+ pass
232
 
233
  # Force torch to see the correct device after setting CUDA_VISIBLE_DEVICES
234
  # This ensures vLLM's device detection works correctly
 
237
  # Verify device is accessible
238
  device_name = torch.cuda.get_device_name(0)
239
  print(f" → Verified CUDA device accessible: {device_name}")
240
+ # Explicitly set default device to ensure vLLM can detect it
241
+ torch.cuda.set_device(0)
242
+ print(f" → Set torch.cuda default device to 0")
243
 
244
  # Add quantization if specified (vLLM auto-detects AWQ via llm-compressor)
245
  if quantization == "awq":