ttoosi commited on
Commit
3aed008
·
verified ·
1 Parent(s): d239a2b

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +41 -5
inference.py CHANGED
@@ -280,12 +280,16 @@ class InferStep:
280
  self.step_size = step_size
281
 
282
  def project(self, x):
283
- # Per-pixel deviation from original; clamp by eps (or eps_map). No add/subtract of mask.
284
- diff = x - self.orig_image
 
 
285
  if self.use_adaptive_eps:
286
- diff = torch.clamp(diff, -self.eps_map, self.eps_map)
 
287
  else:
288
- diff = torch.clamp(diff, -self.eps, self.eps)
 
289
  return torch.clamp(self.orig_image + diff, 0, 1)
290
 
291
  def step(self, x, grad):
@@ -334,7 +338,8 @@ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50
334
  'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
335
  'recognition_normalization': 'off',
336
  'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
337
- 'misc_info': {'keep_grads': False} # Additional configuration
 
338
  }
339
 
340
  # Customize based on inference type
@@ -829,6 +834,30 @@ class GenerativeInferenceModel:
829
  adaptive_step_config=adaptive_step,
830
  )
831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
  # Storage for inference steps
833
  # Create a new tensor that requires gradients
834
  x = image_tensor.clone().detach().requires_grad_(True)
@@ -934,6 +963,13 @@ class GenerativeInferenceModel:
934
 
935
  grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
936
 
 
 
 
 
 
 
 
937
  if grad is None:
938
  print("Warning: Direct gradient calculation failed")
939
  # Fall back to random perturbation
 
280
  self.step_size = step_size
281
 
282
  def project(self, x):
283
+ # Per-pixel L2 (color) constraint: scale diff so color_norm <= eps (or eps_map).
284
+ # Constraint is per-pixel L2 on RGB (spherical), not L∞ per channel.
285
+ diff = x - self.orig_image # (B, C, H, W)
286
+ color_norm = torch.norm(diff, dim=1, keepdim=True) # (B, 1, H, W)
287
  if self.use_adaptive_eps:
288
+ # eps_map has shape (1, 1, H, W), broadcasts to (B, 1, H, W)
289
+ scale = torch.clamp(self.eps_map / (color_norm + 1e-10), max=1.0)
290
  else:
291
+ scale = torch.clamp(self.eps / (color_norm + 1e-10), max=1.0)
292
+ diff = diff * scale
293
  return torch.clamp(self.orig_image + diff, 0, 1)
294
 
295
  def step(self, x, grad):
 
338
  'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
339
  'recognition_normalization': 'off',
340
  'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
341
+ 'misc_info': {'keep_grads': False}, # Additional configuration
342
+ 'biased_inference': {'enable': False, 'class': None} # Bias perception toward a target class (ImageNet label)
343
  }
344
 
345
  # Customize based on inference type
 
834
  adaptive_step_config=adaptive_step,
835
  )
836
 
837
+ # Biased inference: resolve class name to index (ImageNet simple labels, case-insensitive)
838
+ biased_inference_config = config.get('biased_inference', {'enable': False, 'class': None})
839
+ biased_class_index = None
840
+ biased_class_tensor = None
841
+ if biased_inference_config.get('enable', False):
842
+ class_name = biased_inference_config.get('class') or biased_inference_config.get('class_name')
843
+ if class_name:
844
+ try:
845
+ biased_class_index = next(
846
+ i for i, label in enumerate(self.labels)
847
+ if label.lower() == class_name.lower()
848
+ )
849
+ biased_class_tensor = torch.tensor(
850
+ [biased_class_index], device=device, dtype=torch.long
851
+ )
852
+ print(f"Biased inference: biasing toward class '{self.labels[biased_class_index]}' (index {biased_class_index})")
853
+ except StopIteration:
854
+ raise ValueError(
855
+ f"biased_inference class '{class_name}' not found in ImageNet simple labels. "
856
+ "Use a label from imagenet-simple-labels (e.g. 'goldfish', 'tabby cat')."
857
+ )
858
+ else:
859
+ print("Biased inference enabled but no class specified; ignoring.")
860
+
861
  # Storage for inference steps
862
  # Create a new tensor that requires gradients
863
  x = image_tensor.clone().detach().requires_grad_(True)
 
963
 
964
  grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
965
 
966
+ # Biased inference: subtract gradient that pushes toward target class
967
+ if biased_inference_config.get('enable', False) and biased_class_tensor is not None:
968
+ output_full = model(x)
969
+ loss_biased = F.cross_entropy(output_full, biased_class_tensor)
970
+ grad_biased = torch.autograd.grad(loss_biased, x)[0]
971
+ grad = grad - grad_biased
972
+
973
  if grad is None:
974
  print("Warning: Direct gradient calculation failed")
975
  # Fall back to random perturbation