Oysiyl commited on
Commit
4001d78
·
1 Parent(s): 413ecd6

another torch.compile fix

Browse files
Files changed (1) hide show
  1. comfy_extras/nodes_freelunch.py +10 -10
comfy_extras/nodes_freelunch.py CHANGED
@@ -41,18 +41,18 @@ class FreeU:
41
  scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
42
  on_cpu_devices = {}
43
 
 
 
44
  def output_block_patch(h, hsp, transformer_options):
45
  scale = scale_dict.get(int(h.shape[1]), None)
46
  if scale is not None:
47
  h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
48
- # Convert device to string for torch.compile compatibility
49
- device_str = str(hsp.device)
50
- if device_str not in on_cpu_devices:
51
  try:
52
  hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
53
  except:
54
- logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(device_str))
55
- on_cpu_devices[device_str] = True
56
  hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
57
  else:
58
  hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
@@ -82,6 +82,8 @@ class FreeU_V2:
82
  scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
83
  on_cpu_devices = {}
84
 
 
 
85
  def output_block_patch(h, hsp, transformer_options):
86
  scale = scale_dict.get(int(h.shape[1]), None)
87
  if scale is not None:
@@ -93,14 +95,12 @@ class FreeU_V2:
93
 
94
  h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
95
 
96
- # Convert device to string for torch.compile compatibility
97
- device_str = str(hsp.device)
98
- if device_str not in on_cpu_devices:
99
  try:
100
  hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
101
  except:
102
- logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(device_str))
103
- on_cpu_devices[device_str] = True
104
  hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
105
  else:
106
  hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
 
41
  scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
42
  on_cpu_devices = {}
43
 
44
+ # Disable torch.compile for this function to avoid device access issues
45
+ @torch.compiler.disable
46
  def output_block_patch(h, hsp, transformer_options):
47
  scale = scale_dict.get(int(h.shape[1]), None)
48
  if scale is not None:
49
  h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
50
+ if hsp.device not in on_cpu_devices:
 
 
51
  try:
52
  hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
53
  except:
54
+ logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
55
+ on_cpu_devices[hsp.device] = True
56
  hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
57
  else:
58
  hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
 
82
  scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
83
  on_cpu_devices = {}
84
 
85
+ # Disable torch.compile for this function to avoid device access issues
86
+ @torch.compiler.disable
87
  def output_block_patch(h, hsp, transformer_options):
88
  scale = scale_dict.get(int(h.shape[1]), None)
89
  if scale is not None:
 
95
 
96
  h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
97
 
98
+ if hsp.device not in on_cpu_devices:
 
 
99
  try:
100
  hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
101
  except:
102
+ logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
103
+ on_cpu_devices[hsp.device] = True
104
  hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
105
  else:
106
  hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)