ttoosi commited on
Commit
9ca430f
·
verified ·
1 Parent(s): 4003868

Upload 8 files

Browse files

uploading base files

Files changed (8) hide show
  1. Dockerfile +31 -0
  2. LICENSE +21 -0
  3. README.md +88 -6
  4. app.py +593 -0
  5. face_vase.png +0 -0
  6. huggingface-metadata.json +11 -0
  7. inference.py +1012 -0
  8. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /code
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ git \
9
+ libgl1-mesa-glx \
10
+ libglib2.0-0 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy only requirements to leverage Docker caching
14
+ COPY ./requirements.txt /code/requirements.txt
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
18
+
19
+ # Copy all code and data
20
+ COPY . /code/
21
+
22
+ # Create necessary directories
23
+ RUN mkdir -p /code/models
24
+ RUN mkdir -p /code/stimuli
25
+
26
+ # Make sure stimuli and models are writable
27
+ RUN chmod -R 777 /code/models
28
+ RUN chmod -R 777 /code/stimuli
29
+
30
+ # Set up the command to run the app
31
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 GenerativeInferenceDemo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,96 @@
1
  ---
2
  title: Human Hallucination Prediction
3
- emoji: 📊
4
  colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.5.1
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
- short_description: Predicting human perception using generative inference
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Human Hallucination Prediction
3
+ emoji: 👁️
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.23.1
 
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Human Hallucination Prediction
14
+
15
+ This Gradio demo predicts whether humans will experience visual hallucinations or illusions when viewing specific images. Using adversarially robust neural networks, this tool can forecast perceptual phenomena like illusory contours, figure-ground reversals, and other Gestalt effects before humans report them.
16
+
17
+ ## How It Works
18
+
19
+ This tool uses **generative inference** with adversarially robust neural networks to predict human visual hallucinations. Robust models trained with adversarial examples develop more human-like perceptual biases, allowing them to predict when humans will perceive:
20
+
21
+ - **Illusory contours** (Kanizsa shapes, Ehrenstein illusion)
22
+ - **Figure-ground ambiguity** (Rubin's vase, bistable images)
23
+ - **Color spreading effects** (Neon color illusion)
24
+ - **Gestalt grouping** (Continuity, proximity)
25
+ - **Brightness illusions** (Cornsweet effect)
26
+
27
+ ## Features
28
+
29
+ - **Predict hallucinations** from uploaded images or example illusions
30
+ - **Visualize the prediction process** step-by-step
31
+ - **Compare different models** (robust vs. standard)
32
+ - **Adjust prediction parameters** for different perceptual phenomena
33
+ - **Pre-configured examples** of classic visual illusions
34
+
35
+ ## Usage
36
+
37
+ 1. **Select an example illusion** or upload your own image
38
+ 2. **Click "Load Parameters"** to set optimal prediction settings
39
+ 3. **Click "Run Generative Inference"** to predict the hallucination
40
+ 4. **View the results**: The model will show what perceptual effects it predicts humans will experience
41
+
42
+ ## Scientific Background
43
+
44
+ This demo is based on research showing that adversarially robust neural networks develop perceptual representations similar to human vision. By using generative inference (optimizing images to maximize model confidence), we can reveal what perceptual structures the network expects to see—which often matches what humans hallucinate or perceive in ambiguous images.
45
+
46
+ ## Installation
47
+
48
+ To run this demo locally:
49
+
50
+ ```bash
51
+ # Clone the repository
52
+ git clone https://huggingface.co/spaces/ttoosi/Human_Hallucination_Prediction
53
+ cd Human_Hallucination_Prediction
54
+
55
+ # Install dependencies
56
+ pip install -r requirements.txt
57
+
58
+ # Run the app
59
+ python app.py
60
+ ```
61
+
62
+ The web app will be available at http://localhost:7860.
63
+
64
+ ## The Prediction Process
65
+
66
+ 1. **Input**: Start with an ambiguous or illusion-inducing image
67
+ 2. **Generative Inference**: The robust neural network iteratively modifies the image to maximize its confidence
68
+ 3. **Prediction**: The modifications reveal what perceptual structures the network expects—predicting what humans will hallucinate
69
+ 4. **Visualization**: View the predicted hallucination emerging step-by-step
70
+
71
+ ## Models
72
+
73
+ - **Robust ResNet50**: Trained with adversarial examples (ε=3.0), develops human-like perceptual biases
74
+ - **Standard ResNet50**: Standard ImageNet training without adversarial robustness
75
+
76
+ ## Citation
77
+
78
+ If you use this work in your research, please cite:
79
+
80
+ ```bibtex
81
+ @article{toosi2024hallucination,
82
+ title={Predicting Human Visual Hallucinations with Robust Neural Networks},
83
+ author={Toosi, Tahereh},
84
+ year={2024}
85
+ }
86
+ ```
87
+
88
+ ## About
89
+
90
+ **Developed by [Tahereh Toosi](https://toosi.github.io)**
91
+
92
+ This demo demonstrates how adversarially robust neural networks can predict human perceptual hallucinations before they occur.
93
+
94
+ ## License
95
+
96
+ This project is licensed under the MIT License - see the LICENSE file for details.
app.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+ try:
6
+ from spaces import GPU
7
+ except ImportError:
8
+ # Define a no-op decorator if running locally
9
+ def GPU(func):
10
+ return func
11
+
12
+ import os
13
+ import argparse
14
+ from inference import GenerativeInferenceModel, get_inference_configs
15
+
16
+ # Parse command line arguments
17
+ parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
18
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
19
+ args = parser.parse_args()
20
+
21
+ # Create model directories if they don't exist
22
+ os.makedirs("models", exist_ok=True)
23
+ os.makedirs("stimuli", exist_ok=True)
24
+
25
+ # Check if running on Hugging Face Spaces
26
+ if "SPACE_ID" in os.environ:
27
+ default_port = int(os.environ.get("PORT", 7860))
28
+ else:
29
+ default_port = 8861 # Local default port
30
+
31
+ # Initialize model
32
+ model = GenerativeInferenceModel()
33
+
34
+ # Define example images and their parameters with updated values from the research
35
+ examples = [
36
+ {
37
+ "image": os.path.join("stimuli", "Neon_Color_Circle.jpg"),
38
+ "name": "Neon Color Spreading",
39
+ "wiki": "https://en.wikipedia.org/wiki/Neon_color_spreading",
40
+ "papers": [
41
+ "[Color Assimilation](https://doi.org/10.1016/j.visres.2000.200.1)",
42
+ "[Perceptual Filling-in](https://doi.org/10.1016/j.tics.2003.08.003)"
43
+ ],
44
+ "method": "Prior-Guided Drift Diffusion",
45
+ "reverse_diff": {
46
+ "model": "resnet50_robust",
47
+ "layer": "layer3",
48
+ "initial_noise": 0.8,
49
+ "diffusion_noise": 0.003,
50
+ "step_size": 1.0,
51
+ "iterations": 101,
52
+ "epsilon": 20.0
53
+ }
54
+ },
55
+ {
56
+ "image": os.path.join("stimuli", "Kanizsa_square.jpg"),
57
+ "name": "Kanizsa Square",
58
+ "wiki": "https://en.wikipedia.org/wiki/Kanizsa_triangle",
59
+ "papers": [
60
+ "[Gestalt Psychology](https://en.wikipedia.org/wiki/Gestalt_psychology)",
61
+ "[Neural Mechanisms](https://doi.org/10.1016/j.tics.2003.08.003)"
62
+ ],
63
+ "method": "Prior-Guided Drift Diffusion",
64
+ "reverse_diff": {
65
+ "model": "resnet50_robust",
66
+ "layer": "all",
67
+ "initial_noise": 0.0,
68
+ "diffusion_noise": 0.005,
69
+ "step_size": 0.64,
70
+ "iterations": 100,
71
+ "epsilon": 5.0
72
+ }
73
+ },
74
+ {
75
+ "image": os.path.join("stimuli", "CornsweetBlock.png"),
76
+ "name": "Cornsweet Illusion",
77
+ "wiki": "https://en.wikipedia.org/wiki/Cornsweet_illusion",
78
+ "papers": [
79
+ "[Brightness Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
80
+ "[Edge Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
81
+ ],
82
+ "instructions": "Both blocks are gray in color (the same), use your finger to cover the middle line. Hit 'Load Parameters' and then hit 'Run Generative Inference' to see how the model sees the blocks.",
83
+ "method": "Prior-Guided Drift Diffusion",
84
+ "reverse_diff": {
85
+ "model": "resnet50_robust",
86
+ "layer": "layer3",
87
+ "initial_noise": 0.5,
88
+ "diffusion_noise": 0.005,
89
+ "step_size": 0.8,
90
+ "iterations": 51,
91
+ "epsilon": 20.0
92
+ }
93
+ },
94
+ {
95
+ "image": os.path.join("stimuli", "face_vase.png"),
96
+ "name": "Rubin's Face-Vase (Object Prior)",
97
+ "wiki": "https://en.wikipedia.org/wiki/Rubin_vase",
98
+ "papers": [
99
+ "[Figure-Ground Perception](https://en.wikipedia.org/wiki/Figure-ground_(perception))",
100
+ "[Bistable Perception](https://doi.org/10.1016/j.tics.2003.08.003)"
101
+ ],
102
+ "method": "Prior-Guided Drift Diffusion",
103
+ "reverse_diff": {
104
+ "model": "resnet50_robust",
105
+ "layer": "avgpool",
106
+ "initial_noise": 0.9,
107
+ "diffusion_noise": 0.003,
108
+ "step_size": 0.58,
109
+ "iterations": 100,
110
+ "epsilon": 0.81
111
+ }
112
+ },
113
+ {
114
+ "image": os.path.join("stimuli", "Confetti_illusion.png"),
115
+ "name": "Confetti Illusion",
116
+ "wiki": "https://www.youtube.com/watch?v=SvEiEi8O7QE",
117
+ "papers": [
118
+ "[Color Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
119
+ "[Context Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
120
+ ],
121
+ "method": "Prior-Guided Drift Diffusion",
122
+ "reverse_diff": {
123
+ "model": "resnet50_robust",
124
+ "layer": "layer3",
125
+ "initial_noise": 0.1,
126
+ "diffusion_noise": 0.003,
127
+ "step_size": 0.5,
128
+ "iterations": 101,
129
+ "epsilon": 20.0
130
+ }
131
+ },
132
+ {
133
+ "image": os.path.join("stimuli", "EhresteinSingleColor.png"),
134
+ "name": "Ehrenstein Illusion",
135
+ "wiki": "https://en.wikipedia.org/wiki/Ehrenstein_illusion",
136
+ "papers": [
137
+ "[Subjective Contours](https://doi.org/10.1016/j.visres.2000.200.1)",
138
+ "[Neural Processing](https://doi.org/10.1016/j.tics.2003.08.003)"
139
+ ],
140
+ "method": "Prior-Guided Drift Diffusion",
141
+ "reverse_diff": {
142
+ "model": "resnet50_robust",
143
+ "layer": "layer3",
144
+ "initial_noise": 0.5,
145
+ "diffusion_noise": 0.005,
146
+ "step_size": 0.8,
147
+ "iterations": 101,
148
+ "epsilon": 20.0
149
+ }
150
+ },
151
+ {
152
+ "image": os.path.join("stimuli", "GroupingByContinuity.png"),
153
+ "name": "Grouping by Continuity",
154
+ "wiki": "https://en.wikipedia.org/wiki/Principles_of_grouping",
155
+ "papers": [
156
+ "[Gestalt Principles](https://en.wikipedia.org/wiki/Gestalt_psychology)",
157
+ "[Visual Organization](https://doi.org/10.1016/j.tics.2003.08.003)"
158
+ ],
159
+ "method": "Prior-Guided Drift Diffusion",
160
+ "reverse_diff": {
161
+ "model": "resnet50_robust",
162
+ "layer": "layer3",
163
+ "initial_noise": 0.0,
164
+ "diffusion_noise": 0.005,
165
+ "step_size": 0.4,
166
+ "iterations": 101,
167
+ "epsilon": 4.0
168
+ }
169
+ },
170
+ {
171
+ "image": os.path.join("stimuli", "figure_ground.png"),
172
+ "name": "Figure-Ground Illusion",
173
+ "wiki": "https://en.wikipedia.org/wiki/Figure-ground_(perception)",
174
+ "papers": [
175
+ "[Gestalt Principles](https://en.wikipedia.org/wiki/Gestalt_psychology)",
176
+ "[Perceptual Organization](https://doi.org/10.1016/j.tics.2003.08.003)"
177
+ ],
178
+ "method": "Prior-Guided Drift Diffusion",
179
+ "reverse_diff": {
180
+ "model": "resnet50_robust",
181
+ "layer": "layer3",
182
+ "initial_noise": 0.1,
183
+ "diffusion_noise": 0.003,
184
+ "step_size": 0.5,
185
+ "iterations": 101,
186
+ "epsilon": 3.0
187
+ }
188
+ }
189
+ ]
190
+
191
+ @GPU
192
+ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
193
+ initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
194
+ use_adaptive_eps=False, use_adaptive_step=False,
195
+ mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
196
+ eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0):
197
+ # Check if image is provided
198
+ if image is None:
199
+ return None, "Please upload an image before running inference."
200
+
201
+ # Convert eps to float
202
+ eps = float(eps_value)
203
+ step_size_f = float(step_size)
204
+ # Coerce checkbox values (Gradio can send bool, None, or other)
205
+ use_adaptive_eps = bool(use_adaptive_eps) if use_adaptive_eps is not None else False
206
+ use_adaptive_step = bool(use_adaptive_step) if use_adaptive_step is not None else False
207
+
208
+ # Load inference configuration based on the selected type
209
+ config = get_inference_configs(inference_type=inference_type, eps=eps, n_itr=int(num_iterations))
210
+
211
+ # Handle Prior-Guided Drift Diffusion specific parameters
212
+ if inference_type == "Prior-Guided Drift Diffusion":
213
+ config['initial_inference_noise_ratio'] = float(initial_noise)
214
+ config['diffusion_noise_ratio'] = float(diffusion_noise)
215
+ config['step_size'] = step_size_f
216
+ config['top_layer'] = model_layer
217
+
218
+ # Adaptive epsilon (Gaussian mask)
219
+ if use_adaptive_eps:
220
+ config['adaptive_epsilon'] = {
221
+ 'enabled': True,
222
+ 'base_epsilon': eps,
223
+ 'center_x': float(mask_center_x),
224
+ 'center_y': float(mask_center_y),
225
+ 'flat_radius': float(mask_radius),
226
+ 'sigma': float(mask_sigma),
227
+ 'max_multiplier': float(eps_max_mult),
228
+ 'min_multiplier': float(eps_min_mult),
229
+ }
230
+ else:
231
+ config['adaptive_epsilon'] = None
232
+
233
+ # Adaptive step size (same Gaussian mask location/size, different multipliers)
234
+ if use_adaptive_step:
235
+ config['adaptive_step_size'] = {
236
+ 'enabled': True,
237
+ 'base_step_size': step_size_f,
238
+ 'center_x': float(mask_center_x),
239
+ 'center_y': float(mask_center_y),
240
+ 'flat_radius': float(mask_radius),
241
+ 'sigma': float(mask_sigma),
242
+ 'max_multiplier': float(step_max_mult),
243
+ 'min_multiplier': float(step_min_mult),
244
+ }
245
+ else:
246
+ config['adaptive_step_size'] = None
247
+
248
+ # Run generative inference
249
+ result = model.inference(image, model_type, config)
250
+
251
+ # Extract results based on return type
252
+ if isinstance(result, tuple):
253
+ # Old format returning (output_image, all_steps)
254
+ output_image, all_steps = result
255
+ else:
256
+ # New format returning dictionary
257
+ output_image = result['final_image']
258
+ all_steps = result['steps']
259
+
260
+ # Create animation frames
261
+ frames = []
262
+ for i, step_image in enumerate(all_steps):
263
+ # Convert tensor to PIL image
264
+ step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
265
+ frames.append(step_pil)
266
+
267
+ # Convert the final output image to PIL
268
+ final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
269
+
270
+ # Return the final inferred image and the animation frames directly
271
+ return final_image, frames
272
+
273
+ def _image_to_pil(img):
274
+ """Convert Gradio image value (PIL, numpy, path, or dict) to PIL Image; return None if invalid."""
275
+ if img is None:
276
+ return None
277
+ if isinstance(img, Image.Image):
278
+ return img
279
+ if isinstance(img, np.ndarray):
280
+ return Image.fromarray(img.astype(np.uint8))
281
+ if isinstance(img, dict) and "path" in img:
282
+ return Image.open(img["path"]).convert("RGB")
283
+ if isinstance(img, str) and os.path.exists(img):
284
+ return Image.open(img).convert("RGB")
285
+ return None
286
+
287
+
288
+ def _image_size(img):
289
+ """Return (width, height) from Gradio image value, or (1, 1) if unknown."""
290
+ pil = _image_to_pil(img)
291
+ if pil is not None:
292
+ return pil.size
293
+ return (1, 1)
294
+
295
+
296
+ def image_click_to_center(evt: gr.SelectData):
297
+ """Convert click (x, y) on image to normalized mask center (-1 to 1). Returns (center_x, center_y)."""
298
+ if evt.index is None or (isinstance(evt.index, (list, tuple)) and len(evt.index) < 2):
299
+ return 0.0, 0.0
300
+ x, y = float(evt.index[0]), float(evt.index[1])
301
+ w, h = _image_size(evt.value)
302
+ if w <= 0 or h <= 0:
303
+ return 0.0, 0.0
304
+ center_x = (x / w) * 2.0 - 1.0
305
+ center_y = (y / h) * 2.0 - 1.0
306
+ center_x = max(-1.0, min(1.0, center_x))
307
+ center_y = max(-1.0, min(1.0, center_y))
308
+ return center_x, center_y
309
+
310
+
311
+ def draw_mask_overlay(image, center_x, center_y, radius):
312
+ """Draw the Gaussian mask center and radius on a copy of the image. Returns PIL or None."""
313
+ pil = _image_to_pil(image)
314
+ if pil is None:
315
+ return None
316
+ img = pil.convert("RGB").copy()
317
+ w, h = img.size
318
+ cx_px = (float(center_x) + 1.0) / 2.0 * w
319
+ cy_px = (float(center_y) + 1.0) / 2.0 * h
320
+ radius_px = float(radius) * min(w, h) / 2.0
321
+ draw = ImageDraw.Draw(img)
322
+ # Circle for radius
323
+ draw.ellipse(
324
+ [cx_px - radius_px, cy_px - radius_px, cx_px + radius_px, cy_px + radius_px],
325
+ outline="#E11D48",
326
+ width=max(2, min(w, h) // 150),
327
+ )
328
+ # Center dot
329
+ r = max(2, min(w, h) // 80)
330
+ draw.ellipse([cx_px - r, cy_px - r, cx_px + r, cy_px + r], fill="#E11D48", outline="#FFF")
331
+ return img
332
+
333
+
334
+ # Helper function to apply example parameters (adaptive mask off by default for examples)
335
+ def apply_example(example):
336
+ return [
337
+ example["image"],
338
+ "resnet50_robust",
339
+ example["method"],
340
+ example["reverse_diff"]["epsilon"],
341
+ example["reverse_diff"]["iterations"],
342
+ example["reverse_diff"]["initial_noise"],
343
+ example["reverse_diff"]["diffusion_noise"],
344
+ example["reverse_diff"]["step_size"],
345
+ example["reverse_diff"]["layer"],
346
+ False, False, # use_adaptive_eps, use_adaptive_step
347
+ 0.0, 0.0, 0.3, 0.2, # mask center x/y, radius, sigma
348
+ 4.0, 1.0, 4.0, 1.0, # eps max/min mult, step max/min mult
349
+ gr.Group(visible=True),
350
+ ]
351
+
352
+ # Define the interface
353
+ with gr.Blocks(title="Human Hallucination Prediction", css="""
354
+ .purple-button {
355
+ background-color: #8B5CF6 !important;
356
+ color: white !important;
357
+ border: none !important;
358
+ }
359
+ .purple-button:hover {
360
+ background-color: #7C3AED !important;
361
+ }
362
+ """) as demo:
363
+ gr.Markdown("# 👁️ Human Hallucination Prediction")
364
+ gr.Markdown("**Predict what visual hallucinations humans will experience** using adversarially robust neural networks. This demo forecasts perceptual phenomena like illusory contours, figure-ground reversals, and Gestalt effects before humans report them.")
365
+
366
+ gr.Markdown("""
367
+ **How to predict hallucinations:**
368
+ 1. **Select an example illusion** below and click "Load Parameters" to set optimal prediction settings
369
+ 2. **Click "Run Generative Inference"** to predict what hallucination humans will perceive
370
+ 3. **View the prediction**: Watch as the model reveals the perceptual structures it expects—matching what humans typically hallucinate
371
+ 4. **Upload your own images** to test if they will induce hallucinations in human observers
372
+ """)
373
+
374
+ # Main processing interface
375
+ with gr.Row():
376
+ with gr.Column(scale=1):
377
+ # Inputs
378
+ image_input = gr.Image(label="Input Image (click to set mask center)", type="pil", value=os.path.join("stimuli", "urbanoffice1.jpg"))
379
+ mask_preview = gr.Image(
380
+ label="Mask center preview (click to set center — circle shows mask)",
381
+ type="pil",
382
+ interactive=False,
383
+ )
384
+ # Run Inference button right below the image
385
+ run_button = gr.Button("🔮 Predict Hallucination", variant="primary", elem_classes="purple-button")
386
+
387
+ # Parameters toggle button
388
+ params_button = gr.Button("⚙️ Play with the parameters", variant="secondary")
389
+
390
+ # Parameters section (initially hidden)
391
+ with gr.Group(visible=False) as params_section:
392
+ with gr.Row():
393
+ model_choice = gr.Dropdown(
394
+ choices=["resnet50_robust", "standard_resnet50"], # "resnet50_robust_face" - hidden for deployment
395
+ value="resnet50_robust",
396
+ label="Model"
397
+ )
398
+
399
+ inference_type = gr.Dropdown(
400
+ choices=["Prior-Guided Drift Diffusion", "IncreaseConfidence"],
401
+ value="Prior-Guided Drift Diffusion",
402
+ label="Inference Method"
403
+ )
404
+
405
+ with gr.Row():
406
+ eps_slider = gr.Slider(minimum=0.0, maximum=40.0, value=10.0, step=0.01, label="Epsilon (Stimulus Fidelity)")
407
+ iterations_slider = gr.Slider(minimum=1, maximum=600, value=501, step=1, label="Number of Iterations")
408
+
409
+ with gr.Row():
410
+ initial_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.01,
411
+ label="Drift Noise")
412
+ diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=0.05, value=0.002, step=0.001,
413
+ label="Diffusion Noise")
414
+
415
+ with gr.Row():
416
+ step_size_slider = gr.Slider(minimum=0.01, maximum=2.0, value=0.1, step=0.01,
417
+ label="Update Rate")
418
+ layer_choice = gr.Dropdown(
419
+ choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
420
+ value="layer3",
421
+ label="Model Layer"
422
+ )
423
+
424
+ gr.Markdown("### 🎯 Adaptive Gaussian mask (spatially varying constraint)")
425
+ gr.Markdown("Define where on the image the mask is centered and how large its radius is. Coordinates: **-1** = left/top, **1** = right/bottom, **0** = center.")
426
+ with gr.Row():
427
+ use_adaptive_eps_check = gr.Checkbox(value=False, label="Use adaptive epsilon (stronger/weaker constraint by region)")
428
+ use_adaptive_step_check = gr.Checkbox(value=True, label="Use adaptive step size (stronger/weaker updates by region)")
429
+ with gr.Row():
430
+ mask_center_x_slider = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Mask center X")
431
+ mask_center_y_slider = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Mask center Y")
432
+ with gr.Row():
433
+ mask_radius_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.2, step=0.01, label="Mask radius (flat region size)")
434
+ mask_sigma_slider = gr.Slider(minimum=0.05, maximum=0.5, value=0.2, step=0.01, label="Mask sigma (fall-off outside radius)")
435
+ with gr.Row():
436
+ eps_max_mult_slider = gr.Slider(minimum=0.1, maximum=15.0, value=4.0, step=0.1, label="Epsilon: multiplier at center")
437
+ eps_min_mult_slider = gr.Slider(minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Epsilon: multiplier at periphery")
438
+ with gr.Row():
439
+ step_max_mult_slider = gr.Slider(minimum=0.1, maximum=15.0, value=10.0, step=0.1, label="Step size: multiplier at center")
440
+ step_min_mult_slider = gr.Slider(minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Step size: multiplier at periphery")
441
+
442
+ with gr.Column(scale=2):
443
+ # Outputs
444
+ output_image = gr.Image(label="Predicted Hallucination")
445
+ output_frames = gr.Gallery(label="Hallucination Prediction Process", columns=5, rows=2)
446
+
447
+ # Examples section with integrated explanations
448
+ gr.Markdown("## 🎯 Visual Illusion Examples")
449
+ gr.Markdown("Select an illusion to predict what hallucination humans will experience when viewing it")
450
+
451
+ # For each example, create a row with the image and explanation side by side
452
+ for i, ex in enumerate(examples):
453
+ with gr.Row():
454
+ # Left column for the image
455
+ with gr.Column(scale=1):
456
+ # Display the example image
457
+ example_img = gr.Image(value=ex["image"], type="filepath", label=f"{ex['name']}")
458
+ load_btn = gr.Button(f"Load Parameters", variant="primary")
459
+
460
+ # Set up the load button to apply this example's parameters
461
+ load_btn.click(
462
+ fn=lambda ex=ex: apply_example(ex),
463
+ outputs=[
464
+ image_input, model_choice, inference_type,
465
+ eps_slider, iterations_slider,
466
+ initial_noise_slider, diffusion_noise_slider,
467
+ step_size_slider, layer_choice,
468
+ use_adaptive_eps_check, use_adaptive_step_check,
469
+ mask_center_x_slider, mask_center_y_slider,
470
+ mask_radius_slider, mask_sigma_slider,
471
+ eps_max_mult_slider, eps_min_mult_slider,
472
+ step_max_mult_slider, step_min_mult_slider,
473
+ params_section,
474
+ ],
475
+ )
476
+
477
+ # Right column for the explanation
478
+ with gr.Column(scale=2):
479
+ gr.Markdown(f"### {ex['name']}")
480
+ gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
481
+
482
+ # Show instructions if they exist
483
+ if "instructions" in ex:
484
+ gr.Markdown(f"**Instructions:** {ex['instructions']}")
485
+
486
+
487
+ if i < len(examples) - 1: # Don't add separator after the last example
488
+ gr.Markdown("---")
489
+
490
+ # Set up event handler for the main inference
491
+ run_button.click(
492
+ fn=run_inference,
493
+ inputs=[
494
+ image_input, model_choice, inference_type,
495
+ eps_slider, iterations_slider,
496
+ initial_noise_slider, diffusion_noise_slider,
497
+ step_size_slider, layer_choice,
498
+ use_adaptive_eps_check, use_adaptive_step_check,
499
+ mask_center_x_slider, mask_center_y_slider,
500
+ mask_radius_slider, mask_sigma_slider,
501
+ eps_max_mult_slider, eps_min_mult_slider,
502
+ step_max_mult_slider, step_min_mult_slider,
503
+ ],
504
+ outputs=[output_image, output_frames]
505
+ )
506
+
507
+ # Toggle parameters visibility
508
+ def toggle_params():
509
+ return gr.Group(visible=True)
510
+
511
+ params_button.click(
512
+ fn=toggle_params,
513
+ outputs=[params_section],
514
+ )
515
+
516
+ # Click on input image or mask preview to set Gaussian mask center
517
+ def _mask_preview_inputs():
518
+ return [image_input, mask_center_x_slider, mask_center_y_slider, mask_radius_slider]
519
+
520
+ image_input.select(
521
+ fn=image_click_to_center,
522
+ outputs=[mask_center_x_slider, mask_center_y_slider],
523
+ )
524
+ mask_preview.select(
525
+ fn=image_click_to_center,
526
+ outputs=[mask_center_x_slider, mask_center_y_slider],
527
+ )
528
+ # Update mask preview when image or center/radius change
529
+ image_input.change(
530
+ fn=draw_mask_overlay,
531
+ inputs=_mask_preview_inputs(),
532
+ outputs=[mask_preview],
533
+ )
534
+ mask_center_x_slider.change(
535
+ fn=draw_mask_overlay,
536
+ inputs=_mask_preview_inputs(),
537
+ outputs=[mask_preview],
538
+ )
539
+ mask_center_y_slider.change(
540
+ fn=draw_mask_overlay,
541
+ inputs=_mask_preview_inputs(),
542
+ outputs=[mask_preview],
543
+ )
544
+ mask_radius_slider.change(
545
+ fn=draw_mask_overlay,
546
+ inputs=_mask_preview_inputs(),
547
+ outputs=[mask_preview],
548
+ )
549
+ # Populate mask preview on load
550
+ demo.load(
551
+ fn=draw_mask_overlay,
552
+ inputs=_mask_preview_inputs(),
553
+ outputs=[mask_preview],
554
+ )
555
+
556
+ # About section
557
+ gr.Markdown("""
558
+ ## 🧠 About Hallucination Prediction
559
+
560
+ This tool predicts human visual hallucinations using **generative inference** with adversarially robust neural networks. Robust models develop human-like perceptual biases, allowing them to forecast what perceptual structures humans will experience.
561
+
562
+ ### Prediction Methods:
563
+
564
+ **Prior-Guided Drift Diffusion** (Primary Method)
565
+ Starting from a noisy representation, the model converges toward what it expects to perceive—revealing predicted hallucinations
566
+
567
+ **IncreaseConfidence**
568
+ Moving away from unlikely interpretations to reveal the most probable perceptual experience
569
+
570
+ ### Parameters:
571
+ - **Drift Noise**: Initial uncertainty in the prediction process
572
+ - **Diffusion Noise**: Stochastic exploration during prediction
573
+ - **Update Rate**: Speed of convergence to the predicted hallucination
574
+ - **Number of Iterations**: How many prediction steps to perform
575
+ - **Model Layer**: Which perceptual level to predict from (early edges vs. high-level objects)
576
+ - **Epsilon (Stimulus Fidelity)**: How closely the prediction must match the input stimulus
577
+
578
+ ### Why Does This Work?
579
+
580
+ Adversarially robust neural networks develop perceptual representations similar to human vision. When we use generative inference to reveal what these networks "expect" to see, it matches what humans hallucinate in ambiguous images—allowing us to predict human perception.
581
+
582
+ **Developed by [Tahereh Toosi](https://toosi.github.io)**
583
+ """)
584
+
585
+ # Launch the demo
586
+ if __name__ == "__main__":
587
+ print(f"Starting server on port {args.port}")
588
+ demo.launch(
589
+ server_name="0.0.0.0",
590
+ server_port=args.port,
591
+ share=False,
592
+ debug=True
593
+ )
face_vase.png ADDED
huggingface-metadata.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "title": "Human Hallucination Prediction",
3
+ "emoji": "👁️",
4
+ "colorFrom": "blue",
5
+ "colorTo": "purple",
6
+ "sdk": "gradio",
7
+ "sdk_version": "5.23.1",
8
+ "app_file": "app.py",
9
+ "pinned": false,
10
+ "license": "mit"
11
+ }
inference.py ADDED
@@ -0,0 +1,1012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from torchvision.models.resnet import ResNet50_Weights
7
+ from PIL import Image
8
+ import numpy as np
9
+ import os
10
+ import requests
11
+ import time
12
+ import copy
13
+ from collections import OrderedDict
14
+ from pathlib import Path
15
+
16
+ # Check for available hardware acceleration
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
20
+ device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs
21
+ else:
22
+ device = torch.device("cpu")
23
+ print(f"Using device: {device}")
24
+
25
+ # Constants
26
+ MODEL_URLS = {
27
+ 'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
28
+ 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
29
+ 'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/blob/main/100_checkpoint.pt'
30
+ }
31
+
32
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
33
+ IMAGENET_STD = [0.229, 0.224, 0.225]
34
+
35
+ # Define the transforms based on whether normalization is on or off
36
+ def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD):
37
+ if normalize:
38
+ return transforms.Compose([
39
+ transforms.Resize(input_size),
40
+ transforms.CenterCrop(input_size),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(norm_mean, norm_std),
43
+ ])
44
+ else:
45
+ return transforms.Compose([
46
+ transforms.Resize(input_size),
47
+ transforms.CenterCrop(input_size),
48
+ transforms.ToTensor(),
49
+ ])
50
+
51
+ # Default transform without normalization
52
+ transform = transforms.Compose([
53
+ transforms.Resize(224),
54
+ transforms.CenterCrop(224),
55
+ transforms.ToTensor(),
56
+ ])
57
+
58
+ normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
59
+
60
+
61
+ def create_flat_top_gaussian_mask(image_shape, center_x, center_y, flat_radius, sigma, max_multiplier=4.0, min_multiplier=1.0, device=None):
62
+ """
63
+ Create a flat-top Gaussian epsilon/step-size mask for adaptive constraint application.
64
+ Coordinates are normalized: -1 = left/top edge, 1 = right/bottom edge, 0 = center.
65
+
66
+ Args:
67
+ image_shape: (batch, channels, height, width)
68
+ center_x: Center x (normalized -1 to 1)
69
+ center_y: Center y (normalized -1 to 1)
70
+ flat_radius: Radius of flat-top region (normalized)
71
+ sigma: Std for Gaussian fall-off outside flat region
72
+ max_multiplier: Multiplier inside flat region
73
+ min_multiplier: Multiplier far from center
74
+ device: torch device for the output tensor
75
+ Returns:
76
+ Tensor (1, 1, H, W)
77
+ """
78
+ _, _, H, W = image_shape
79
+ if device is None:
80
+ device = torch.device("cpu")
81
+ y_coords = torch.linspace(-1, 1, H, device=device)
82
+ x_coords = torch.linspace(-1, 1, W, device=device)
83
+ y_coords, x_coords = torch.meshgrid(y_coords, x_coords, indexing='ij')
84
+ dist = torch.sqrt((x_coords - center_x) ** 2 + (y_coords - center_y) ** 2)
85
+ multiplier = torch.where(
86
+ dist <= flat_radius,
87
+ torch.full_like(dist, max_multiplier),
88
+ min_multiplier + (max_multiplier - min_multiplier) * torch.exp(-((dist - flat_radius) ** 2) / (2 * sigma ** 2))
89
+ )
90
+ return multiplier.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
91
+
92
+
93
+ def extract_middle_layers(model, layer_index):
94
+ """
95
+ Extract a subset of the model up to a specific layer.
96
+
97
+ Args:
98
+ model: The neural network model
99
+ layer_index: String 'all' for the full model, or a layer identifier (string or int)
100
+ For ResNet: integers 0-8 representing specific layers
101
+ For ViT: strings like 'encoder.layers.encoder_layer_3'
102
+
103
+ Returns:
104
+ A modified model that outputs features from the specified layer
105
+ """
106
+ if isinstance(layer_index, str) and layer_index == 'all':
107
+ return model
108
+
109
+ # Special case for ViT's encoder layers with DataParallel wrapper
110
+ if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'):
111
+ try:
112
+ target_layer_idx = int(layer_index.split('_')[-1])
113
+
114
+ # Create a deep copy of the model to avoid modifying the original
115
+ new_model = copy.deepcopy(model)
116
+
117
+ # For models wrapped in DataParallel
118
+ if hasattr(new_model, 'module'):
119
+ # Create a subset of encoder layers up to the specified index
120
+ encoder_layers = nn.Sequential()
121
+ for i in range(target_layer_idx + 1):
122
+ layer_name = f"encoder_layer_{i}"
123
+ if hasattr(new_model.module.encoder.layers, layer_name):
124
+ encoder_layers.add_module(layer_name,
125
+ getattr(new_model.module.encoder.layers, layer_name))
126
+
127
+ # Replace the encoder layers with our truncated version
128
+ new_model.module.encoder.layers = encoder_layers
129
+
130
+ # Remove the heads since we're stopping at the encoder layer
131
+ new_model.module.heads = nn.Identity()
132
+
133
+ return new_model
134
+ else:
135
+ # Direct model access (not DataParallel)
136
+ encoder_layers = nn.Sequential()
137
+ for i in range(target_layer_idx + 1):
138
+ layer_name = f"encoder_layer_{i}"
139
+ if hasattr(new_model.encoder.layers, layer_name):
140
+ encoder_layers.add_module(layer_name,
141
+ getattr(new_model.encoder.layers, layer_name))
142
+
143
+ # Replace the encoder layers with our truncated version
144
+ new_model.encoder.layers = encoder_layers
145
+
146
+ # Remove the heads since we're stopping at the encoder layer
147
+ new_model.heads = nn.Identity()
148
+
149
+ return new_model
150
+
151
+ except (ValueError, IndexError) as e:
152
+ raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}")
153
+
154
+ # Handling for ViT whole blocks
155
+ elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')):
156
+ # Check for DataParallel wrapper
157
+ base_model = model.module if hasattr(model, 'module') else model
158
+
159
+ # Create a deep copy to avoid modifying the original
160
+ new_model = copy.deepcopy(model)
161
+ base_new_model = new_model.module if hasattr(new_model, 'module') else new_model
162
+
163
+ # Add the desired number of transformer blocks
164
+ if isinstance(layer_index, int):
165
+ # Truncate the blocks
166
+ base_new_model.blocks = base_new_model.blocks[:layer_index+1]
167
+
168
+ return new_model
169
+
170
+ else:
171
+ # Original ResNet/VGG handling
172
+ modules = list(model.named_children())
173
+ print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}")
174
+
175
+ cutoff_idx = next((i for i, (name, _) in enumerate(modules)
176
+ if name == str(layer_index)), None)
177
+
178
+ if cutoff_idx is not None:
179
+ # Keep modules up to and including the target
180
+ new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1]))
181
+ return new_model
182
+ else:
183
+ raise ValueError(f"Module {layer_index} not found in model")
184
+
185
+ # Get ImageNet labels
186
+ def get_imagenet_labels():
187
+ url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
188
+ response = requests.get(url)
189
+ if response.status_code == 200:
190
+ return response.json()
191
+ else:
192
+ raise RuntimeError("Failed to fetch ImageNet labels")
193
+
194
+ # Download model if needed
195
+ def download_model(model_type):
196
+ if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
197
+ return None # Use PyTorch's pretrained model
198
+
199
+ # Handle special case for face model
200
+ if model_type == 'resnet50_robust_face':
201
+ model_path = Path("models/resnet50_robust_face_100_checkpoint.pt")
202
+ else:
203
+ model_path = Path(f"models/{model_type}.pt")
204
+
205
+ if not model_path.exists():
206
+ print(f"Downloading {model_type} model...")
207
+ url = MODEL_URLS[model_type]
208
+ response = requests.get(url, stream=True)
209
+ if response.status_code == 200:
210
+ with open(model_path, 'wb') as f:
211
+ for chunk in response.iter_content(chunk_size=8192):
212
+ f.write(chunk)
213
+ print(f"Model downloaded and saved to {model_path}")
214
+ else:
215
+ raise RuntimeError(f"Failed to download model: {response.status_code}")
216
+ return model_path
217
+
218
+ class NormalizeByChannelMeanStd(nn.Module):
219
+ def __init__(self, mean, std):
220
+ super(NormalizeByChannelMeanStd, self).__init__()
221
+ if not isinstance(mean, torch.Tensor):
222
+ mean = torch.tensor(mean)
223
+ if not isinstance(std, torch.Tensor):
224
+ std = torch.tensor(std)
225
+ self.register_buffer("mean", mean)
226
+ self.register_buffer("std", std)
227
+
228
+ def forward(self, tensor):
229
+ return self.normalize_fn(tensor, self.mean, self.std)
230
+
231
+ def normalize_fn(self, tensor, mean, std):
232
+ """Differentiable version of torchvision.functional.normalize"""
233
+ # here we assume the color channel is at dim=1
234
+ mean = mean[None, :, None, None]
235
+ std = std[None, :, None, None]
236
+ return tensor.sub(mean).div(std)
237
+
238
+ class InferStep:
239
+ def __init__(self, orig_image, eps, step_size, adaptive_eps_config=None, adaptive_step_config=None):
240
+ self.orig_image = orig_image
241
+ dev = orig_image.device
242
+
243
+ # Adaptive epsilon (spatially varying constraint)
244
+ if adaptive_eps_config and adaptive_eps_config.get("enabled", False):
245
+ self.use_adaptive_eps = True
246
+ base_eps = float(adaptive_eps_config.get("base_epsilon", eps))
247
+ self.eps_map = create_flat_top_gaussian_mask(
248
+ orig_image.shape,
249
+ float(adaptive_eps_config.get("center_x", 0.0)),
250
+ float(adaptive_eps_config.get("center_y", 0.0)),
251
+ float(adaptive_eps_config.get("flat_radius", 0.3)),
252
+ float(adaptive_eps_config.get("sigma", 0.2)),
253
+ float(adaptive_eps_config.get("max_multiplier", 4.0)),
254
+ float(adaptive_eps_config.get("min_multiplier", 1.0)),
255
+ device=dev,
256
+ ) * base_eps
257
+ self.eps_map = self.eps_map.to(dev)
258
+ else:
259
+ self.use_adaptive_eps = False
260
+ self.eps = eps
261
+
262
+ # Adaptive step size (spatially varying update strength)
263
+ if adaptive_step_config and adaptive_step_config.get("enabled", False):
264
+ self.use_adaptive_step_size = True
265
+ base_step = float(adaptive_step_config.get("base_step_size", step_size))
266
+ self.step_size_map = create_flat_top_gaussian_mask(
267
+ orig_image.shape,
268
+ float(adaptive_step_config.get("center_x", 0.0)),
269
+ float(adaptive_step_config.get("center_y", 0.0)),
270
+ float(adaptive_step_config.get("flat_radius", 0.3)),
271
+ float(adaptive_step_config.get("sigma", 0.2)),
272
+ float(adaptive_step_config.get("max_multiplier", 4.0)),
273
+ float(adaptive_step_config.get("min_multiplier", 1.0)),
274
+ device=dev,
275
+ ) * base_step
276
+ self.step_size_map = self.step_size_map.to(dev)
277
+ else:
278
+ self.use_adaptive_step_size = False
279
+ self.step_size = step_size
280
+
281
+ def project(self, x):
282
+ # Per-pixel deviation from original; clamp by eps (or eps_map). No add/subtract of mask.
283
+ diff = x - self.orig_image
284
+ if self.use_adaptive_eps:
285
+ diff = torch.clamp(diff, -self.eps_map, self.eps_map)
286
+ else:
287
+ diff = torch.clamp(diff, -self.eps, self.eps)
288
+ return torch.clamp(self.orig_image + diff, 0, 1)
289
+
290
+ def step(self, x, grad):
291
+ # Same gradient everywhere; scale magnitude by step_size (or step_size_map). Multiply, not add/subtract.
292
+ l = len(x.shape) - 1
293
+ grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1] * l))
294
+ scaled_grad = grad / (grad_norm + 1e-10)
295
+ if self.use_adaptive_step_size:
296
+ return scaled_grad * self.step_size_map
297
+ return scaled_grad * self.step_size
298
+
299
+ def get_iterations_to_show(n_itr):
300
+ """Generate a dynamic list of iterations to show based on total iterations."""
301
+ if n_itr <= 50:
302
+ return [1, 5, 10, 20, 30, 40, 50, n_itr]
303
+ elif n_itr <= 100:
304
+ return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
305
+ elif n_itr <= 200:
306
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
307
+ elif n_itr <= 500:
308
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
309
+ else:
310
+ # For very large iterations, show more evenly distributed points
311
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
312
+ int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
313
+
314
+ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
315
+ """Generate inference configuration with customizable parameters.
316
+
317
+ Args:
318
+ inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion')
319
+ eps (float): Maximum perturbation size
320
+ n_itr (int): Number of iterations
321
+ step_size (float): Step size for each iteration
322
+ """
323
+
324
+ # Base configuration common to all inference types
325
+ config = {
326
+ 'loss_infer': inference_type, # How to guide the optimization
327
+ 'n_itr': n_itr, # Number of iterations
328
+ 'eps': eps, # Maximum perturbation size
329
+ 'step_size': step_size, # Step size for each iteration
330
+ 'diffusion_noise_ratio': 0.0, # No diffusion noise
331
+ 'initial_inference_noise_ratio': 0.0, # No initial noise
332
+ 'top_layer': 'all', # Use all layers of the model
333
+ 'inference_normalization': False, # Apply normalization during inference
334
+ 'recognition_normalization': False, # Apply normalization during recognition
335
+ 'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
336
+ 'misc_info': {'keep_grads': False} # Additional configuration
337
+ }
338
+
339
+ # Customize based on inference type
340
+ if inference_type == 'IncreaseConfidence':
341
+ config['loss_function'] = 'CE' # Cross Entropy
342
+
343
+ elif inference_type == 'Prior-Guided Drift Diffusion':
344
+ config['loss_function'] = 'MSE' # Mean Square Error
345
+ config['initial_inference_noise_ratio'] = 0.05 # Initial noise for diffusion
346
+ config['diffusion_noise_ratio'] = 0.01 # Add noise during diffusion
347
+
348
+ elif inference_type == 'GradModulation':
349
+ config['loss_function'] = 'CE' # Cross Entropy
350
+ config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength
351
+
352
+ elif inference_type == 'CompositionalFusion':
353
+ config['loss_function'] = 'CE' # Cross Entropy
354
+ config['misc_info']['positive_classes'] = [] # Classes to maximize
355
+ config['misc_info']['negative_classes'] = [] # Classes to minimize
356
+
357
+ return config
358
+
359
+ class GenerativeInferenceModel:
360
+ def __init__(self):
361
+ self.models = {}
362
+ self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
363
+ self.labels = get_imagenet_labels()
364
+
365
+ def verify_model_integrity(self, model, model_type):
366
+ """
367
+ Verify model integrity by running a test input through it.
368
+ Returns whether the model passes basic integrity check.
369
+ """
370
+ try:
371
+ print(f"\n=== Running model integrity check for {model_type} ===")
372
+ # Create a deterministic test input directly on the correct device
373
+ test_input = torch.zeros(1, 3, 224, 224, device=device)
374
+ test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
375
+
376
+ # Run forward pass
377
+ with torch.no_grad():
378
+ output = model(test_input)
379
+
380
+ # Check output shape
381
+ if output.shape != (1, 1000):
382
+ print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)")
383
+ return False
384
+
385
+ # Get top prediction
386
+ probs = torch.nn.functional.softmax(output, dim=1)
387
+ confidence, prediction = torch.max(probs, 1)
388
+
389
+ # Calculate basic statistics on output
390
+ mean = output.mean().item()
391
+ std = output.std().item()
392
+ min_val = output.min().item()
393
+ max_val = output.max().item()
394
+
395
+ print(f"Model integrity check results:")
396
+ print(f"- Output shape: {output.shape}")
397
+ print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
398
+ print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}")
399
+
400
+ # Basic sanity checks
401
+ if torch.isnan(output).any():
402
+ print("❌ Model produced NaN outputs")
403
+ return False
404
+
405
+ if output.std().item() < 0.1:
406
+ print("⚠️ Low output variance, model may not be discriminative")
407
+
408
+ print("✅ Model passes basic integrity check")
409
+ return True
410
+
411
+ except Exception as e:
412
+ print(f"❌ Model integrity check failed with error: {e}")
413
+ # Rather than failing completely, we'll continue
414
+ return True
415
+
416
+ def load_model(self, model_type):
417
+ """Load model from checkpoint or use pretrained model."""
418
+ if model_type in self.models:
419
+ print(f"Using cached {model_type} model")
420
+ return self.models[model_type]
421
+
422
+ # Record loading time for performance analysis
423
+ start_time = time.time()
424
+ model_path = download_model(model_type)
425
+
426
+ # Create a sequential model with normalizer and ResNet50
427
+ resnet = models.resnet50()
428
+ model = nn.Sequential(
429
+ self.normalizer, # Normalizer is part of the model sequence
430
+ resnet
431
+ )
432
+
433
+ # Load the model checkpoint
434
+ if model_path:
435
+ print(f"Loading {model_type} model from {model_path}...")
436
+ try:
437
+ checkpoint = torch.load(model_path, map_location=device)
438
+
439
+ # Print checkpoint structure for better understanding
440
+ print("\n=== Analyzing checkpoint structure ===")
441
+ if isinstance(checkpoint, dict):
442
+ print(f"Checkpoint contains keys: {list(checkpoint.keys())}")
443
+
444
+ # Examine 'model' structure if it exists
445
+ if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
446
+ model_dict = checkpoint['model']
447
+ # Get sample of keys to understand structure
448
+ first_keys = list(model_dict.keys())[:5]
449
+ print(f"'model' contains keys like: {first_keys}")
450
+
451
+ # Check for common prefixes in the model dict
452
+ prefixes = set()
453
+ for key in list(model_dict.keys())[:100]: # Check first 100 keys
454
+ parts = key.split('.')
455
+ if len(parts) > 1:
456
+ prefixes.add(parts[0])
457
+ if prefixes:
458
+ print(f"Common prefixes in model dict: {prefixes}")
459
+ else:
460
+ print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}")
461
+
462
+ # Handle different checkpoint formats
463
+ if 'model' in checkpoint:
464
+ # Format from madrylab robust models
465
+ state_dict = checkpoint['model']
466
+ print("Using 'model' key from checkpoint")
467
+ elif 'state_dict' in checkpoint:
468
+ state_dict = checkpoint['state_dict']
469
+ print("Using 'state_dict' key from checkpoint")
470
+ else:
471
+ # Direct state dict
472
+ state_dict = checkpoint
473
+ print("Using checkpoint directly as state_dict")
474
+
475
+ # Handle prefix in state dict keys for ResNet part
476
+ resnet_state_dict = {}
477
+ prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.']
478
+ resnet_keys = set(resnet.state_dict().keys())
479
+
480
+ # First check if we can find keys directly in the attacker.model path
481
+ print("\n=== Phase 1: Checking for specific model structures ===")
482
+
483
+ # Check for 'module.model' structure (seen in actual checkpoint)
484
+ module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')]
485
+ if module_model_keys:
486
+ print(f"Found 'module.model' structure with {len(module_model_keys)} parameters")
487
+ # Extract all parameters from module.model
488
+ for source_key, value in state_dict.items():
489
+ if source_key.startswith('module.model.'):
490
+ target_key = source_key[len('module.model.'):]
491
+ resnet_state_dict[target_key] = value
492
+
493
+ print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
494
+
495
+ # Check for 'attacker.model' structure
496
+ attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')]
497
+ if attacker_model_keys:
498
+ print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters")
499
+ # Extract all parameters from attacker.model
500
+ for source_key, value in state_dict.items():
501
+ if source_key.startswith('attacker.model.'):
502
+ target_key = source_key[len('attacker.model.'):]
503
+ resnet_state_dict[target_key] = value
504
+
505
+ print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model")
506
+
507
+ # Check if 'model' (not attacker.model) exists as a fallback
508
+ model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')]
509
+ if model_keys and len(resnet_state_dict) < len(resnet_keys):
510
+ print(f"Found additional 'model.' structure with {len(model_keys)} parameters")
511
+ # Try to complete missing parameters
512
+ for source_key, value in state_dict.items():
513
+ if source_key.startswith('model.'):
514
+ target_key = source_key[len('model.'):]
515
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
516
+ resnet_state_dict[target_key] = value
517
+
518
+ else:
519
+ # Check for other known structures
520
+ structure_found = False
521
+
522
+ # Check for 'model.' prefix
523
+ model_keys = [key for key in state_dict.keys() if key.startswith('model.')]
524
+ if model_keys:
525
+ print(f"Found 'model.' structure with {len(model_keys)} parameters")
526
+ for source_key, value in state_dict.items():
527
+ if source_key.startswith('model.'):
528
+ target_key = source_key[len('model.'):]
529
+ resnet_state_dict[target_key] = value
530
+ structure_found = True
531
+
532
+ # Check for ResNet parameters at the top level
533
+ top_level_resnet_keys = 0
534
+ for key in resnet_keys:
535
+ if key in state_dict:
536
+ top_level_resnet_keys += 1
537
+
538
+ if top_level_resnet_keys > 0:
539
+ print(f"Found {top_level_resnet_keys} ResNet parameters at top level")
540
+ for target_key in resnet_keys:
541
+ if target_key in state_dict:
542
+ resnet_state_dict[target_key] = state_dict[target_key]
543
+ structure_found = True
544
+
545
+ # If no structure was recognized, try the prefix mapping approach
546
+ if not structure_found:
547
+ print("No standard model structure found, trying prefix mappings...")
548
+ for target_key in resnet_keys:
549
+ for prefix in prefixes_to_try:
550
+ source_key = prefix + target_key
551
+ if source_key in state_dict:
552
+ resnet_state_dict[target_key] = state_dict[source_key]
553
+ break
554
+
555
+ # If we still can't find enough keys, try a final approach of removing prefixes
556
+ if len(resnet_state_dict) < len(resnet_keys):
557
+ print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...")
558
+
559
+ # Track matches found through prefix removal
560
+ prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']}
561
+ layer_matches = {} # Track matches by layer type
562
+
563
+ # Count parameter keys by layer type for analysis
564
+ for key in resnet_keys:
565
+ layer_name = key.split('.')[0] if '.' in key else key
566
+ if layer_name not in layer_matches:
567
+ layer_matches[layer_name] = {'total': 0, 'matched': 0}
568
+ layer_matches[layer_name]['total'] += 1
569
+
570
+ # Try keys with common prefixes
571
+ for source_key, value in state_dict.items():
572
+ # Skip if already found
573
+ target_key = source_key
574
+ matched_prefix = None
575
+
576
+ # Try removing various prefixes
577
+ for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']:
578
+ if source_key.startswith(prefix):
579
+ target_key = source_key[len(prefix):]
580
+ matched_prefix = prefix
581
+ break
582
+
583
+ # If the target key is in the ResNet keys, add it to the state dict
584
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
585
+ resnet_state_dict[target_key] = value
586
+
587
+ # Update match statistics
588
+ if matched_prefix:
589
+ prefix_matches[matched_prefix] += 1
590
+
591
+ # Update layer matches
592
+ layer_name = target_key.split('.')[0] if '.' in target_key else target_key
593
+ if layer_name in layer_matches:
594
+ layer_matches[layer_name]['matched'] += 1
595
+
596
+ # Print detailed prefix removal statistics
597
+ print("\n=== Prefix Removal Statistics ===")
598
+ total_matches = sum(prefix_matches.values())
599
+ print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)")
600
+
601
+ # Show matches by prefix
602
+ print("\nMatches by prefix:")
603
+ for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True):
604
+ if count > 0:
605
+ print(f" {prefix}: {count} parameters")
606
+
607
+ # Show matches by layer type
608
+ print("\nMatches by layer type:")
609
+ for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True):
610
+ match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0
611
+ print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)")
612
+
613
+ # Check for specific important layers (conv1, layer1, etc.)
614
+ critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
615
+ print("\nStatus of critical layers:")
616
+ for layer in critical_layers:
617
+ if layer in layer_matches:
618
+ match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100
619
+ status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE"
620
+ print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}")
621
+ else:
622
+ print(f" {layer}: Not found in model")
623
+
624
+ # Load the ResNet state dict
625
+ if resnet_state_dict:
626
+ try:
627
+ # Use strict=False to allow missing keys
628
+ result = resnet.load_state_dict(resnet_state_dict, strict=False)
629
+ missing_keys, unexpected_keys = result
630
+
631
+ # Generate detailed information with better formatting
632
+ loading_report = []
633
+ loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====")
634
+ loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}")
635
+ loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}")
636
+ loading_report.append(f"Missing keys: {len(missing_keys):,} parameters")
637
+ loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters")
638
+
639
+ # Calculate percentage of parameters loaded
640
+ loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys)
641
+ loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100
642
+
643
+ # Determine loading success status
644
+ if loaded_percent >= 99.5:
645
+ status = "✅ COMPLETE - All important parameters loaded"
646
+ elif loaded_percent >= 90:
647
+ status = "🟡 PARTIAL - Most parameters loaded, should still function"
648
+ elif loaded_percent >= 50:
649
+ status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly"
650
+ else:
651
+ status = "❌ FAILED - Critical parameters missing, will not function properly"
652
+
653
+ loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)")
654
+ loading_report.append(f"Loading status: {status}")
655
+
656
+ # If loading is severely incomplete, fall back to PyTorch's pretrained model
657
+ if loaded_percent < 50:
658
+ loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.")
659
+ loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.")
660
+
661
+ # Create a new ResNet model with pretrained weights
662
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
663
+ model = nn.Sequential(self.normalizer, resnet)
664
+ loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
665
+
666
+ # Show missing keys by layer type
667
+ if missing_keys:
668
+ loading_report.append("\nMissing keys by layer type:")
669
+ layer_types = {}
670
+ for key in missing_keys:
671
+ # Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.)
672
+ parts = key.split('.')
673
+ if len(parts) > 0:
674
+ layer_type = parts[0]
675
+ if layer_type not in layer_types:
676
+ layer_types[layer_type] = 0
677
+ layer_types[layer_type] += 1
678
+
679
+ # Add counts by layer type
680
+ for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
681
+ loading_report.append(f" {layer_type}: {count:,} parameters")
682
+
683
+ loading_report.append("\nFirst 10 missing keys:")
684
+ for i, key in enumerate(sorted(missing_keys)[:10]):
685
+ loading_report.append(f" {i+1}. {key}")
686
+
687
+ # Show unexpected keys if any
688
+ if unexpected_keys:
689
+ loading_report.append("\nFirst 10 unexpected keys:")
690
+ for i, key in enumerate(sorted(unexpected_keys)[:10]):
691
+ loading_report.append(f" {i+1}. {key}")
692
+
693
+ loading_report.append("========================================")
694
+
695
+ # Convert report to string and print it
696
+ report_text = "\n".join(loading_report)
697
+ print(report_text)
698
+
699
+ # Also save to a file for reference
700
+ os.makedirs("logs", exist_ok=True)
701
+ with open(f"logs/model_loading_{model_type}.log", "w") as f:
702
+ f.write(report_text)
703
+
704
+ # Look for normalizer parameters as well
705
+ if any(key.startswith('attacker.normalize.') for key in state_dict.keys()):
706
+ norm_state_dict = {}
707
+ for key, value in state_dict.items():
708
+ if key.startswith('attacker.normalize.'):
709
+ norm_key = key[len('attacker.normalize.'):]
710
+ norm_state_dict[norm_key] = value
711
+
712
+ if norm_state_dict:
713
+ try:
714
+ self.normalizer.load_state_dict(norm_state_dict, strict=False)
715
+ print("Successfully loaded normalizer parameters")
716
+ except Exception as e:
717
+ print(f"Warning: Could not load normalizer parameters: {e}")
718
+ except Exception as e:
719
+ print(f"Warning: Error loading ResNet parameters: {e}")
720
+ # Fall back to loading without normalizer
721
+ model = resnet # Use just the ResNet model without normalizer
722
+ except Exception as e:
723
+ print(f"Error loading model checkpoint: {e}")
724
+ # Fallback to PyTorch's pretrained model
725
+ print("Falling back to PyTorch's pretrained model")
726
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
727
+ model = nn.Sequential(self.normalizer, resnet)
728
+ else:
729
+ # Fallback to PyTorch's pretrained model
730
+ print("No checkpoint available, using PyTorch's pretrained model")
731
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
732
+ model = nn.Sequential(self.normalizer, resnet)
733
+
734
+ model = model.to(device)
735
+ model.eval() # Set to evaluation mode
736
+
737
+ # Verify model integrity
738
+ self.verify_model_integrity(model, model_type)
739
+
740
+ # Store the model for future use
741
+ self.models[model_type] = model
742
+ end_time = time.time()
743
+ load_time = end_time - start_time
744
+ print(f"Model {model_type} loaded in {load_time:.2f} seconds")
745
+ return model
746
+
747
+ def inference(self, image, model_type, config):
748
+ """Run generative inference on the image."""
749
+ # Time the entire inference process
750
+ inference_start = time.time()
751
+
752
+ # Load model if not already loaded
753
+ model = self.load_model(model_type)
754
+
755
+ # Check if image is a file path
756
+ if isinstance(image, str):
757
+ if os.path.exists(image):
758
+ image = Image.open(image).convert('RGB')
759
+ else:
760
+ raise ValueError(f"Image path does not exist: {image}")
761
+ elif isinstance(image, torch.Tensor):
762
+ raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor")
763
+
764
+ # Prepare image tensor - match original code's conditional transform
765
+ load_start = time.time()
766
+ use_norm = config['inference_normalization'] == 'on'
767
+ custom_transform = get_transform(
768
+ input_size=224,
769
+ normalize=use_norm,
770
+ norm_mean=IMAGENET_MEAN,
771
+ norm_std=IMAGENET_STD
772
+ )
773
+
774
+ # Special handling for GradModulation as in original
775
+ if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']:
776
+ grad_modulation = config['misc_info']['grad_modulation']
777
+ image_tensor = custom_transform(image).unsqueeze(0).to(device)
778
+ image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device)
779
+ else:
780
+ image_tensor = custom_transform(image).unsqueeze(0).to(device)
781
+
782
+ image_tensor.requires_grad = True
783
+ print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds")
784
+
785
+ # Check model structure
786
+ is_sequential = isinstance(model, nn.Sequential)
787
+
788
+ # Get original predictions
789
+ with torch.no_grad():
790
+ # If the model is sequential with a normalizer, skip the normalization step
791
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
792
+ print("Model is sequential with normalization")
793
+ # Get the core model part (typically at index 1 in Sequential)
794
+ core_model = model[1]
795
+ if config['inference_normalization']:
796
+ output_original = model(image_tensor) # Model includes normalization
797
+ else:
798
+ output_original = core_model(image_tensor) # Model includes normalization
799
+
800
+ else:
801
+ print("Model is not sequential with normalization")
802
+ # Use manual normalization for non-sequential models
803
+ if config['inference_normalization']:
804
+ normalized_tensor = normalize_transform(image_tensor)
805
+ output_original = model(normalized_tensor)
806
+ else:
807
+ output_original = model(image_tensor)
808
+ core_model = model
809
+
810
+ probs_orig = F.softmax(output_original, dim=1)
811
+ conf_orig, classes_orig = torch.max(probs_orig, 1)
812
+
813
+ # Get least confident classes
814
+ _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False)
815
+
816
+ # Initialize inference step (adaptive Gaussian mask from config)
817
+ adaptive_eps = config.get("adaptive_epsilon")
818
+ adaptive_step = config.get("adaptive_step_size")
819
+ if adaptive_eps and adaptive_eps.get("enabled"):
820
+ print(f"Applying adaptive epsilon mask: center=({adaptive_eps.get('center_x'):.2f}, {adaptive_eps.get('center_y'):.2f}), radius={adaptive_eps.get('flat_radius'):.2f}")
821
+ if adaptive_step and adaptive_step.get("enabled"):
822
+ print(f"Applying adaptive step-size mask: center=({adaptive_step.get('center_x'):.2f}, {adaptive_step.get('center_y'):.2f}), radius={adaptive_step.get('flat_radius'):.2f}")
823
+ infer_step = InferStep(
824
+ image_tensor,
825
+ config["eps"],
826
+ config["step_size"],
827
+ adaptive_eps_config=adaptive_eps,
828
+ adaptive_step_config=adaptive_step,
829
+ )
830
+
831
+ # Storage for inference steps
832
+ # Create a new tensor that requires gradients
833
+ x = image_tensor.clone().detach().requires_grad_(True)
834
+ all_steps = [image_tensor[0].detach().cpu()]
835
+
836
+ # For Prior-Guided Drift Diffusion, extract selected layer and initialize with noisy features
837
+ noisy_features = None
838
+ layer_model = None
839
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
840
+ print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...")
841
+
842
+ # Extract model up to the specified layer
843
+ try:
844
+ # Start by finding the actual model to use
845
+ base_model = model
846
+
847
+ # Handle DataParallel wrapper if present
848
+ if hasattr(base_model, 'module'):
849
+ base_model = base_model.module
850
+
851
+ # Log the initial model structure
852
+ print(f"DEBUG - Initial model structure: {type(base_model)}")
853
+
854
+ # If we have a Sequential model (which is likely our normalizer + model structure)
855
+ if isinstance(base_model, nn.Sequential):
856
+ print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children")
857
+
858
+ # If this is our NormalizeByChannelMeanStd + ResNet pattern
859
+ if len(list(base_model.children())) >= 2:
860
+ # The actual ResNet model is the second component (index 1)
861
+ actual_model = list(base_model.children())[1]
862
+ print(f"DEBUG - Using ResNet component: {type(actual_model)}")
863
+ print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}")
864
+
865
+ # Extract from the actual ResNet
866
+ layer_model = extract_middle_layers(actual_model, config['top_layer'])
867
+ else:
868
+ # Just a single component Sequential
869
+ layer_model = extract_middle_layers(base_model, config['top_layer'])
870
+ else:
871
+ # Not Sequential, might be direct model
872
+ print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}")
873
+ layer_model = extract_middle_layers(base_model, config['top_layer'])
874
+
875
+ print(f"Successfully extracted model up to layer: {config['top_layer']}")
876
+ except ValueError as e:
877
+ print(f"Layer extraction failed: {e}. Using full model.")
878
+ layer_model = model
879
+
880
+ # Add noise to the image - exactly match original code
881
+ added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device)
882
+ noisy_image_tensor = image_tensor + added_noise
883
+
884
+ # Compute noisy features - simplified to match original code
885
+ noisy_features = layer_model(noisy_image_tensor)
886
+
887
+ print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}")
888
+
889
+ # Main inference loop
890
+ print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
891
+ loop_start = time.time()
892
+ for i in range(config['n_itr']):
893
+ # Reset gradients
894
+ x.grad = None
895
+
896
+ # Forward pass - use layer_model for Prior-Guided Drift Diffusion, full model otherwise
897
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
898
+ # Use the extracted layer model for Prior-Guided Drift Diffusion
899
+ # In original code, normalization is handled at transform time, not during forward pass
900
+ output = layer_model(x)
901
+ else:
902
+ # Standard forward pass with full model
903
+ # Simplified to match original code's approach
904
+ output = model(x)
905
+
906
+ # Calculate loss and gradients based on inference type
907
+ try:
908
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
909
+ # Use MSE loss to match the noisy features
910
+ assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE"
911
+ if noisy_features is not None:
912
+ loss = F.mse_loss(output, noisy_features)
913
+ grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original
914
+ else:
915
+ raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
916
+
917
+ else: # Default 'IncreaseConfidence' approach
918
+ # Get the least confident classes
919
+ num_classes = min(10, least_confident_classes.size(1))
920
+ target_classes = least_confident_classes[0, :num_classes]
921
+
922
+ # Create targets for least confident classes
923
+ targets = torch.tensor([idx.item() for idx in target_classes], device=device)
924
+
925
+ # Use a combined loss to increase confidence
926
+ loss = 0
927
+ for target in targets:
928
+ # Create one-hot target
929
+ one_hot = torch.zeros_like(output)
930
+ one_hot[0, target] = 1
931
+ # Use loss to maximize confidence
932
+ loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
933
+
934
+ grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
935
+
936
+ if grad is None:
937
+ print("Warning: Direct gradient calculation failed")
938
+ # Fall back to random perturbation
939
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
940
+ x = infer_step.project(x + random_noise)
941
+ else:
942
+ # Update image with gradient - do this exactly as in original code
943
+ adjusted_grad = infer_step.step(x, grad)
944
+
945
+ # Add diffusion noise if specified
946
+ diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device)
947
+
948
+ # Apply gradient and noise in one operation before projecting, exactly as in original
949
+ x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
950
+
951
+ except Exception as e:
952
+ print(f"Error in gradient calculation: {e}")
953
+ # Fall back to random perturbation - match original code
954
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
955
+ x = infer_step.project(x.clone() + random_noise)
956
+
957
+ # Store step if in iterations_to_show
958
+ if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
959
+ all_steps.append(x[0].detach().cpu())
960
+
961
+ # Print some info about the inference
962
+ with torch.no_grad():
963
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
964
+ if config['inference_normalization']:
965
+ final_output = model(x)
966
+ else:
967
+ final_output = core_model(x)
968
+ else:
969
+ if config['inference_normalization']:
970
+ normalized_x = normalize_transform(x)
971
+ final_output = model(normalized_x)
972
+ else:
973
+ final_output = model(x)
974
+
975
+ final_probs = F.softmax(final_output, dim=1)
976
+ final_conf, final_classes = torch.max(final_probs, 1)
977
+
978
+ # Calculate timing information
979
+ loop_time = time.time() - loop_start
980
+ total_time = time.time() - inference_start
981
+ avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0
982
+
983
+ print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
984
+ print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
985
+ print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)")
986
+ print(f"Total inference time: {total_time:.2f} seconds")
987
+
988
+ # Return results in format compatible with both old and new code
989
+ return {
990
+ 'final_image': x[0].detach().cpu(),
991
+ 'steps': all_steps,
992
+ 'original_class': classes_orig.item(),
993
+ 'original_confidence': conf_orig.item(),
994
+ 'final_class': final_classes.item(),
995
+ 'final_confidence': final_conf.item()
996
+ }
997
+
998
+ # Utility function to show inference steps
999
+ def show_inference_steps(steps, figsize=(15, 10)):
1000
+ import matplotlib.pyplot as plt
1001
+
1002
+ n_steps = len(steps)
1003
+ fig, axes = plt.subplots(1, n_steps, figsize=figsize)
1004
+
1005
+ for i, step_img in enumerate(steps):
1006
+ img = step_img.permute(1, 2, 0).numpy()
1007
+ axes[i].imshow(img)
1008
+ axes[i].set_title(f"Step {i}")
1009
+ axes[i].axis('off')
1010
+
1011
+ plt.tight_layout()
1012
+ return fig
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pillow
5
+ gradio
6
+ matplotlib
7
+ requests
8
+ tqdm
9
+ huggingface_hub