Update comfy/model_management.py
Browse files
comfy/model_management.py
CHANGED
|
@@ -25,6 +25,8 @@ import sys
|
|
| 25 |
import platform
|
| 26 |
import weakref
|
| 27 |
import gc
|
|
|
|
|
|
|
| 28 |
|
| 29 |
class VRAMState(Enum):
|
| 30 |
DISABLED = 0 #No vram present: no need to move models to vram
|
|
@@ -117,16 +119,16 @@ def get_torch_device():
|
|
| 117 |
global directml_device
|
| 118 |
return directml_device
|
| 119 |
if cpu_state == CPUState.MPS:
|
| 120 |
-
return torch.device("
|
| 121 |
if cpu_state == CPUState.CPU:
|
| 122 |
return torch.device("cpu")
|
| 123 |
else:
|
| 124 |
if is_intel_xpu():
|
| 125 |
-
return torch.device("
|
| 126 |
elif is_ascend_npu():
|
| 127 |
-
return torch.device("
|
| 128 |
else:
|
| 129 |
-
return torch.device(
|
| 130 |
|
| 131 |
def get_total_memory(dev=None, torch_total_too=False):
|
| 132 |
global directml_enabled
|
|
@@ -790,7 +792,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|
| 790 |
def get_autocast_device(dev):
|
| 791 |
if hasattr(dev, 'type'):
|
| 792 |
return dev.type
|
| 793 |
-
return "
|
| 794 |
|
| 795 |
def supports_dtype(device, dtype): #TODO
|
| 796 |
if dtype == torch.float32:
|
|
|
|
| 25 |
import platform
|
| 26 |
import weakref
|
| 27 |
import gc
|
| 28 |
+
import os
|
| 29 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
| 30 |
|
| 31 |
class VRAMState(Enum):
|
| 32 |
DISABLED = 0 #No vram present: no need to move models to vram
|
|
|
|
| 119 |
global directml_device
|
| 120 |
return directml_device
|
| 121 |
if cpu_state == CPUState.MPS:
|
| 122 |
+
return torch.device("cpu")
|
| 123 |
if cpu_state == CPUState.CPU:
|
| 124 |
return torch.device("cpu")
|
| 125 |
else:
|
| 126 |
if is_intel_xpu():
|
| 127 |
+
return torch.device("cpu")
|
| 128 |
elif is_ascend_npu():
|
| 129 |
+
return torch.device("cpu")
|
| 130 |
else:
|
| 131 |
+
return torch.device("cpu")
|
| 132 |
|
| 133 |
def get_total_memory(dev=None, torch_total_too=False):
|
| 134 |
global directml_enabled
|
|
|
|
| 792 |
def get_autocast_device(dev):
|
| 793 |
if hasattr(dev, 'type'):
|
| 794 |
return dev.type
|
| 795 |
+
return "cpu"
|
| 796 |
|
| 797 |
def supports_dtype(device, dtype): #TODO
|
| 798 |
if dtype == torch.float32:
|