LogicGoInfotechSpaces commited on
Commit
80d11e4
·
1 Parent(s): bfac43d

Fix CUDA operations to check device: only use fused ops when input is on CUDA

Browse files
models/psp/stylegan2/op/fused_act.py CHANGED
@@ -32,7 +32,7 @@ class FusedLeakyReLUFunctionBackward(Function):
32
 
33
  empty = grad_output.new_empty(0)
34
 
35
- if _USE_FUSED:
36
  grad_input = fused.fused_bias_act(
37
  grad_output, empty, out, 3, 1, negative_slope, scale
38
  )
@@ -65,7 +65,8 @@ class FusedLeakyReLUFunctionBackward(Function):
65
  class FusedLeakyReLUFunction(Function):
66
  @staticmethod
67
  def forward(ctx, input, bias, negative_slope, scale):
68
- if _USE_FUSED:
 
69
  empty = input.new_empty(0)
70
  out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
71
  else:
 
32
 
33
  empty = grad_output.new_empty(0)
34
 
35
+ if _USE_FUSED and grad_output.is_cuda:
36
  grad_input = fused.fused_bias_act(
37
  grad_output, empty, out, 3, 1, negative_slope, scale
38
  )
 
65
  class FusedLeakyReLUFunction(Function):
66
  @staticmethod
67
  def forward(ctx, input, bias, negative_slope, scale):
68
+ # Only use fused operations if CUDA is available and input is on CUDA
69
+ if _USE_FUSED and input.is_cuda:
70
  empty = input.new_empty(0)
71
  out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
72
  else:
models/psp/stylegan2/op/upfirdn2d.py CHANGED
@@ -33,7 +33,7 @@ class UpFirDn2dBackward(Function):
33
 
34
  grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
35
 
36
- if _USE_EXT:
37
  grad_input = upfirdn2d_op.upfirdn2d(
38
  grad_output,
39
  grad_kernel,
@@ -76,7 +76,7 @@ class UpFirDn2dBackward(Function):
76
 
77
  gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
78
 
79
- if _USE_EXT:
80
  gradgrad_out = upfirdn2d_op.upfirdn2d(
81
  gradgrad_input,
82
  kernel,
@@ -131,7 +131,7 @@ class UpFirDn2d(Function):
131
 
132
  ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
133
 
134
- if _USE_EXT:
135
  out = upfirdn2d_op.upfirdn2d(
136
  input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
137
  )
 
33
 
34
  grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
35
 
36
+ if _USE_EXT and grad_output.is_cuda:
37
  grad_input = upfirdn2d_op.upfirdn2d(
38
  grad_output,
39
  grad_kernel,
 
76
 
77
  gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
78
 
79
+ if _USE_EXT and gradgrad_input.is_cuda:
80
  gradgrad_out = upfirdn2d_op.upfirdn2d(
81
  gradgrad_input,
82
  kernel,
 
131
 
132
  ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
133
 
134
+ if _USE_EXT and input.is_cuda:
135
  out = upfirdn2d_op.upfirdn2d(
136
  input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
137
  )