ttoosi commited on
Commit
fa937f9
·
verified ·
1 Parent(s): 7212655

Upload 23 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ stimuli/Confetti_illusion.png filter=lfs diff=lfs merge=lfs -text
37
+ stimuli/CornsweetBlock.png filter=lfs diff=lfs merge=lfs -text
38
+ stimuli/EhresteinSingleColor.png filter=lfs diff=lfs merge=lfs -text
39
+ stimuli/figure_ground.png filter=lfs diff=lfs merge=lfs -text
40
+ stimuli/Neon_Color_Circle.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 GenerativeInferenceDemo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,83 @@
1
  ---
2
- title: Generative Inference Faces
3
- emoji: 🚀
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: Generative inference on models trained for faces
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Generative Inference Demo
3
+ emoji: 🧠
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.23.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Generative Inference Demo
14
+
15
+ This Gradio demo showcases how neural networks perceive visual illusions through generative inference. The demo uses both standard and robust ResNet50 models to reveal emergent perception of contours, figure-ground separation, and other visual phenomena.
16
+
17
+ ## Models
18
+
19
+ - **Robust ResNet50**: A model trained with adversarial examples (ε=3.0), exhibiting more human-like visual perception
20
+ - **Standard ResNet50**: A model trained without adversarial examples (ε=0.0)
21
+
22
+ ## Features
23
+
24
+ - Upload your own images or use example illusions
25
+ - Choose between robust and standard models
26
+ - Adjust perturbation size (epsilon) and iteration count
27
+ - Visualize how perception emerges over time
28
+ - Includes classic illusions:
29
+ - Kanizsa shapes
30
+ - Face-Vase illusions
31
+ - Figure-Ground segmentation
32
+ - Neon color spreading
33
+
34
+ ## Usage
35
+
36
+ 1. Select an example image or upload your own
37
+ 2. Choose the model type (robust or standard)
38
+ 3. Adjust epsilon and iteration parameters
39
+ 4. Click "Run Inference" to see how the model perceives the image
40
+
41
+ ## About
42
+
43
+ This demo is based on research showing how adversarially robust models develop more human-like visual representations. The generative inference process reveals these perceptual biases by optimizing the input to maximize the model's confidence.
44
+
45
+ ## Installation
46
+
47
+ To run this demo locally:
48
+
49
+ ```bash
50
+ # Clone the repository
51
+ git clone [repo-url]
52
+ cd GenerativeInferenceDemo
53
+
54
+ # Install dependencies
55
+ pip install -r requirements.txt
56
+
57
+ # Run the app
58
+ python app.py
59
+ ```
60
+
61
+ The web app will be available at http://localhost:7860 (or another port if 7860 is busy).
62
+
63
+ ## About the Models
64
+
65
+ - **Robust ResNet50**: A model trained with adversarial examples, making it more robust to small perturbations. These models often exhibit more human-like visual perception.
66
+ - **Standard ResNet50**: A standard ImageNet-trained ResNet50 model.
67
+
68
+ ## How It Works
69
+
70
+ 1. The algorithm starts with an input image
71
+ 2. It iteratively updates the image to increase the model's confidence in its predictions
72
+ 3. These updates are constrained to a small neighborhood (controlled by epsilon) around the original image
73
+ 4. The resulting changes reveal how the network "sees" the image
74
+
75
+ ## Citation
76
+
77
+ If you use this work in your research, please cite the original paper:
78
+
79
+ [Citation information will be added here]
80
+
81
+ ## License
82
+
83
+ This project is licensed under the MIT License - see the LICENSE file for details.
app.py CHANGED
@@ -1,7 +1,412 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ try:
6
+ from spaces import GPU
7
+ except ImportError:
8
+ # Define a no-op decorator if running locally
9
+ def GPU(func):
10
+ return func
11
+
12
+ import os
13
+ import argparse
14
+ from inference import GenerativeInferenceModel, get_inference_configs
15
 
16
+ # Parse command line arguments
17
+ parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
18
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
19
+ args = parser.parse_args()
20
 
