Spaces:
Build error
Build error
Update app.py
Browse fileschecking grads
app.py
CHANGED
|
@@ -108,6 +108,10 @@ def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0
|
|
| 108 |
|
| 109 |
# Access gradient
|
| 110 |
grad = image_tensor.grad # Gradient is now retained
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
grad_norm = grad.view(grad.shape[0], -1).norm(dim=1, keepdim=True).view(grad.shape[0], 1, 1, 1)
|
| 112 |
grad = grad / (grad_norm + 1e-10) # Avoid division by zero
|
| 113 |
|
|
|
|
| 108 |
|
| 109 |
# Access gradient
|
| 110 |
grad = image_tensor.grad # Gradient is now retained
|
| 111 |
+
if torch.isnan(grad).any(): # Check for NaN values in the gradient
|
| 112 |
+
print("Warning: Gradient contains NaN values. Aborting inference.")
|
| 113 |
+
return None, None
|
| 114 |
+
|
| 115 |
grad_norm = grad.view(grad.shape[0], -1).norm(dim=1, keepdim=True).view(grad.shape[0], 1, 1, 1)
|
| 116 |
grad = grad / (grad_norm + 1e-10) # Avoid division by zero
|
| 117 |
|