ttoosi commited on
Commit
ff3c24a
·
verified ·
1 Parent(s): 5050ebe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -71,9 +71,9 @@ def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0
71
  transforms.ToTensor(),
72
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
73
  ])
74
- image_tensor = transform(image).unsqueeze(0)
75
 
76
- image_tensor = image_tensor.detach().clone()
77
  image_tensor.requires_grad = True
78
  # image_tensor.retain_grad() # Ensure gradients are retained for non-leaf tensor
79
 
@@ -119,7 +119,9 @@ def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0
119
 
120
  # Update image tensor
121
  image_tensor = image_tensor + step_size * grad
122
- image_tensor = torch.clamp(image_tensor, -eps, eps) # Keep within range
 
 
123
  image_tensor = image_tensor.clone().detach().requires_grad_(True) # Ensure it's a new leaf tensor
124
  image_tensor.retain_grad()
125
 
 
71
  transforms.ToTensor(),
72
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
73
  ])
74
+ image_original = transform(image).unsqueeze(0)
75
 
76
+ image_tensor = image_original.detach().clone()
77
  image_tensor.requires_grad = True
78
  # image_tensor.retain_grad() # Ensure gradients are retained for non-leaf tensor
79
 
 
119
 
120
  # Update image tensor
121
  image_tensor = image_tensor + step_size * grad
122
+ delta = image_tensor - image_original
123
+ delta = torch.clamp(delta, -eps, eps) # Keep within range
124
+ image_tensor = torch.clamp(image_original + diff, 0, 1)
125
  image_tensor = image_tensor.clone().detach().requires_grad_(True) # Ensure it's a new leaf tensor
126
  image_tensor.retain_grad()
127