21
+ # Create model directories if they don't exist
22
+ os.makedirs("models", exist_ok=True)
23
+ os.makedirs("stimuli", exist_ok=True)
24
+
25
+ # Check if running on Hugging Face Spaces
26
+ if "SPACE_ID" in os.environ:
27
+ default_port = int(os.environ.get("PORT", 7860))
28
+ else:
29
+ default_port = 8861 # Local default port
30
+
31
+ # Initialize model
32
+ model = GenerativeInferenceModel()
33
+
34
+ # Define example images and their parameters with updated values from the research
35
+ examples = [
36
+ {
37
+ "image": os.path.join("stimuli", "Neon_Color_Circle.jpg"),
38
+ "name": "Neon Color Spreading",
39
+ "wiki": "https://en.wikipedia.org/wiki/Neon_color_spreading",
40
+ "papers": [
41
+ "[Color Assimilation](https://doi.org/10.1016/j.visres.2000.200.1)",
42
+ "[Perceptual Filling-in](https://doi.org/10.1016/j.tics.2003.08.003)"
43
+ ],
44
+ "method": "Prior-Guided Drift Diffusion",
45
+ "reverse_diff": {
46
+ "model": "resnet50_robust",
47
+ "layer": "layer3",
48
+ "initial_noise": 0.8,
49
+ "diffusion_noise": 0.003,
50
+ "step_size": 1.0,
51
+ "iterations": 101,
52
+ "epsilon": 20.0
53
+ }
54
+ },
55
+ {
56
+ "image": os.path.join("stimuli", "Kanizsa_square.jpg"),
57
+ "name": "Kanizsa Square",
58
+ "wiki": "https://en.wikipedia.org/wiki/Kanizsa_triangle",
59
+ "papers": [
60
+ "[Gestalt Psychology](https://en.wikipedia.org/wiki/Gestalt_psychology)",
61
+ "[Neural Mechanisms](https://doi.org/10.1016/j.tics.2003.08.003)"
62
+ ],
63
+ "method": "Prior-Guided Drift Diffusion",
64
+ "reverse_diff": {
65
+ "model": "resnet50_robust",
66
+ "layer": "all",
67
+ "initial_noise": 0.0,
68
+ "diffusion_noise": 0.005,
69
+ "step_size": 0.64,
70
+ "iterations": 100,
71
+ "epsilon": 5.0
72
+ }
73
+ },
74
+ {
75
+ "image": os.path.join("stimuli", "CornsweetBlock.png"),
76
+ "name": "Cornsweet Illusion",
77
+ "wiki": "https://en.wikipedia.org/wiki/Cornsweet_illusion",
78
+ "papers": [
79
+ "[Brightness Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
80
+ "[Edge Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
81
+ ],
82
+ "instructions": "Both blocks are gray in color (the same), use your finger to cover the middle line. Hit 'Load Parameters' and then hit 'Run Generative Inference' to see how the model sees the blocks.",
83
+ "method": "Prior-Guided Drift Diffusion",
84
+ "reverse_diff": {
85
+ "model": "resnet50_robust",
86
+ "layer": "layer3",
87
+ "initial_noise": 0.5,
88
+ "diffusion_noise": 0.005,
89
+ "step_size": 0.8,
90
+ "iterations": 51,
91
+ "epsilon": 20.0
92
+ }
93
+ },
94
+ {
95
+ "image": os.path.join("stimuli", "face_vase.png"),
96
+ "name": "Rubin's Face-Vase (Object Prior)",
97
+ "wiki": "https://en.wikipedia.org/wiki/Rubin_vase",
98
+ "papers": [
99
+ "[Figure-Ground Perception](https://en.wikipedia.org/wiki/Figure-ground_(perception))",
100
+ "[Bistable Perception](https://doi.org/10.1016/j.tics.2003.08.003)"
101
+ ],
102
+ "method": "Prior-Guided Drift Diffusion",
103
+ "reverse_diff": {
104
+ "model": "resnet50_robust",
105
+ "layer": "avgpool",
106
+ "initial_noise": 0.9,
107
+ "diffusion_noise": 0.003,
108
+ "step_size": 0.58,
109
+ "iterations": 100,
110
+ "epsilon": 0.81
111
+ }
112
+ },
113
+ {
114
+ "image": os.path.join("stimuli", "Confetti_illusion.png"),
115
+ "name": "Confetti Illusion",
116
+ "wiki": "https://www.youtube.com/watch?v=SvEiEi8O7QE",
117
+ "papers": [
118
+ "[Color Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
119
+ "[Context Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
120
+ ],
121
+ "method": "Prior-Guided Drift Diffusion",
122
+ "reverse_diff": {
123
+ "model": "resnet50_robust",
124
+ "layer": "layer3",
125
+ "initial_noise": 0.1,
126
+ "diffusion_noise": 0.003,
127
+ "step_size": 0.5,
128
+ "iterations": 101,
129
+ "epsilon": 20.0
130
+ }
131
+ },
132
+ {
133
+ "image": os.path.join("stimuli", "EhresteinSingleColor.png"),
134
+ "name": "Ehrenstein Illusion",
135
+ "wiki": "https://en.wikipedia.org/wiki/Ehrenstein_illusion",
136
+ "papers": [
137
+ "[Subjective Contours](https://doi.org/10.1016/j.visres.2000.200.1)",
138
+ "[Neural Processing](https://doi.org/10.1016/j.tics.2003.08.003)"
139
+ ],
140
+ "method": "Prior-Guided Drift Diffusion",
141
+ "reverse_diff": {
142
+ "model": "resnet50_robust",
143
+ "layer": "layer3",
144
+ "initial_noise": 0.5,
145
+ "diffusion_noise": 0.005,
146
+ "step_size": 0.8,
147
+ "iterations": 101,
148
+ "epsilon": 20.0
149
+ }
150
+ },
151
+ {
152
+ "image": os.path.join("stimuli", "GroupingByContinuity.png"),
153
+ "name": "Grouping by Continuity",
154
+ "wiki": "https://en.wikipedia.org/wiki/Principles_of_grouping",
155
+ "papers": [
156
+ "[Gestalt Principles](https://en.wikipedia.org/wiki/Gestalt_psychology)",
157
+ "[Visual Organization](https://doi.org/10.1016/j.tics.2003.08.003)"
158
+ ],
159
+ "method": "Prior-Guided Drift Diffusion",
160
+ "reverse_diff": {
161
+ "model": "resnet50_robust",
162
+ "layer": "layer3",
163
+ "initial_noise": 0.0,
164
+ "diffusion_noise": 0.005,
165
+ "step_size": 0.4,
166
+ "iterations": 101,
167
+ "epsilon": 4.0
168
+ }
169
+ },
170
+ {
171
+ "image": os.path.join("stimuli", "figure_ground.png"),
172
+ "name": "Figure-Ground Illusion",
173
+ "wiki": "https://en.wikipedia.org/wiki/Figure-ground_(perception)",
174
+ "papers": [
175
+ "[Gestalt Principles](https://en.wikipedia.org/wiki/Gestalt_psychology)",
176
+ "[Perceptual Organization](https://doi.org/10.1016/j.tics.2003.08.003)"
177
+ ],
178
+ "method": "Prior-Guided Drift Diffusion",
179
+ "reverse_diff": {
180
+ "model": "resnet50_robust",
181
+ "layer": "layer3",
182
+ "initial_noise": 0.1,
183
+ "diffusion_noise": 0.003,
184
+ "step_size": 0.5,
185
+ "iterations": 101,
186
+ "epsilon": 3.0
187
+ }
188
+ }
189
+ ]
190
+
191
+ @GPU
192
+ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
193
+ initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3"):
194
+ # Check if image is provided
195
+ if image is None:
196
+ return None, "Please upload an image before running inference."
197
+
198
+ # Convert eps to float
199
+ eps = float(eps_value)
200
+
201
+ # Load inference configuration based on the selected type
202
+ config = get_inference_configs(inference_type=inference_type, eps=eps, n_itr=int(num_iterations))
203
+
204
+ # Handle Prior-Guided Drift Diffusion specific parameters
205
+ if inference_type == "Prior-Guided Drift Diffusion":
206
+ config['initial_inference_noise_ratio'] = float(initial_noise)
207
+ config['diffusion_noise_ratio'] = float(diffusion_noise)
208
+ config['step_size'] = float(step_size) # Added step size parameter
209
+ config['top_layer'] = model_layer
210
+
211
+ # Run generative inference
212
+ result = model.inference(image, model_type, config)
213
+
214
+ # Extract results based on return type
215
+ if isinstance(result, tuple):
216
+ # Old format returning (output_image, all_steps)
217
+ output_image, all_steps = result
218
+ else:
219
+ # New format returning dictionary
220
+ output_image = result['final_image']
221
+ all_steps = result['steps']
222
+
223
+ # Create animation frames
224
+ frames = []
225
+ for i, step_image in enumerate(all_steps):
226
+ # Convert tensor to PIL image
227
+ step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
228
+ frames.append(step_pil)
229
+
230
+ # Convert the final output image to PIL
231
+ final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
232
+
233
+ # Return the final inferred image and the animation frames directly
234
+ return final_image, frames
235
+
236
+ # Helper function to apply example parameters
237
+ def apply_example(example):
238
+ return [
239
+ example["image"],
240
+ "resnet50_robust", # Model type
241
+ example["method"], # Inference type
242
+ example["reverse_diff"]["epsilon"], # Epsilon value
243
+ example["reverse_diff"]["iterations"], # Number of iterations
244
+ example["reverse_diff"]["initial_noise"], # Initial noise
245
+ example["reverse_diff"]["diffusion_noise"], # Diffusion noise value (corrected)
246
+ example["reverse_diff"]["step_size"], # Step size (added)
247
+ example["reverse_diff"]["layer"], # Model layer
248
+ gr.Group(visible=True) # Show parameters section
249
+ ]
250
+
251
+ # Define the interface
252
+ with gr.Blocks(title="Generative Inference Demo", css="""
253
+ .purple-button {
254
+ background-color: #8B5CF6 !important;
255
+ color: white !important;
256
+ border: none !important;
257
+ }
258
+ .purple-button:hover {
259
+ background-color: #7C3AED !important;
260
+ }
261
+ """) as demo:
262
+ gr.Markdown("# Generative Inference Demo")
263
+ gr.Markdown("This demo showcases how neural networks can perceive visual illusions and develop Gestalt principles of perceptual organization through generative inference.")
264
+
265
+ gr.Markdown("""
266
+ **How to use this demo:**
267
+ - **Load pre-configured examples**: Click on any visual illusion below and hit "Load Parameters" to automatically set up the optimal parameters for that illusion
268
+ - **Run the inference**: After loading parameters or setting your own, hit "Run Inference" to start the generative inference process
269
+ - **You can also upload your own images** and experiment with different parameters to see how they affect the generative inference process
270
+ """)
271
+
272
+ # Main processing interface
273
+ with gr.Row():
274
+ with gr.Column(scale=1):
275
+ # Inputs
276
+ image_input = gr.Image(label="Input Image", type="pil", value=os.path.join("stimuli", "Neon_Color_Circle.jpg"))
277
+
278
+ # Run Inference button right below the image
279
+ run_button = gr.Button("🪄 Run Generative Inference", variant="primary", elem_classes="purple-button")
280
+
281
+ # Parameters toggle button
282
+ params_button = gr.Button("⚙️ Play with the parameters", variant="secondary")
283
+
284
+ # Parameters section (initially hidden)
285
+ with gr.Group(visible=False) as params_section:
286
+ with gr.Row():
287
+ model_choice = gr.Dropdown(
288
+ choices=["resnet50_robust", "standard_resnet50"], # "resnet50_robust_face" - hidden for deployment
289
+ value="resnet50_robust",
290
+ label="Model"
291
+ )
292
+
293
+ inference_type = gr.Dropdown(
294
+ choices=["Prior-Guided Drift Diffusion", "IncreaseConfidence"],
295
+ value="Prior-Guided Drift Diffusion",
296
+ label="Inference Method"
297
+ )
298
+
299
+ with gr.Row():
300
+ eps_slider = gr.Slider(minimum=0.0, maximum=40.0, value=20.0, step=0.01, label="Epsilon (Stimulus Fidelity)")
301
+ iterations_slider = gr.Slider(minimum=1, maximum=600, value=101, step=1, label="Number of Iterations") # Updated max to 600
302
+
303
+ with gr.Row():
304
+ initial_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.01,
305
+ label="Drift Noise")
306
+ diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=0.05, value=0.003, step=0.001,
307
+ label="Diffusion Noise") # Corrected name
308
+
309
+ with gr.Row():
310
+ step_size_slider = gr.Slider(minimum=0.01, maximum=2.0, value=1.0, step=0.01,
311
+ label="Update Rate") # Added step size slider
312
+ layer_choice = gr.Dropdown(
313
+ choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
314
+ value="layer3",
315
+ label="Model Layer"
316
+ )
317
+
318
+ with gr.Column(scale=2):
319
+ # Outputs
320
+ output_image = gr.Image(label="Final Inferred Image")
321
+ output_frames = gr.Gallery(label="Inference Steps", columns=5, rows=2)
322
+
323
+ # Examples section with integrated explanations
324
+ gr.Markdown("## Visual Illusion Examples")
325
+ gr.Markdown("Select an illusion to load its parameters and see how generative inference reveals perceptual effects")
326
+
327
+ # For each example, create a row with the image and explanation side by side
328
+ for i, ex in enumerate(examples):
329
+ with gr.Row():
330
+ # Left column for the image
331
+ with gr.Column(scale=1):
332
+ # Display the example image
333
+ example_img = gr.Image(value=ex["image"], type="filepath", label=f"{ex['name']}")
334
+ load_btn = gr.Button(f"Load Parameters", variant="primary")
335
+
336
+ # Set up the load button to apply this example's parameters
337
+ load_btn.click(
338
+ fn=lambda ex=ex: apply_example(ex),
339
+ outputs=[
340
+ image_input, model_choice, inference_type,
341
+ eps_slider, iterations_slider,
342
+ initial_noise_slider, diffusion_noise_slider,
343
+ step_size_slider, layer_choice, params_section
344
+ ]
345
+ )
346
+
347
+ # Right column for the explanation
348
+ with gr.Column(scale=2):
349
+ gr.Markdown(f"### {ex['name']}")
350
+ gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
351
+
352
+ # Show instructions if they exist
353
+ if "instructions" in ex:
354
+ gr.Markdown(f"**Instructions:** {ex['instructions']}")
355
+
356
+
357
+ if i < len(examples) - 1: # Don't add separator after the last example
358
+ gr.Markdown("---")
359
+
360
+ # Set up event handler for the main inference
361
+ run_button.click(
362
+ fn=run_inference,
363
+ inputs=[
364
+ image_input, model_choice, inference_type,
365
+ eps_slider, iterations_slider,
366
+ initial_noise_slider, diffusion_noise_slider,
367
+ step_size_slider, layer_choice
368
+ ],
369
+ outputs=[output_image, output_frames]
370
+ )
371
+
372
+ # Toggle parameters visibility
373
+ def toggle_params():
374
+ return gr.Group(visible=True)
375
+
376
+ params_button.click(
377
+ fn=toggle_params,
378
+ outputs=[params_section]
379
+ )
380
+
381
+ # About section
382
+ gr.Markdown("""
383
+ ## About Generative Inference
384
+
385
+ Generative inference is a technique that reveals how neural networks perceive visual stimuli. This demo primarily uses the Prior-Guided Drift Diffusion method.
386
+
387
+ ### Prior-Guided Drift Diffusion
388
+ Moving away from a noisy representation of the input images
389
+
390
+ ### IncreaseConfidence
391
+ Moving away from the least likely class identified at iteration 0 (fast perception)
392
+
393
+ ### Parameters:
394
+ - **Drift Noise**: Controls the amount of noise added to the image at the beginning
395
+ - **Diffusion Noise**: Controls the amount of noise added at each optimization step
396
+ - **Update Rate**: Learning rate for the optimization process
397
+ - **Number of Iterations**: How many optimization steps to perform
398
+ - **Model Layer**: Select a specific layer of the ResNet50 model to extract features from
399
+ - **Epsilon (Stimulus Fidelity)**: Controls the size of perturbation during optimization
400
+
401
+ **Generative Inference was developed by [Tahereh Toosi](https://toosi.github.io).**
402
+ """)
403
+
404
+ # Launch the demo
405
+ if __name__ == "__main__":
406
+ print(f"Starting server on port {args.port}")
407
+ demo.launch(
408
+ server_name="0.0.0.0",
409
+ server_port=args.port,
410
+ share=False,
411
+ debug=True
412
+ )
face_vase.png ADDED
huggingface-metadata.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "title": "Generative Inference Demo",
3
+ "emoji": "🧠",
4
+ "colorFrom": "indigo",
5
+ "colorTo": "purple",
6
+ "sdk": "gradio",
7
+ "sdk_version": "3.32.0",
8
+ "app_file": "app.py",
9
+ "pinned": false,
10
+ "license": "mit"
11
+ }
inference.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from torchvision.models.resnet import ResNet50_Weights
7
+ from PIL import Image
8
+ import numpy as np
9
+ import os
10
+ import requests
11
+ import time
12
+ import copy
13
+ from collections import OrderedDict
14
+ from pathlib import Path
15
+
16
+ # Check for available hardware acceleration
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
20
+ device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs
21
+ else:
22
+ device = torch.device("cpu")
23
+ print(f"Using device: {device}")
24
+
25
+ # Constants
26
+ MODEL_URLS = {
27
+ 'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
28
+ 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
29
+ 'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/blob/main/100_checkpoint.pt'
30
+ }
31
+
32
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
33
+ IMAGENET_STD = [0.229, 0.224, 0.225]
34
+
35
+ # Define the transforms based on whether normalization is on or off
36
+ def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD):
37
+ if normalize:
38
+ return transforms.Compose([
39
+ transforms.Resize(input_size),
40
+ transforms.CenterCrop(input_size),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(norm_mean, norm_std),
43
+ ])
44
+ else:
45
+ return transforms.Compose([
46
+ transforms.Resize(input_size),
47
+ transforms.CenterCrop(input_size),
48
+ transforms.ToTensor(),
49
+ ])
50
+
51
+ # Default transform without normalization
52
+ transform = transforms.Compose([
53
+ transforms.Resize(224),
54
+ transforms.CenterCrop(224),
55
+ transforms.ToTensor(),
56
+ ])
57
+
58
+ normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
59
+
60
+ def extract_middle_layers(model, layer_index):
61
+ """
62
+ Extract a subset of the model up to a specific layer.
63
+
64
+ Args:
65
+ model: The neural network model
66
+ layer_index: String 'all' for the full model, or a layer identifier (string or int)
67
+ For ResNet: integers 0-8 representing specific layers
68
+ For ViT: strings like 'encoder.layers.encoder_layer_3'
69
+
70
+ Returns:
71
+ A modified model that outputs features from the specified layer
72
+ """
73
+ if isinstance(layer_index, str) and layer_index == 'all':
74
+ return model
75
+
76
+ # Special case for ViT's encoder layers with DataParallel wrapper
77
+ if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'):
78
+ try:
79
+ target_layer_idx = int(layer_index.split('_')[-1])
80
+
81
+ # Create a deep copy of the model to avoid modifying the original
82
+ new_model = copy.deepcopy(model)
83
+
84
+ # For models wrapped in DataParallel
85
+ if hasattr(new_model, 'module'):
86
+ # Create a subset of encoder layers up to the specified index
87
+ encoder_layers = nn.Sequential()
88
+ for i in range(target_layer_idx + 1):
89
+ layer_name = f"encoder_layer_{i}"
90
+ if hasattr(new_model.module.encoder.layers, layer_name):
91
+ encoder_layers.add_module(layer_name,
92
+ getattr(new_model.module.encoder.layers, layer_name))
93
+
94
+ # Replace the encoder layers with our truncated version
95
+ new_model.module.encoder.layers = encoder_layers
96
+
97
+ # Remove the heads since we're stopping at the encoder layer
98
+ new_model.module.heads = nn.Identity()
99
+
100
+ return new_model
101
+ else:
102
+ # Direct model access (not DataParallel)
103
+ encoder_layers = nn.Sequential()
104
+ for i in range(target_layer_idx + 1):
105
+ layer_name = f"encoder_layer_{i}"
106
+ if hasattr(new_model.encoder.layers, layer_name):
107
+ encoder_layers.add_module(layer_name,
108
+ getattr(new_model.encoder.layers, layer_name))
109
+
110
+ # Replace the encoder layers with our truncated version
111
+ new_model.encoder.layers = encoder_layers
112
+
113
+ # Remove the heads since we're stopping at the encoder layer
114
+ new_model.heads = nn.Identity()
115
+
116
+ return new_model
117
+
118
+ except (ValueError, IndexError) as e:
119
+ raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}")
120
+
121
+ # Handling for ViT whole blocks
122
+ elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')):
123
+ # Check for DataParallel wrapper
124
+ base_model = model.module if hasattr(model, 'module') else model
125
+
126
+ # Create a deep copy to avoid modifying the original
127
+ new_model = copy.deepcopy(model)
128
+ base_new_model = new_model.module if hasattr(new_model, 'module') else new_model
129
+
130
+ # Add the desired number of transformer blocks
131
+ if isinstance(layer_index, int):
132
+ # Truncate the blocks
133
+ base_new_model.blocks = base_new_model.blocks[:layer_index+1]
134
+
135
+ return new_model
136
+
137
+ else:
138
+ # Original ResNet/VGG handling
139
+ modules = list(model.named_children())
140
+ print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}")
141
+
142
+ cutoff_idx = next((i for i, (name, _) in enumerate(modules)
143
+ if name == str(layer_index)), None)
144
+
145
+ if cutoff_idx is not None:
146
+ # Keep modules up to and including the target
147
+ new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1]))
148
+ return new_model
149
+ else:
150
+ raise ValueError(f"Module {layer_index} not found in model")
151
+
152
+ # Get ImageNet labels
153
+ def get_imagenet_labels():
154
+ url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
155
+ response = requests.get(url)
156
+ if response.status_code == 200:
157
+ return response.json()
158
+ else:
159
+ raise RuntimeError("Failed to fetch ImageNet labels")
160
+
161
+ # Download model if needed
162
+ def download_model(model_type):
163
+ if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
164
+ return None # Use PyTorch's pretrained model
165
+
166
+ # Handle special case for face model
167
+ if model_type == 'resnet50_robust_face':
168
+ model_path = Path("models/resnet50_robust_face_100_checkpoint.pt")
169
+ else:
170
+ model_path = Path(f"models/{model_type}.pt")
171
+
172
+ if not model_path.exists():
173
+ print(f"Downloading {model_type} model...")
174
+ url = MODEL_URLS[model_type]
175
+ response = requests.get(url, stream=True)
176
+ if response.status_code == 200:
177
+ with open(model_path, 'wb') as f:
178
+ for chunk in response.iter_content(chunk_size=8192):
179
+ f.write(chunk)
180
+ print(f"Model downloaded and saved to {model_path}")
181
+ else:
182
+ raise RuntimeError(f"Failed to download model: {response.status_code}")
183
+ return model_path
184
+
185
+ class NormalizeByChannelMeanStd(nn.Module):
186
+ def __init__(self, mean, std):
187
+ super(NormalizeByChannelMeanStd, self).__init__()
188
+ if not isinstance(mean, torch.Tensor):
189
+ mean = torch.tensor(mean)
190
+ if not isinstance(std, torch.Tensor):
191
+ std = torch.tensor(std)
192
+ self.register_buffer("mean", mean)
193
+ self.register_buffer("std", std)
194
+
195
+ def forward(self, tensor):
196
+ return self.normalize_fn(tensor, self.mean, self.std)
197
+
198
+ def normalize_fn(self, tensor, mean, std):
199
+ """Differentiable version of torchvision.functional.normalize"""
200
+ # here we assume the color channel is at dim=1
201
+ mean = mean[None, :, None, None]
202
+ std = std[None, :, None, None]
203
+ return tensor.sub(mean).div(std)
204
+
205
+ class InferStep:
206
+ def __init__(self, orig_image, eps, step_size):
207
+ self.orig_image = orig_image
208
+ self.eps = eps
209
+ self.step_size = step_size
210
+
211
+ def project(self, x):
212
+ diff = x - self.orig_image
213
+ diff = torch.clamp(diff, -self.eps, self.eps)
214
+ return torch.clamp(self.orig_image + diff, 0, 1)
215
+
216
+ def step(self, x, grad):
217
+ l = len(x.shape) - 1
218
+ grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l))
219
+ scaled_grad = grad / (grad_norm + 1e-10)
220
+ return scaled_grad * self.step_size
221
+
222
+ def get_iterations_to_show(n_itr):
223
+ """Generate a dynamic list of iterations to show based on total iterations."""
224
+ if n_itr <= 50:
225
+ return [1, 5, 10, 20, 30, 40, 50, n_itr]
226
+ elif n_itr <= 100:
227
+ return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
228
+ elif n_itr <= 200:
229
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
230
+ elif n_itr <= 500:
231
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
232
+ else:
233
+ # For very large iterations, show more evenly distributed points
234
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
235
+ int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
236
+
237
+ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
238
+ """Generate inference configuration with customizable parameters.
239
+
240
+ Args:
241
+ inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion')
242
+ eps (float): Maximum perturbation size
243
+ n_itr (int): Number of iterations
244
+ step_size (float): Step size for each iteration
245
+ """
246
+
247
+ # Base configuration common to all inference types
248
+ config = {
249
+ 'loss_infer': inference_type, # How to guide the optimization
250
+ 'n_itr': n_itr, # Number of iterations
251
+ 'eps': eps, # Maximum perturbation size
252
+ 'step_size': step_size, # Step size for each iteration
253
+ 'diffusion_noise_ratio': 0.0, # No diffusion noise
254
+ 'initial_inference_noise_ratio': 0.0, # No initial noise
255
+ 'top_layer': 'all', # Use all layers of the model
256
+ 'inference_normalization': False, # Apply normalization during inference
257
+ 'recognition_normalization': False, # Apply normalization during recognition
258
+ 'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
259
+ 'misc_info': {'keep_grads': False} # Additional configuration
260
+ }
261
+
262
+ # Customize based on inference type
263
+ if inference_type == 'IncreaseConfidence':
264
+ config['loss_function'] = 'CE' # Cross Entropy
265
+
266
+ elif inference_type == 'Prior-Guided Drift Diffusion':
267
+ config['loss_function'] = 'MSE' # Mean Square Error
268
+ config['initial_inference_noise_ratio'] = 0.05 # Initial noise for diffusion
269
+ config['diffusion_noise_ratio'] = 0.01 # Add noise during diffusion
270
+
271
+ elif inference_type == 'GradModulation':
272
+ config['loss_function'] = 'CE' # Cross Entropy
273
+ config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength
274
+
275
+ elif inference_type == 'CompositionalFusion':
276
+ config['loss_function'] = 'CE' # Cross Entropy
277
+ config['misc_info']['positive_classes'] = [] # Classes to maximize
278
+ config['misc_info']['negative_classes'] = [] # Classes to minimize
279
+
280
+ return config
281
+
282
+ class GenerativeInferenceModel:
283
+ def __init__(self):
284
+ self.models = {}
285
+ self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
286
+ self.labels = get_imagenet_labels()
287
+
288
+ def verify_model_integrity(self, model, model_type):
289
+ """
290
+ Verify model integrity by running a test input through it.
291
+ Returns whether the model passes basic integrity check.
292
+ """
293
+ try:
294
+ print(f"\n=== Running model integrity check for {model_type} ===")
295
+ # Create a deterministic test input directly on the correct device
296
+ test_input = torch.zeros(1, 3, 224, 224, device=device)
297
+ test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
298
+
299
+ # Run forward pass
300
+ with torch.no_grad():
301
+ output = model(test_input)
302
+
303
+ # Check output shape
304
+ if output.shape != (1, 1000):
305
+ print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)")
306
+ return False
307
+
308
+ # Get top prediction
309
+ probs = torch.nn.functional.softmax(output, dim=1)
310
+ confidence, prediction = torch.max(probs, 1)
311
+
312
+ # Calculate basic statistics on output
313
+ mean = output.mean().item()
314
+ std = output.std().item()
315
+ min_val = output.min().item()
316
+ max_val = output.max().item()
317
+
318
+ print(f"Model integrity check results:")
319
+ print(f"- Output shape: {output.shape}")
320
+ print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
321
+ print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}")
322
+
323
+ # Basic sanity checks
324
+ if torch.isnan(output).any():
325
+ print("❌ Model produced NaN outputs")
326
+ return False
327
+
328
+ if output.std().item() < 0.1:
329
+ print("⚠️ Low output variance, model may not be discriminative")
330
+
331
+ print("✅ Model passes basic integrity check")
332
+ return True
333
+
334
+ except Exception as e:
335
+ print(f"❌ Model integrity check failed with error: {e}")
336
+ # Rather than failing completely, we'll continue
337
+ return True
338
+
339
+ def load_model(self, model_type):
340
+ """Load model from checkpoint or use pretrained model."""
341
+ if model_type in self.models:
342
+ print(f"Using cached {model_type} model")
343
+ return self.models[model_type]
344
+
345
+ # Record loading time for performance analysis
346
+ start_time = time.time()
347
+ model_path = download_model(model_type)
348
+
349
+ # Create a sequential model with normalizer and ResNet50
350
+ resnet = models.resnet50()
351
+ model = nn.Sequential(
352
+ self.normalizer, # Normalizer is part of the model sequence
353
+ resnet
354
+ )
355
+
356
+ # Load the model checkpoint
357
+ if model_path:
358
+ print(f"Loading {model_type} model from {model_path}...")
359
+ try:
360
+ checkpoint = torch.load(model_path, map_location=device)
361
+
362
+ # Print checkpoint structure for better understanding
363
+ print("\n=== Analyzing checkpoint structure ===")
364
+ if isinstance(checkpoint, dict):
365
+ print(f"Checkpoint contains keys: {list(checkpoint.keys())}")
366
+
367
+ # Examine 'model' structure if it exists
368
+ if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
369
+ model_dict = checkpoint['model']
370
+ # Get sample of keys to understand structure
371
+ first_keys = list(model_dict.keys())[:5]
372
+ print(f"'model' contains keys like: {first_keys}")
373
+
374
+ # Check for common prefixes in the model dict
375
+ prefixes = set()
376
+ for key in list(model_dict.keys())[:100]: # Check first 100 keys
377
+ parts = key.split('.')
378
+ if len(parts) > 1:
379
+ prefixes.add(parts[0])
380
+ if prefixes:
381
+ print(f"Common prefixes in model dict: {prefixes}")
382
+ else:
383
+ print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}")
384
+
385
+ # Handle different checkpoint formats
386
+ if 'model' in checkpoint:
387
+ # Format from madrylab robust models
388
+ state_dict = checkpoint['model']
389
+ print("Using 'model' key from checkpoint")
390
+ elif 'state_dict' in checkpoint:
391
+ state_dict = checkpoint['state_dict']
392
+ print("Using 'state_dict' key from checkpoint")
393
+ else:
394
+ # Direct state dict
395
+ state_dict = checkpoint
396
+ print("Using checkpoint directly as state_dict")
397
+
398
+ # Handle prefix in state dict keys for ResNet part
399
+ resnet_state_dict = {}
400
+ prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.']
401
+ resnet_keys = set(resnet.state_dict().keys())
402
+
403
+ # First check if we can find keys directly in the attacker.model path
404
+ print("\n=== Phase 1: Checking for specific model structures ===")
405
+
406
+ # Check for 'module.model' structure (seen in actual checkpoint)
407
+ module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')]
408
+ if module_model_keys:
409
+ print(f"Found 'module.model' structure with {len(module_model_keys)} parameters")
410
+ # Extract all parameters from module.model
411
+ for source_key, value in state_dict.items():
412
+ if source_key.startswith('module.model.'):
413
+ target_key = source_key[len('module.model.'):]
414
+ resnet_state_dict[target_key] = value
415
+
416
+ print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
417
+
418
+ # Check for 'attacker.model' structure
419
+ attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')]
420
+ if attacker_model_keys:
421
+ print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters")
422
+ # Extract all parameters from attacker.model
423
+ for source_key, value in state_dict.items():
424
+ if source_key.startswith('attacker.model.'):
425
+ target_key = source_key[len('attacker.model.'):]
426
+ resnet_state_dict[target_key] = value
427
+
428
+ print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model")
429
+
430
+ # Check if 'model' (not attacker.model) exists as a fallback
431
+ model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')]
432
+ if model_keys and len(resnet_state_dict) < len(resnet_keys):
433
+ print(f"Found additional 'model.' structure with {len(model_keys)} parameters")
434
+ # Try to complete missing parameters
435
+ for source_key, value in state_dict.items():
436
+ if source_key.startswith('model.'):
437
+ target_key = source_key[len('model.'):]
438
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
439
+ resnet_state_dict[target_key] = value
440
+
441
+ else:
442
+ # Check for other known structures
443
+ structure_found = False
444
+
445
+ # Check for 'model.' prefix
446
+ model_keys = [key for key in state_dict.keys() if key.startswith('model.')]
447
+ if model_keys:
448
+ print(f"Found 'model.' structure with {len(model_keys)} parameters")
449
+ for source_key, value in state_dict.items():
450
+ if source_key.startswith('model.'):
451
+ target_key = source_key[len('model.'):]
452
+ resnet_state_dict[target_key] = value
453
+ structure_found = True
454
+
455
+ # Check for ResNet parameters at the top level
456
+ top_level_resnet_keys = 0
457
+ for key in resnet_keys:
458
+ if key in state_dict:
459
+ top_level_resnet_keys += 1
460
+
461
+ if top_level_resnet_keys > 0:
462
+ print(f"Found {top_level_resnet_keys} ResNet parameters at top level")
463
+ for target_key in resnet_keys:
464
+ if target_key in state_dict:
465
+ resnet_state_dict[target_key] = state_dict[target_key]
466
+ structure_found = True
467
+
468
+ # If no structure was recognized, try the prefix mapping approach
469
+ if not structure_found:
470
+ print("No standard model structure found, trying prefix mappings...")
471
+ for target_key in resnet_keys:
472
+ for prefix in prefixes_to_try:
473
+ source_key = prefix + target_key
474
+ if source_key in state_dict:
475
+ resnet_state_dict[target_key] = state_dict[source_key]
476
+ break
477
+
478
+ # If we still can't find enough keys, try a final approach of removing prefixes
479
+ if len(resnet_state_dict) < len(resnet_keys):
480
+ print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...")
481
+
482
+ # Track matches found through prefix removal
483
+ prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']}
484
+ layer_matches = {} # Track matches by layer type
485
+
486
+ # Count parameter keys by layer type for analysis
487
+ for key in resnet_keys:
488
+ layer_name = key.split('.')[0] if '.' in key else key
489
+ if layer_name not in layer_matches:
490
+ layer_matches[layer_name] = {'total': 0, 'matched': 0}
491
+ layer_matches[layer_name]['total'] += 1
492
+
493
+ # Try keys with common prefixes
494
+ for source_key, value in state_dict.items():
495
+ # Skip if already found
496
+ target_key = source_key
497
+ matched_prefix = None
498
+
499
+ # Try removing various prefixes
500
+ for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']:
501
+ if source_key.startswith(prefix):
502
+ target_key = source_key[len(prefix):]
503
+ matched_prefix = prefix
504
+ break
505
+
506
+ # If the target key is in the ResNet keys, add it to the state dict
507
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
508
+ resnet_state_dict[target_key] = value
509
+
510
+ # Update match statistics
511
+ if matched_prefix:
512
+ prefix_matches[matched_prefix] += 1
513
+
514
+ # Update layer matches
515
+ layer_name = target_key.split('.')[0] if '.' in target_key else target_key
516
+ if layer_name in layer_matches:
517
+ layer_matches[layer_name]['matched'] += 1
518
+
519
+ # Print detailed prefix removal statistics
520
+ print("\n=== Prefix Removal Statistics ===")
521
+ total_matches = sum(prefix_matches.values())
522
+ print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)")
523
+
524
+ # Show matches by prefix
525
+ print("\nMatches by prefix:")
526
+ for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True):
527
+ if count > 0:
528
+ print(f" {prefix}: {count} parameters")
529
+
530
+ # Show matches by layer type
531
+ print("\nMatches by layer type:")
532
+ for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True):
533
+ match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0
534
+ print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)")
535
+
536
+ # Check for specific important layers (conv1, layer1, etc.)
537
+ critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
538
+ print("\nStatus of critical layers:")
539
+ for layer in critical_layers:
540
+ if layer in layer_matches:
541
+ match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100
542
+ status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE"
543
+ print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}")
544
+ else:
545
+ print(f" {layer}: Not found in model")
546
+
547
+ # Load the ResNet state dict
548
+ if resnet_state_dict:
549
+ try:
550
+ # Use strict=False to allow missing keys
551
+ result = resnet.load_state_dict(resnet_state_dict, strict=False)
552
+ missing_keys, unexpected_keys = result
553
+
554
+ # Generate detailed information with better formatting
555
+ loading_report = []
556
+ loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====")
557
+ loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}")
558
+ loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}")
559
+ loading_report.append(f"Missing keys: {len(missing_keys):,} parameters")
560
+ loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters")
561
+
562
+ # Calculate percentage of parameters loaded
563
+ loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys)
564
+ loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100
565
+
566
+ # Determine loading success status
567
+ if loaded_percent >= 99.5:
568
+ status = "✅ COMPLETE - All important parameters loaded"
569
+ elif loaded_percent >= 90:
570
+ status = "🟡 PARTIAL - Most parameters loaded, should still function"
571
+ elif loaded_percent >= 50:
572
+ status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly"
573
+ else:
574
+ status = "❌ FAILED - Critical parameters missing, will not function properly"
575
+
576
+ loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)")
577
+ loading_report.append(f"Loading status: {status}")
578
+
579
+ # If loading is severely incomplete, fall back to PyTorch's pretrained model
580
+ if loaded_percent < 50:
581
+ loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.")
582
+ loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.")
583
+
584
+ # Create a new ResNet model with pretrained weights
585
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
586
+ model = nn.Sequential(self.normalizer, resnet)
587
+ loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
588
+
589
+ # Show missing keys by layer type
590
+ if missing_keys:
591
+ loading_report.append("\nMissing keys by layer type:")
592
+ layer_types = {}
593
+ for key in missing_keys:
594
+ # Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.)
595
+ parts = key.split('.')
596
+ if len(parts) > 0:
597
+ layer_type = parts[0]
598
+ if layer_type not in layer_types:
599
+ layer_types[layer_type] = 0
600
+ layer_types[layer_type] += 1
601
+
602
+ # Add counts by layer type
603
+ for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
604
+ loading_report.append(f" {layer_type}: {count:,} parameters")
605
+
606
+ loading_report.append("\nFirst 10 missing keys:")
607
+ for i, key in enumerate(sorted(missing_keys)[:10]):
608
+ loading_report.append(f" {i+1}. {key}")
609
+
610
+ # Show unexpected keys if any
611
+ if unexpected_keys:
612
+ loading_report.append("\nFirst 10 unexpected keys:")
613
+ for i, key in enumerate(sorted(unexpected_keys)[:10]):
614
+ loading_report.append(f" {i+1}. {key}")
615
+
616
+ loading_report.append("========================================")
617
+
618
+ # Convert report to string and print it
619
+ report_text = "\n".join(loading_report)
620
+ print(report_text)
621
+
622
+ # Also save to a file for reference
623
+ os.makedirs("logs", exist_ok=True)
624
+ with open(f"logs/model_loading_{model_type}.log", "w") as f:
625
+ f.write(report_text)
626
+
627
+ # Look for normalizer parameters as well
628
+ if any(key.startswith('attacker.normalize.') for key in state_dict.keys()):
629
+ norm_state_dict = {}
630
+ for key, value in state_dict.items():
631
+ if key.startswith('attacker.normalize.'):
632
+ norm_key = key[len('attacker.normalize.'):]
633
+ norm_state_dict[norm_key] = value
634
+
635
+ if norm_state_dict:
636
+ try:
637
+ self.normalizer.load_state_dict(norm_state_dict, strict=False)
638
+ print("Successfully loaded normalizer parameters")
639
+ except Exception as e:
640
+ print(f"Warning: Could not load normalizer parameters: {e}")
641
+ except Exception as e:
642
+ print(f"Warning: Error loading ResNet parameters: {e}")
643
+ # Fall back to loading without normalizer
644
+ model = resnet # Use just the ResNet model without normalizer
645
+ except Exception as e:
646
+ print(f"Error loading model checkpoint: {e}")
647
+ # Fallback to PyTorch's pretrained model
648
+ print("Falling back to PyTorch's pretrained model")
649
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
650
+ model = nn.Sequential(self.normalizer, resnet)
651
+ else:
652
+ # Fallback to PyTorch's pretrained model
653
+ print("No checkpoint available, using PyTorch's pretrained model")
654
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
655
+ model = nn.Sequential(self.normalizer, resnet)
656
+
657
+ model = model.to(device)
658
+ model.eval() # Set to evaluation mode
659
+
660
+ # Verify model integrity
661
+ self.verify_model_integrity(model, model_type)
662
+
663
+ # Store the model for future use
664
+ self.models[model_type] = model
665
+ end_time = time.time()
666
+ load_time = end_time - start_time
667
+ print(f"Model {model_type} loaded in {load_time:.2f} seconds")
668
+ return model
669
+
670
+ def inference(self, image, model_type, config):
671
+ """Run generative inference on the image."""
672
+ # Time the entire inference process
673
+ inference_start = time.time()
674
+
675
+ # Load model if not already loaded
676
+ model = self.load_model(model_type)
677
+
678
+ # Check if image is a file path
679
+ if isinstance(image, str):
680
+ if os.path.exists(image):
681
+ image = Image.open(image).convert('RGB')
682
+ else:
683
+ raise ValueError(f"Image path does not exist: {image}")
684
+ elif isinstance(image, torch.Tensor):
685
+ raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor")
686
+
687
+ # Prepare image tensor - match original code's conditional transform
688
+ load_start = time.time()
689
+ use_norm = config['inference_normalization'] == 'on'
690
+ custom_transform = get_transform(
691
+ input_size=224,
692
+ normalize=use_norm,
693
+ norm_mean=IMAGENET_MEAN,
694
+ norm_std=IMAGENET_STD
695
+ )
696
+
697
+ # Special handling for GradModulation as in original
698
+ if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']:
699
+ grad_modulation = config['misc_info']['grad_modulation']
700
+ image_tensor = custom_transform(image).unsqueeze(0).to(device)
701
+ image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device)
702
+ else:
703
+ image_tensor = custom_transform(image).unsqueeze(0).to(device)
704
+
705
+ image_tensor.requires_grad = True
706
+ print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds")
707
+
708
+ # Check model structure
709
+ is_sequential = isinstance(model, nn.Sequential)
710
+
711
+ # Get original predictions
712
+ with torch.no_grad():
713
+ # If the model is sequential with a normalizer, skip the normalization step
714
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
715
+ print("Model is sequential with normalization")
716
+ # Get the core model part (typically at index 1 in Sequential)
717
+ core_model = model[1]
718
+ if config['inference_normalization']:
719
+ output_original = model(image_tensor) # Model includes normalization
720
+ else:
721
+ output_original = core_model(image_tensor) # Model includes normalization
722
+
723
+ else:
724
+ print("Model is not sequential with normalization")
725
+ # Use manual normalization for non-sequential models
726
+ if config['inference_normalization']:
727
+ normalized_tensor = normalize_transform(image_tensor)
728
+ output_original = model(normalized_tensor)
729
+ else:
730
+ output_original = model(image_tensor)
731
+ core_model = model
732
+
733
+ probs_orig = F.softmax(output_original, dim=1)
734
+ conf_orig, classes_orig = torch.max(probs_orig, 1)
735
+
736
+ # Get least confident classes
737
+ _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False)
738
+
739
+ # Initialize inference step
740
+ infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
741
+
742
+ # Storage for inference steps
743
+ # Create a new tensor that requires gradients
744
+ x = image_tensor.clone().detach().requires_grad_(True)
745
+ all_steps = [image_tensor[0].detach().cpu()]
746
+
747
+ # For Prior-Guided Drift Diffusion, extract selected layer and initialize with noisy features
748
+ noisy_features = None
749
+ layer_model = None
750
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
751
+ print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...")
752
+
753
+ # Extract model up to the specified layer
754
+ try:
755
+ # Start by finding the actual model to use
756
+ base_model = model
757
+
758
+ # Handle DataParallel wrapper if present
759
+ if hasattr(base_model, 'module'):
760
+ base_model = base_model.module
761
+
762
+ # Log the initial model structure
763
+ print(f"DEBUG - Initial model structure: {type(base_model)}")
764
+
765
+ # If we have a Sequential model (which is likely our normalizer + model structure)
766
+ if isinstance(base_model, nn.Sequential):
767
+ print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children")
768
+
769
+ # If this is our NormalizeByChannelMeanStd + ResNet pattern
770
+ if len(list(base_model.children())) >= 2:
771
+ # The actual ResNet model is the second component (index 1)
772
+ actual_model = list(base_model.children())[1]
773
+ print(f"DEBUG - Using ResNet component: {type(actual_model)}")
774
+ print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}")
775
+
776
+ # Extract from the actual ResNet
777
+ layer_model = extract_middle_layers(actual_model, config['top_layer'])
778
+ else:
779
+ # Just a single component Sequential
780
+ layer_model = extract_middle_layers(base_model, config['top_layer'])
781
+ else:
782
+ # Not Sequential, might be direct model
783
+ print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}")
784
+ layer_model = extract_middle_layers(base_model, config['top_layer'])
785
+
786
+ print(f"Successfully extracted model up to layer: {config['top_layer']}")
787
+ except ValueError as e:
788
+ print(f"Layer extraction failed: {e}. Using full model.")
789
+ layer_model = model
790
+
791
+ # Add noise to the image - exactly match original code
792
+ added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device)
793
+ noisy_image_tensor = image_tensor + added_noise
794
+
795
+ # Compute noisy features - simplified to match original code
796
+ noisy_features = layer_model(noisy_image_tensor)
797
+
798
+ print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}")
799
+
800
+ # Main inference loop
801
+ print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
802
+ loop_start = time.time()
803
+ for i in range(config['n_itr']):
804
+ # Reset gradients
805
+ x.grad = None
806
+
807
+ # Forward pass - use layer_model for Prior-Guided Drift Diffusion, full model otherwise
808
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
809
+ # Use the extracted layer model for Prior-Guided Drift Diffusion
810
+ # In original code, normalization is handled at transform time, not during forward pass
811
+ output = layer_model(x)
812
+ else:
813
+ # Standard forward pass with full model
814
+ # Simplified to match original code's approach
815
+ output = model(x)
816
+
817
+ # Calculate loss and gradients based on inference type
818
+ try:
819
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
820
+ # Use MSE loss to match the noisy features
821
+ assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE"
822
+ if noisy_features is not None:
823
+ loss = F.mse_loss(output, noisy_features)
824
+ grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original
825
+ else:
826
+ raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
827
+
828
+ else: # Default 'IncreaseConfidence' approach
829
+ # Get the least confident classes
830
+ num_classes = min(10, least_confident_classes.size(1))
831
+ target_classes = least_confident_classes[0, :num_classes]
832
+
833
+ # Create targets for least confident classes
834
+ targets = torch.tensor([idx.item() for idx in target_classes], device=device)
835
+
836
+ # Use a combined loss to increase confidence
837
+ loss = 0
838
+ for target in targets:
839
+ # Create one-hot target
840
+ one_hot = torch.zeros_like(output)
841
+ one_hot[0, target] = 1
842
+ # Use loss to maximize confidence
843
+ loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
844
+
845
+ grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
846
+
847
+ if grad is None:
848
+ print("Warning: Direct gradient calculation failed")
849
+ # Fall back to random perturbation
850
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
851
+ x = infer_step.project(x + random_noise)
852
+ else:
853
+ # Update image with gradient - do this exactly as in original code
854
+ adjusted_grad = infer_step.step(x, grad)
855
+
856
+ # Add diffusion noise if specified
857
+ diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device)
858
+
859
+ # Apply gradient and noise in one operation before projecting, exactly as in original
860
+ x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
861
+
862
+ except Exception as e:
863
+ print(f"Error in gradient calculation: {e}")
864
+ # Fall back to random perturbation - match original code
865
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
866
+ x = infer_step.project(x.clone() + random_noise)
867
+
868
+ # Store step if in iterations_to_show
869
+ if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
870
+ all_steps.append(x[0].detach().cpu())
871
+
872
+ # Print some info about the inference
873
+ with torch.no_grad():
874
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
875
+ if config['inference_normalization']:
876
+ final_output = model(x)
877
+ else:
878
+ final_output = core_model(x)
879
+ else:
880
+ if config['inference_normalization']:
881
+ normalized_x = normalize_transform(x)
882
+ final_output = model(normalized_x)
883
+ else:
884
+ final_output = model(x)
885
+
886
+ final_probs = F.softmax(final_output, dim=1)
887
+ final_conf, final_classes = torch.max(final_probs, 1)
888
+
889
+ # Calculate timing information
890
+ loop_time = time.time() - loop_start
891
+ total_time = time.time() - inference_start
892
+ avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0
893
+
894
+ print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
895
+ print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
896
+ print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)")
897
+ print(f"Total inference time: {total_time:.2f} seconds")
898
+
899
+ # Return results in format compatible with both old and new code
900
+ return {
901
+ 'final_image': x[0].detach().cpu(),
902
+ 'steps': all_steps,
903
+ 'original_class': classes_orig.item(),
904
+ 'original_confidence': conf_orig.item(),
905
+ 'final_class': final_classes.item(),
906
+ 'final_confidence': final_conf.item()
907
+ }
908
+
909
+ # Utility function to show inference steps
910
+ def show_inference_steps(steps, figsize=(15, 10)):
911
+ import matplotlib.pyplot as plt
912
+
913
+ n_steps = len(steps)
914
+ fig, axes = plt.subplots(1, n_steps, figsize=figsize)
915
+
916
+ for i, step_img in enumerate(steps):
917
+ img = step_img.permute(1, 2, 0).numpy()
918
+ axes[i].imshow(img)
919
+ axes[i].set_title(f"Step {i}")
920
+ axes[i].axis('off')
921
+
922
+ plt.tight_layout()
923
+ return fig
logs/model_loading_resnet50_robust.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ===== MODEL LOADING REPORT: resnet50_robust =====
3
+ Total parameters in checkpoint: 320
4
+ Total parameters in model: 320
5
+ Missing keys: 0 parameters
6
+ Unexpected keys: 0 parameters
7
+ Successfully loaded: 320 parameters (100.0%)
8
+ Loading status: ✅ COMPLETE - All important parameters loaded
9
+ ========================================
logs/model_loading_resnet50_robust_face.log ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ===== MODEL LOADING REPORT: resnet50_robust_face =====
3
+ Total parameters in checkpoint: 320
4
+ Total parameters in model: 320
5
+ Missing keys: 267 parameters
6
+ Unexpected keys: 320 parameters
7
+ Successfully loaded: 0 parameters (0.0%)
8
+ Loading status: ❌ FAILED - Critical parameters missing, will not function properly
9
+
10
+ ⚠️ WARNING: Loading from checkpoint is too incomplete.
11
+ ⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.
12
+ ✅ Successfully loaded PyTorch's pretrained ResNet50 model
13
+
14
+ Missing keys by layer type:
15
+ layer3: 95 parameters
16
+ layer2: 65 parameters
17
+ layer1: 50 parameters
18
+ layer4: 50 parameters
19
+ bn1: 4 parameters
20
+ fc: 2 parameters
21
+ conv1: 1 parameters
22
+
23
+ First 10 missing keys:
24
+ 1. bn1.bias
25
+ 2. bn1.running_mean
26
+ 3. bn1.running_var
27
+ 4. bn1.weight
28
+ 5. conv1.weight
29
+ 6. fc.bias
30
+ 7. fc.weight
31
+ 8. layer1.0.bn1.bias
32
+ 9. layer1.0.bn1.running_mean
33
+ 10. layer1.0.bn1.running_var
34
+
35
+ First 10 unexpected keys:
36
+ 1. model.bn1.bias
37
+ 2. model.bn1.num_batches_tracked
38
+ 3. model.bn1.running_mean
39
+ 4. model.bn1.running_var
40
+ 5. model.bn1.weight
41
+ 6. model.conv1.weight
42
+ 7. model.fc.bias
43
+ 8. model.fc.weight
44
+ 9. model.layer1.0.bn1.bias
45
+ 10. model.layer1.0.bn1.num_batches_tracked
46
+ ========================================
logs/model_loading_robust_resnet50.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ===== MODEL LOADING REPORT: robust_resnet50 =====
3
+ Total parameters in checkpoint: 320
4
+ Total parameters in model: 320
5
+ Missing keys: 0 parameters
6
+ Unexpected keys: 0 parameters
7
+ Successfully loaded: 320 parameters (100.0%)
8
+ Loading status: ✅ COMPLETE - All important parameters loaded
9
+ ========================================
logs/model_loading_standard_resnet50.log ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ===== MODEL LOADING REPORT: standard_resnet50 =====
3
+ Total parameters in checkpoint: 320
4
+ Total parameters in model: 320
5
+ Missing keys: 0 parameters
6
+ Unexpected keys: 0 parameters
7
+ Successfully loaded: 320 parameters (100.0%)
8
+ Loading status: ✅ COMPLETE - All important parameters loaded
9
+ ========================================
models/resnet50_robust.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:380b14e6f9750bffa1447cf7017f65da4dc5ce71a3dd112f107515dcf7b14d9d
3
+ size 204818947
models/resnet50_robust_face_100_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c48a5c16ca0d5ac4cb20f1b98e2128838746f18b658728ac661f1ffd589c37bf
3
+ size 196695413
models/robust_resnet50.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:380b14e6f9750bffa1447cf7017f65da4dc5ce71a3dd112f107515dcf7b14d9d
3
+ size 204818947
models/standard_resnet50.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72d4a99582db5d7fa86c3fd2a089f0bfd6a10f69d635bca51f6ad72ac6b458f0
3
+ size 204818947
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pillow
5
+ gradio
6
+ matplotlib
7
+ requests
8
+ tqdm
9
+ huggingface_hub
stimuli/Confetti_illusion.png ADDED

Git LFS Details

  • SHA256: ca3230d513cea1f14c40f16ec1a67013c4f3d011599502dd3360d42798726a1f
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
stimuli/CornsweetBlock.png ADDED

Git LFS Details

  • SHA256: 7de2b448a2c55bd7da84ab59a8f1277a6f449ca25b7628f3afac718c101fc80f
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB
stimuli/EhresteinSingleColor.png ADDED

Git LFS Details

  • SHA256: 7fb3280fdf4aca8dc896b6fa6663464380a00733b9c9961da4190d6ea8ee44e5
  • Pointer size: 131 Bytes
  • Size of remote file: 263 kB
stimuli/GroupingByContinuity.png ADDED
stimuli/Kanizsa_square.jpg ADDED
stimuli/Neon_Color_Circle.jpg ADDED

Git LFS Details

  • SHA256: d010ceedd0041254905de0ed825ad9669fa2a2635527913f8a7563fa08e8e017
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
stimuli/face_vase.png ADDED
stimuli/figure_ground.png ADDED

Git LFS Details

  • SHA256: adce211b50da48f88a1cdde467e5fddade8da98f2bcc4c00d625852d7cc15fd2
  • Pointer size: 131 Bytes
  • Size of remote file: 301 kB