Update inference.py
Browse files- inference.py +41 -5
inference.py
CHANGED
|
@@ -280,12 +280,16 @@ class InferStep:
|
|
| 280 |
self.step_size = step_size
|
| 281 |
|
| 282 |
def project(self, x):
|
| 283 |
-
# Per-pixel
|
| 284 |
-
|
|
|
|
|
|
|
| 285 |
if self.use_adaptive_eps:
|
| 286 |
-
|
|
|
|
| 287 |
else:
|
| 288 |
-
|
|
|
|
| 289 |
return torch.clamp(self.orig_image + diff, 0, 1)
|
| 290 |
|
| 291 |
def step(self, x, grad):
|
|
@@ -334,7 +338,8 @@ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50
|
|
| 334 |
'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
|
| 335 |
'recognition_normalization': 'off',
|
| 336 |
'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
|
| 337 |
-
'misc_info': {'keep_grads': False} # Additional configuration
|
|
|
|
| 338 |
}
|
| 339 |
|
| 340 |
# Customize based on inference type
|
|
@@ -829,6 +834,30 @@ class GenerativeInferenceModel:
|
|
| 829 |
adaptive_step_config=adaptive_step,
|
| 830 |
)
|
| 831 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 832 |
# Storage for inference steps
|
| 833 |
# Create a new tensor that requires gradients
|
| 834 |
x = image_tensor.clone().detach().requires_grad_(True)
|
|
@@ -934,6 +963,13 @@ class GenerativeInferenceModel:
|
|
| 934 |
|
| 935 |
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
|
| 936 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
if grad is None:
|
| 938 |
print("Warning: Direct gradient calculation failed")
|
| 939 |
# Fall back to random perturbation
|
|
|
|
| 280 |
self.step_size = step_size
|
| 281 |
|
| 282 |
def project(self, x):
|
| 283 |
+
# Per-pixel L2 (color) constraint: scale diff so color_norm <= eps (or eps_map).
|
| 284 |
+
# Constraint is per-pixel L2 on RGB (spherical), not L∞ per channel.
|
| 285 |
+
diff = x - self.orig_image # (B, C, H, W)
|
| 286 |
+
color_norm = torch.norm(diff, dim=1, keepdim=True) # (B, 1, H, W)
|
| 287 |
if self.use_adaptive_eps:
|
| 288 |
+
# eps_map has shape (1, 1, H, W), broadcasts to (B, 1, H, W)
|
| 289 |
+
scale = torch.clamp(self.eps_map / (color_norm + 1e-10), max=1.0)
|
| 290 |
else:
|
| 291 |
+
scale = torch.clamp(self.eps / (color_norm + 1e-10), max=1.0)
|
| 292 |
+
diff = diff * scale
|
| 293 |
return torch.clamp(self.orig_image + diff, 0, 1)
|
| 294 |
|
| 295 |
def step(self, x, grad):
|
|
|
|
| 338 |
'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
|
| 339 |
'recognition_normalization': 'off',
|
| 340 |
'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
|
| 341 |
+
'misc_info': {'keep_grads': False}, # Additional configuration
|
| 342 |
+
'biased_inference': {'enable': False, 'class': None} # Bias perception toward a target class (ImageNet label)
|
| 343 |
}
|
| 344 |
|
| 345 |
# Customize based on inference type
|
|
|
|
| 834 |
adaptive_step_config=adaptive_step,
|
| 835 |
)
|
| 836 |
|
| 837 |
+
# Biased inference: resolve class name to index (ImageNet simple labels, case-insensitive)
|
| 838 |
+
biased_inference_config = config.get('biased_inference', {'enable': False, 'class': None})
|
| 839 |
+
biased_class_index = None
|
| 840 |
+
biased_class_tensor = None
|
| 841 |
+
if biased_inference_config.get('enable', False):
|
| 842 |
+
class_name = biased_inference_config.get('class') or biased_inference_config.get('class_name')
|
| 843 |
+
if class_name:
|
| 844 |
+
try:
|
| 845 |
+
biased_class_index = next(
|
| 846 |
+
i for i, label in enumerate(self.labels)
|
| 847 |
+
if label.lower() == class_name.lower()
|
| 848 |
+
)
|
| 849 |
+
biased_class_tensor = torch.tensor(
|
| 850 |
+
[biased_class_index], device=device, dtype=torch.long
|
| 851 |
+
)
|
| 852 |
+
print(f"Biased inference: biasing toward class '{self.labels[biased_class_index]}' (index {biased_class_index})")
|
| 853 |
+
except StopIteration:
|
| 854 |
+
raise ValueError(
|
| 855 |
+
f"biased_inference class '{class_name}' not found in ImageNet simple labels. "
|
| 856 |
+
"Use a label from imagenet-simple-labels (e.g. 'goldfish', 'tabby cat')."
|
| 857 |
+
)
|
| 858 |
+
else:
|
| 859 |
+
print("Biased inference enabled but no class specified; ignoring.")
|
| 860 |
+
|
| 861 |
# Storage for inference steps
|
| 862 |
# Create a new tensor that requires gradients
|
| 863 |
x = image_tensor.clone().detach().requires_grad_(True)
|
|
|
|
| 963 |
|
| 964 |
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
|
| 965 |
|
| 966 |
+
# Biased inference: subtract gradient that pushes toward target class
|
| 967 |
+
if biased_inference_config.get('enable', False) and biased_class_tensor is not None:
|
| 968 |
+
output_full = model(x)
|
| 969 |
+
loss_biased = F.cross_entropy(output_full, biased_class_tensor)
|
| 970 |
+
grad_biased = torch.autograd.grad(loss_biased, x)[0]
|
| 971 |
+
grad = grad - grad_biased
|
| 972 |
+
|
| 973 |
if grad is None:
|
| 974 |
print("Warning: Direct gradient calculation failed")
|
| 975 |
# Fall back to random perturbation
|