ttoosi commited on
Commit
82fbda0
·
verified ·
1 Parent(s): e98a2b8

Update app.py

Browse files

push the targets outside loop

Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -74,29 +74,34 @@ def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0
74
  image_tensor = transform(image).unsqueeze(0).requires_grad_(True)
75
  image_tensor.retain_grad() # Ensure gradients are retained for non-leaf tensor
76
 
 
 
 
 
 
 
 
 
 
77
  for _ in range(n_iterations):
 
 
78
  # Zero gradients
79
- if image_tensor.grad is not None:
80
- image_tensor.grad.zero_()
81
 
82
  # Forward pass
83
  output = model(image_tensor)
84
 
85
  # Define inference loss based on mode
86
  if mode == "increase confidence":
87
- probs = torch.nn.functional.softmax(output, dim=1)
88
- _, least_likely_indices = torch.topk(probs, k=2, largest=False)
89
  losses = []
90
  for idx in least_likely_indices[0]:
91
  target = torch.full((1,), idx, dtype=torch.long, device=output.device)
92
  loss = torch.nn.CrossEntropyLoss()(output, target)
93
  losses.append(loss)
94
  loss = torch.stack(losses).mean()
95
- elif mode == "ReverseDiffuse":
96
- noisy_image = image_tensor + torch.randn_like(image_tensor) * noise_ratio
97
  loss = torch.nn.functional.mse_loss(image_tensor, noisy_image)
98
- else:
99
- raise ValueError("Invalid mode selected. Choose 'increase confidence' or 'ReverseDiffuse'.")
100
 
101
  # Backward pass
102
  loss.backward()
 
74
  image_tensor = transform(image).unsqueeze(0).requires_grad_(True)
75
  image_tensor.retain_grad() # Ensure gradients are retained for non-leaf tensor
76
 
77
+ if mode == "increase confidence":
78
+ output = model(image_tensor)
79
+ probs = torch.nn.functional.softmax(output, dim=1)
80
+ _, least_likely_indices = torch.topk(probs, k=2, largest=False)
81
+ elif mode == "ReverseDiffuse":
82
+ noisy_image = image_tensor + torch.randn_like(image_tensor) * noise_ratio
83
+ else:
84
+ raise ValueError("Invalid mode selected. Choose 'increase confidence' or 'ReverseDiffuse'.")
85
+
86
  for _ in range(n_iterations):
87
+
88
+ image_tensor
89
  # Zero gradients
90
+ model.zero_grad()
 
91
 
92
  # Forward pass
93
  output = model(image_tensor)
94
 
95
  # Define inference loss based on mode
96
  if mode == "increase confidence":
 
 
97
  losses = []
98
  for idx in least_likely_indices[0]:
99
  target = torch.full((1,), idx, dtype=torch.long, device=output.device)
100
  loss = torch.nn.CrossEntropyLoss()(output, target)
101
  losses.append(loss)
102
  loss = torch.stack(losses).mean()
103
+ elif mode == "ReverseDiffuse":
 
104
  loss = torch.nn.functional.mse_loss(image_tensor, noisy_image)
 
 
105
 
106
  # Backward pass
107
  loss.backward()