techyygarry commited on
Commit
aecbe31
·
verified ·
1 Parent(s): 6f53b42

Update ComfyUI/comfy/model_management.py

Browse files
Files changed (1) hide show
  1. ComfyUI/comfy/model_management.py +16 -11
ComfyUI/comfy/model_management.py CHANGED
@@ -163,30 +163,35 @@ def is_ixuca():
163
  return False
164
 
165
  def get_torch_device():
 
166
  global directml_enabled
167
  global cpu_state
 
168
  if directml_enabled:
169
  global directml_device
170
  return directml_device
 
171
  if cpu_state == CPUState.MPS:
172
  return torch.device("mps")
 
173
  if cpu_state == CPUState.CPU:
174
  return torch.device("cpu")
175
- else:
176
- if is_intel_xpu():
177
- return torch.device("xpu", torch.xpu.current_device())
178
- elif is_ascend_npu():
179
- return torch.device("npu", torch.npu.current_device())
180
- elif is_mlu():
181
- return torch.device("mlu", torch.mlu.current_device())
182
- else:
183
- return torch.device(torch.cuda.current_device())
184
- if torch.cuda.is_available():
185
  return torch.device(f"cuda:{torch.cuda.current_device()}")
186
  else:
187
- print("⚠️ No GPU found. Falling back to CPU.")
188
  return torch.device("cpu")
189
 
 
190
  def get_total_memory(dev=None, torch_total_too=False):
191
  global directml_enabled
192
  if dev is None:
 
163
  return False
164
 
165
  def get_torch_device():
166
+ import torch
167
  global directml_enabled
168
  global cpu_state
169
+
170
  if directml_enabled:
171
  global directml_device
172
  return directml_device
173
+
174
  if cpu_state == CPUState.MPS:
175
  return torch.device("mps")
176
+
177
  if cpu_state == CPUState.CPU:
178
  return torch.device("cpu")
179
+
180
+ if is_intel_xpu():
181
+ return torch.device("xpu", torch.xpu.current_device())
182
+ elif is_ascend_npu():
183
+ return torch.device("npu", torch.npu.current_device())
184
+ elif is_mlu():
185
+ return torch.device("mlu", torch.mlu.current_device())
186
+
187
+ # Fallback to CUDA if available
188
+ if torch.cuda.is_available():
189
  return torch.device(f"cuda:{torch.cuda.current_device()}")
190
  else:
191
+ print("⚠️ No compatible GPU found. Using CPU.")
192
  return torch.device("cpu")
193
 
194
+
195
  def get_total_memory(dev=None, torch_total_too=False):
196
  global directml_enabled
197
  if dev is None: