Spaces:
Sleeping
Sleeping
another torch.compile fix
Browse files- comfy_extras/nodes_freelunch.py +10 -10
comfy_extras/nodes_freelunch.py
CHANGED
|
@@ -41,18 +41,18 @@ class FreeU:
|
|
| 41 |
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
| 42 |
on_cpu_devices = {}
|
| 43 |
|
|
|
|
|
|
|
| 44 |
def output_block_patch(h, hsp, transformer_options):
|
| 45 |
scale = scale_dict.get(int(h.shape[1]), None)
|
| 46 |
if scale is not None:
|
| 47 |
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
|
| 48 |
-
|
| 49 |
-
device_str = str(hsp.device)
|
| 50 |
-
if device_str not in on_cpu_devices:
|
| 51 |
try:
|
| 52 |
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
| 53 |
except:
|
| 54 |
-
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(
|
| 55 |
-
on_cpu_devices[
|
| 56 |
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
| 57 |
else:
|
| 58 |
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
|
@@ -82,6 +82,8 @@ class FreeU_V2:
|
|
| 82 |
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
| 83 |
on_cpu_devices = {}
|
| 84 |
|
|
|
|
|
|
|
| 85 |
def output_block_patch(h, hsp, transformer_options):
|
| 86 |
scale = scale_dict.get(int(h.shape[1]), None)
|
| 87 |
if scale is not None:
|
|
@@ -93,14 +95,12 @@ class FreeU_V2:
|
|
| 93 |
|
| 94 |
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
|
| 95 |
|
| 96 |
-
|
| 97 |
-
device_str = str(hsp.device)
|
| 98 |
-
if device_str not in on_cpu_devices:
|
| 99 |
try:
|
| 100 |
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
| 101 |
except:
|
| 102 |
-
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(
|
| 103 |
-
on_cpu_devices[
|
| 104 |
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
| 105 |
else:
|
| 106 |
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
|
|
|
| 41 |
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
| 42 |
on_cpu_devices = {}
|
| 43 |
|
| 44 |
+
# Disable torch.compile for this function to avoid device access issues
|
| 45 |
+
@torch.compiler.disable
|
| 46 |
def output_block_patch(h, hsp, transformer_options):
|
| 47 |
scale = scale_dict.get(int(h.shape[1]), None)
|
| 48 |
if scale is not None:
|
| 49 |
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
|
| 50 |
+
if hsp.device not in on_cpu_devices:
|
|
|
|
|
|
|
| 51 |
try:
|
| 52 |
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
| 53 |
except:
|
| 54 |
+
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
|
| 55 |
+
on_cpu_devices[hsp.device] = True
|
| 56 |
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
| 57 |
else:
|
| 58 |
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
|
|
|
| 82 |
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
| 83 |
on_cpu_devices = {}
|
| 84 |
|
| 85 |
+
# Disable torch.compile for this function to avoid device access issues
|
| 86 |
+
@torch.compiler.disable
|
| 87 |
def output_block_patch(h, hsp, transformer_options):
|
| 88 |
scale = scale_dict.get(int(h.shape[1]), None)
|
| 89 |
if scale is not None:
|
|
|
|
| 95 |
|
| 96 |
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
|
| 97 |
|
| 98 |
+
if hsp.device not in on_cpu_devices:
|
|
|
|
|
|
|
| 99 |
try:
|
| 100 |
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
| 101 |
except:
|
| 102 |
+
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
|
| 103 |
+
on_cpu_devices[hsp.device] = True
|
| 104 |
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
| 105 |
else:
|
| 106 |
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|