Spaces:
Runtime error
Runtime error
Update ComfyUI/comfy/model_management.py
Browse files
ComfyUI/comfy/model_management.py
CHANGED
|
@@ -163,30 +163,35 @@ def is_ixuca():
|
|
| 163 |
return False
|
| 164 |
|
| 165 |
def get_torch_device():
|
|
|
|
| 166 |
global directml_enabled
|
| 167 |
global cpu_state
|
|
|
|
| 168 |
if directml_enabled:
|
| 169 |
global directml_device
|
| 170 |
return directml_device
|
|
|
|
| 171 |
if cpu_state == CPUState.MPS:
|
| 172 |
return torch.device("mps")
|
|
|
|
| 173 |
if cpu_state == CPUState.CPU:
|
| 174 |
return torch.device("cpu")
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
| 186 |
else:
|
| 187 |
-
print("⚠️ No GPU found.
|
| 188 |
return torch.device("cpu")
|
| 189 |
|
|
|
|
| 190 |
def get_total_memory(dev=None, torch_total_too=False):
|
| 191 |
global directml_enabled
|
| 192 |
if dev is None:
|
|
|
|
| 163 |
return False
|
| 164 |
|
| 165 |
def get_torch_device():
|
| 166 |
+
import torch
|
| 167 |
global directml_enabled
|
| 168 |
global cpu_state
|
| 169 |
+
|
| 170 |
if directml_enabled:
|
| 171 |
global directml_device
|
| 172 |
return directml_device
|
| 173 |
+
|
| 174 |
if cpu_state == CPUState.MPS:
|
| 175 |
return torch.device("mps")
|
| 176 |
+
|
| 177 |
if cpu_state == CPUState.CPU:
|
| 178 |
return torch.device("cpu")
|
| 179 |
+
|
| 180 |
+
if is_intel_xpu():
|
| 181 |
+
return torch.device("xpu", torch.xpu.current_device())
|
| 182 |
+
elif is_ascend_npu():
|
| 183 |
+
return torch.device("npu", torch.npu.current_device())
|
| 184 |
+
elif is_mlu():
|
| 185 |
+
return torch.device("mlu", torch.mlu.current_device())
|
| 186 |
+
|
| 187 |
+
# Fallback to CUDA if available
|
| 188 |
+
if torch.cuda.is_available():
|
| 189 |
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
| 190 |
else:
|
| 191 |
+
print("⚠️ No compatible GPU found. Using CPU.")
|
| 192 |
return torch.device("cpu")
|
| 193 |
|
| 194 |
+
|
| 195 |
def get_total_memory(dev=None, torch_total_too=False):
|
| 196 |
global directml_enabled
|
| 197 |
if dev is None:
|