ttoosi commited on
Commit
75aa396
·
verified ·
1 Parent(s): 82fbda0

Update app.py

Browse files

checking grads

Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -108,6 +108,10 @@ def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0
108
 
109
  # Access gradient
110
  grad = image_tensor.grad # Gradient is now retained
 
 
 
 
111
  grad_norm = grad.view(grad.shape[0], -1).norm(dim=1, keepdim=True).view(grad.shape[0], 1, 1, 1)
112
  grad = grad / (grad_norm + 1e-10) # Avoid division by zero
113
 
 
108
 
109
  # Access gradient
110
  grad = image_tensor.grad # Gradient is now retained
111
+ if torch.isnan(grad).any(): # Check for NaN values in the gradient
112
+ print("Warning: Gradient contains NaN values. Aborting inference.")
113
+ return None, None
114
+
115
  grad_norm = grad.view(grad.shape[0], -1).norm(dim=1, keepdim=True).view(grad.shape[0], 1, 1, 1)
116
  grad = grad / (grad_norm + 1e-10) # Avoid division by zero
117