ttoosi commited on
Commit
67f0c8b
·
2 Parent(s): fd13ca9f45a5a8

Merge local demo into new Space history

Browse files
Files changed (10) hide show
  1. .gitattributes +5 -0
  2. .gitignore +41 -0
  3. Dockerfile +31 -0
  4. LICENSE +21 -0
  5. README.md +125 -8
  6. app.py +905 -0
  7. face_vase.png +0 -0
  8. huggingface-metadata.json +11 -0
  9. inference.py +1049 -0
  10. requirements.txt +9 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ stimuli/figure_ground.png filter=lfs diff=lfs merge=lfs -text
37
+ stimuli/NeonColorSaeedi.jpg filter=lfs diff=lfs merge=lfs -text
38
+ stimuli/Neon_Color_Circle.jpg filter=lfs diff=lfs merge=lfs -text
39
+ stimuli/Confetti_illusion.png filter=lfs diff=lfs merge=lfs -text
40
+ stimuli/CornsweetBlock.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache files
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+
8
+ # Log files
9
+ logs/
10
+ *.log
11
+
12
+ # Model files (will be downloaded at runtime)
13
+ models/
14
+ *.pt
15
+ *.pth
16
+ *.ckpt
17
+
18
+ # Large image files (uploaded via web interface)
19
+ stimuli/*.jpg
20
+ stimuli/*.png
21
+ stimuli/*.jpeg
22
+
23
+ # Environment files
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+
29
+ # IDE files
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+
35
+ # OS files
36
+ .DS_Store
37
+ Thumbs.db
38
+
39
+ # Temporary files
40
+ *.tmp
41
+ *.temp
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /code
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ git \
9
+ libgl1-mesa-glx \
10
+ libglib2.0-0 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy only requirements to leverage Docker caching
14
+ COPY ./requirements.txt /code/requirements.txt
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
18
+
19
+ # Copy all code and data
20
+ COPY . /code/
21
+
22
+ # Create necessary directories
23
+ RUN mkdir -p /code/models
24
+ RUN mkdir -p /code/stimuli
25
+
26
+ # Make sure stimuli and models are writable
27
+ RUN chmod -R 777 /code/models
28
+ RUN chmod -R 777 /code/stimuli
29
+
30
+ # Set up the command to run the app
31
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 GenerativeInferenceDemo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,131 @@
1
  ---
2
- title: Hallucination Prediction Simple
3
- emoji: 👀
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.11.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
- short_description: Simplified interface for generative inference by TToosi
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Human Hallucination Prediction
3
+ emoji: 👁️
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.23.1
 
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Human Hallucination Prediction
14
+
15
+ This Gradio demo predicts whether humans will experience visual hallucinations or illusions when viewing specific images. Using adversarially robust neural networks, this tool can forecast perceptual phenomena like illusory contours, figure-ground reversals, and other Gestalt effects before humans report them.
16
+
17
+ ## How It Works
18
+
19
+ This tool uses **generative inference** with adversarially robust neural networks to predict human visual hallucinations. Robust models trained with adversarial examples develop more human-like perceptual biases, allowing them to predict when humans will perceive:
20
+
21
+ - **Illusory contours** (Kanizsa shapes, Ehrenstein illusion)
22
+ - **Figure-ground ambiguity** (Rubin's vase, bistable images)
23
+ - **Color spreading effects** (Neon color illusion)
24
+ - **Gestalt grouping** (Continuity, proximity)
25
+ - **Brightness illusions** (Cornsweet effect)
26
+
27
+ ## Features
28
+
29
+ - **Predict hallucinations** from uploaded images or example illusions
30
+ - **Visualize the prediction process** step-by-step
31
+ - **Compare different models** (robust vs. standard)
32
+ - **Adjust prediction parameters** for different perceptual phenomena
33
+ - **Pre-configured examples** of classic visual illusions
34
+
35
+ ## Usage
36
+
37
+ 1. **Choose an input**: Pick a pre-configured example illusion from the dropdown, or upload your own image.
38
+ 2. **Load Parameters**: Click **"Load Parameters"** to fill in optimal prediction settings for that example (or adjust them manually).
39
+ 3. **Select the affected part of the visual field**: Set where the model should focus by clicking on the input image or the mask preview to define the mask center, then adjust **Mask center X/Y**, **Mask radius**, and **Mask sigma** in the Adaptive Gaussian mask section if needed. The preview circle shows the region that will receive stronger constraint during inference.
40
+ 4. **Run inference**: Click **"Run Generative Inference"** to start the prediction. Progress and intermediate steps are shown in the interface.
41
+ 5. **View results**: Inspect the predicted perceptual effects, visualizations, and any generated outputs in the result panels.
42
+
43
+ ## Scientific Background
44
+
45
+ This demo is based on research showing that adversarially robust neural networks develop perceptual representations similar to human vision. By using generative inference (optimizing images to maximize model confidence), we can reveal what perceptual structures the network expects to see—which often matches what humans hallucinate or perceive in ambiguous images.
46
+
47
+ ## Prerequisites
48
+
49
+ - **Python** 3.8 or higher
50
+ - **pip** (Python package manager)
51
+
52
+ ## Installation
53
+
54
+ 1. **Clone the repository**
55
+
56
+ ```bash
57
+ git clone https://huggingface.co/spaces/ttoosi/Human_Hallucination_Prediction
58
+ cd Human_Hallucination_Prediction
59
+ ```
60
+
61
+ 2. **Create a virtual environment** (recommended)
62
+
63
+ ```bash
64
+ python -m venv venv
65
+ source venv/bin/activate # On Windows: venv\Scripts\activate
66
+ ```
67
+
68
+ 3. **Install dependencies**
69
+
70
+ ```bash
71
+ pip install -r requirements.txt
72
+ ```
73
+
74
+ 4. **Run the app**
75
+
76
+ ```bash
77
+ python app.py
78
+ ```
79
+
80
+ Optional: specify a port with `--port` (default is 7860):
81
+
82
+ ```bash
83
+ python app.py --port 8861
84
+ ```
85
+
86
+ The web app will be available at **http://localhost:7860** (or the port you specified).
87
+
88
+ **Note:** Model weights (e.g. robust ResNet50) are downloaded automatically from Hugging Face on first run and cached in the `models/` directory. The app also creates a `stimuli/` directory for example images.
89
+
90
+ ### Running with Docker
91
+
92
+ ```bash
93
+ docker build -t human-hallucination-prediction .
94
+ docker run -p 7860:7860 human-hallucination-prediction
95
+ ```
96
+
97
+ Then open http://localhost:7860 in your browser.
98
+
99
+ ## The Prediction Process
100
+
101
+ 1. **Input**: You provide an ambiguous or illusion-inducing image (or use a built-in example).
102
+ 2. **Generative inference**: The adversarially robust network iteratively updates the image to maximize its confidence, guided by your chosen parameters (model, layer, noise, step size, etc.).
103
+ 3. **Prediction**: The resulting changes reveal the perceptual structures the network expects—which correspond to what humans tend to hallucinate or perceive in such images.
104
+ 4. **Visualization**: The interface shows the predicted hallucination and intermediate steps as the optimization runs.
105
+
106
+ ## Models
107
+
108
+ - **Robust ResNet50**: Trained with adversarial examples (ε=3.0), develops human-like perceptual biases
109
+ - **Standard ResNet50**: Standard ImageNet training without adversarial robustness
110
+
111
+ ## Citation
112
+
113
+ If you use this work in your research, please cite:
114
+
115
+ ```bibtex
116
+ @article{toosi2024hallucination,
117
+ title={Predicting Human Visual Hallucinations with Robust Neural Networks},
118
+ author={Toosi, Tahereh},
119
+ year={2024}
120
+ }
121
+ ```
122
+
123
+ ## About
124
+
125
+ **Developed by [Tahereh Toosi](https://toosi.github.io)**
126
+
127
+ This demo demonstrates how adversarially robust neural networks can predict human perceptual hallucinations before they occur.
128
+
129
+ ## License
130
+
131
+ This project is licensed under the MIT License - see the LICENSE file for details.
app.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+ try:
6
+ from spaces import GPU
7
+ except ImportError:
8
+ # Define a no-op decorator if running locally
9
+ def GPU(func):
10
+ return func
11
+
12
+ import os
13
+ import re
14
+ import json
15
+ import argparse
16
+ from datetime import datetime
17
+ from inference import GenerativeInferenceModel, get_inference_configs, get_imagenet_labels
18
+
19
+ # Parse command line arguments
20
+ parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
21
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
22
+ args = parser.parse_args()
23
+
24
+ # Create model directories if they don't exist
25
+ os.makedirs("models", exist_ok=True)
26
+ os.makedirs("stimuli", exist_ok=True)
27
+ SAVED_RUNS_DIR = "saved_runs"
28
+ os.makedirs(SAVED_RUNS_DIR, exist_ok=True)
29
+
30
+ # Load ImageNet labels for biased-inference dropdown (1000 classes)
31
+ IMAGENET_LABELS = get_imagenet_labels()
32
+
33
+ # Check if running on Hugging Face Spaces
34
+ if "SPACE_ID" in os.environ:
35
+ default_port = int(os.environ.get("PORT", 7860))
36
+ else:
37
+ default_port = 8861 # Local default port
38
+
39
+ # Initialize model
40
+ model = GenerativeInferenceModel()
41
+
42
+ # Define example images and their parameters with updated values from the research
43
+ examples = [
44
+ {
45
+ "image": os.path.join("stimuli", "farm1.jpg"),
46
+ "name": "farm1",
47
+ "wiki": "https://en.wikipedia.org/wiki/Visual_perception",
48
+ "papers": [
49
+ "[Adversarially Robust Vision](https://github.com/MadryLab/robustness)",
50
+ "[Generative Inference](https://doi.org/10.1016/j.tics.2003.08.003)"
51
+ ],
52
+ "method": "Prior-Guided Drift Diffusion",
53
+ "reverse_diff": {
54
+ "model": "resnet50_robust",
55
+ "layer": "all",
56
+ "initial_noise": 0.0,
57
+ "diffusion_noise": 0.02,
58
+ "step_size": 1.0,
59
+ "iterations": 501,
60
+ "epsilon": 40.0
61
+ },
62
+ "inference_normalization": "off",
63
+ "use_adaptive_eps": False,
64
+ "use_adaptive_step": False,
65
+ "mask_center_x": 0.0,
66
+ "mask_center_y": 0.0,
67
+ "mask_radius": 0.2,
68
+ "mask_sigma": 0.3,
69
+ "eps_max_mult": 300.0,
70
+ "eps_min_mult": 1.0,
71
+ "step_max_mult": 10.0,
72
+ "step_min_mult": 1.0,
73
+ },
74
+ {
75
+ "image": os.path.join("stimuli", "ArtGallery1.jpg"),
76
+ "name": "ArtGallery1",
77
+ "wiki": "https://en.wikipedia.org/wiki/Visual_perception",
78
+ "papers": [
79
+ "[Adversarially Robust Vision](https://github.com/MadryLab/robustness)",
80
+ "[Generative Inference](https://doi.org/10.1016/j.tics.2003.08.003)"
81
+ ],
82
+ "method": "Prior-Guided Drift Diffusion",
83
+ "reverse_diff": {
84
+ "model": "resnet50_robust",
85
+ "layer": "layer4",
86
+ "initial_noise": 0.5,
87
+ "diffusion_noise": 0.002,
88
+ "step_size": 0.1,
89
+ "iterations": 501,
90
+ "epsilon": 40.0
91
+ },
92
+ "inference_normalization": "off",
93
+ "use_adaptive_eps": False,
94
+ "use_adaptive_step": True,
95
+ "mask_center_x": 0.0,
96
+ "mask_center_y": -1.0,
97
+ "mask_radius": 0.1,
98
+ "mask_sigma": 0.2,
99
+ "eps_max_mult": 30.0,
100
+ "eps_min_mult": 1.0,
101
+ "step_max_mult": 100.0,
102
+ "step_min_mult": 1.0,
103
+ },
104
+ {
105
+ "image": os.path.join("stimuli", "urbanoffice1.jpg"),
106
+ "name": "UrbanOffice1",
107
+ "wiki": "https://en.wikipedia.org/wiki/Visual_perception",
108
+ "papers": [
109
+ "[Adversarially Robust Vision](https://github.com/MadryLab/robustness)",
110
+ "[Generative Inference](https://doi.org/10.1016/j.tics.2003.08.003)"
111
+ ],
112
+ "method": "Prior-Guided Drift Diffusion",
113
+ "reverse_diff": {
114
+ "model": "resnet50_robust",
115
+ "layer": "all",
116
+ "initial_noise": 1.0,
117
+ "diffusion_noise": 0.002,
118
+ "step_size": 1.0,
119
+ "iterations": 500,
120
+ "epsilon": 40.0
121
+ },
122
+ "inference_normalization": "off",
123
+ "use_adaptive_eps": False,
124
+ "use_adaptive_step": True,
125
+ "mask_center_x": 0.5,
126
+ "mask_center_y": 0.0,
127
+ "mask_radius": 0.2,
128
+ "mask_sigma": 0.2,
129
+ "eps_max_mult": 20.0,
130
+ "eps_min_mult": 1.0,
131
+ "step_max_mult": 50.0,
132
+ "step_min_mult": 0.2,
133
+ },
134
+ {
135
+ "image": os.path.join("stimuli", "Neon_Color_Circle.jpg"),
136
+ "name": "Neon Color Spreading",
137
+ "wiki": "https://en.wikipedia.org/wiki/Neon_color_spreading",
138
+ "papers": [
139
+ "[Color Assimilation](https://doi.org/10.1016/j.visres.2000.200.1)",
140
+ "[Perceptual Filling-in](https://doi.org/10.1016/j.tics.2003.08.003)"
141
+ ],
142
+ "method": "Prior-Guided Drift Diffusion",
143
+ "reverse_diff": {
144
+ "model": "resnet50_robust",
145
+ "layer": "layer3",
146
+ "initial_noise": 0.8,
147
+ "diffusion_noise": 0.003,
148
+ "step_size": 1.0,
149
+ "iterations": 101,
150
+ "epsilon": 20.0
151
+ },
152
+ "use_adaptive_eps": False,
153
+ "use_adaptive_step": False,
154
+ "mask_center_x": 0.0,
155
+ "mask_center_y": 0.0,
156
+ "mask_radius": 0.2,
157
+ "mask_sigma": 1.0,
158
+ "eps_max_mult": 1.0,
159
+ "eps_min_mult": 1.0,
160
+ "step_max_mult": 1.0,
161
+ "step_min_mult": 1.0,
162
+ },
163
+ {
164
+ "image": os.path.join("stimuli", "Kanizsa_square.jpg"),
165
+ "name": "Kanizsa Square",
166
+ "wiki": "https://en.wikipedia.org/wiki/Kanizsa_triangle",
167
+ "papers": [
168
+ "[Gestalt Psychology](https://en.wikipedia.org/wiki/Gestalt_psychology)",
169
+ "[Neural Mechanisms](https://doi.org/10.1016/j.tics.2003.08.003)"
170
+ ],
171
+ "method": "Prior-Guided Drift Diffusion",
172
+ "reverse_diff": {
173
+ "model": "resnet50_robust",
174
+ "layer": "all",
175
+ "initial_noise": 0.0,
176
+ "diffusion_noise": 0.005,
177
+ "step_size": 0.64,
178
+ "iterations": 100,
179
+ "epsilon": 5.0
180
+ },
181
+ "use_adaptive_eps": False,
182
+ "use_adaptive_step": False,
183
+ "mask_center_x": 0.0,
184
+ "mask_center_y": 0.0,
185
+ "mask_radius": 0.2,
186
+ "mask_sigma": 1.0,
187
+ "eps_max_mult": 1.0,
188
+ "eps_min_mult": 1.0,
189
+ "step_max_mult": 1.0,
190
+ "step_min_mult": 1.0,
191
+ },
192
+ {
193
+ "image": os.path.join("stimuli", "CornsweetBlock.png"),
194
+ "name": "Cornsweet Illusion",
195
+ "wiki": "https://en.wikipedia.org/wiki/Cornsweet_illusion",
196
+ "papers": [
197
+ "[Brightness Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
198
+ "[Edge Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
199
+ ],
200
+ "instructions": "Both blocks are gray in color (the same), use your finger to cover the middle line. Hit 'Load Parameters' and then hit 'Run Generative Inference' to see how the model sees the blocks.",
201
+ "method": "Prior-Guided Drift Diffusion",
202
+ "reverse_diff": {
203
+ "model": "resnet50_robust",
204
+ "layer": "layer3",
205
+ "initial_noise": 0.5,
206
+ "diffusion_noise": 0.005,
207
+ "step_size": 0.8,
208
+ "iterations": 51,
209
+ "epsilon": 20.0
210
+ },
211
+ "use_adaptive_eps": False,
212
+ "use_adaptive_step": False,
213
+ "mask_center_x": 0.0,
214
+ "mask_center_y": 0.0,
215
+ "mask_radius": 0.2,
216
+ "mask_sigma": 1.0,
217
+ "eps_max_mult": 1.0,
218
+ "eps_min_mult": 1.0,
219
+ "step_max_mult": 1.0,
220
+ "step_min_mult": 1.0,
221
+ },
222
+ {
223
+ "image": os.path.join("stimuli", "face_vase.png"),
224
+ "name": "Rubin's Face-Vase (Object Prior)",
225
+ "wiki": "https://en.wikipedia.org/wiki/Rubin_vase",
226
+ "papers": [
227
+ "[Figure-Ground Perception](https://en.wikipedia.org/wiki/Figure-ground_(perception))",
228
+ "[Bistable Perception](https://doi.org/10.1016/j.tics.2003.08.003)"
229
+ ],
230
+ "method": "Prior-Guided Drift Diffusion",
231
+ "reverse_diff": {
232
+ "model": "resnet50_robust",
233
+ "layer": "avgpool",
234
+ "initial_noise": 0.9,
235
+ "diffusion_noise": 0.003,
236
+ "step_size": 0.58,
237
+ "iterations": 100,
238
+ "epsilon": 0.81
239
+ },
240
+ "use_adaptive_eps": False,
241
+ "use_adaptive_step": False,
242
+ "mask_center_x": 0.0,
243
+ "mask_center_y": 0.0,
244
+ "mask_radius": 0.2,
245
+ "mask_sigma": 1.0,
246
+ "eps_max_mult": 1.0,
247
+ "eps_min_mult": 1.0,
248
+ "step_max_mult": 1.0,
249
+ "step_min_mult": 1.0,
250
+ },
251
+ {
252
+ "image": os.path.join("stimuli", "Confetti_illusion.png"),
253
+ "name": "Confetti Illusion",
254
+ "wiki": "https://www.youtube.com/watch?v=SvEiEi8O7QE",
255
+ "papers": [
256
+ "[Color Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
257
+ "[Context Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
258
+ ],
259
+ "method": "Prior-Guided Drift Diffusion",
260
+ "reverse_diff": {
261
+ "model": "resnet50_robust",
262
+ "layer": "layer3",
263
+ "initial_noise": 0.1,
264
+ "diffusion_noise": 0.003,
265
+ "step_size": 0.5,
266
+ "iterations": 101,
267
+ "epsilon": 20.0
268
+ },
269
+ "use_adaptive_eps": False,
270
+ "use_adaptive_step": False,
271
+ "mask_center_x": 0.0,
272
+ "mask_center_y": 0.0,
273
+ "mask_radius": 0.2,
274
+ "mask_sigma": 1.0,
275
+ "eps_max_mult": 1.0,
276
+ "eps_min_mult": 1.0,
277
+ "step_max_mult": 1.0,
278
+ "step_min_mult": 1.0,
279
+ },
280
+ {
281
+ "image": os.path.join("stimuli", "EhresteinSingleColor.png"),
282
+ "name": "Ehrenstein Illusion",
283
+ "wiki": "https://en.wikipedia.org/wiki/Ehrenstein_illusion",
284
+ "papers": [
285
+ "[Subjective Contours](https://doi.org/10.1016/j.visres.2000.200.1)",
286
+ "[Neural Processing](https://doi.org/10.1016/j.tics.2003.08.003)"
287
+ ],
288
+ "method": "Prior-Guided Drift Diffusion",
289
+ "reverse_diff": {
290
+ "model": "resnet50_robust",
291
+ "layer": "layer3",
292
+ "initial_noise": 0.5,
293
+ "diffusion_noise": 0.005,
294
+ "step_size": 0.8,
295
+ "iterations": 101,
296
+ "epsilon": 20.0
297
+ },
298
+ "use_adaptive_eps": False,
299
+ "use_adaptive_step": False,
300
+ "mask_center_x": 0.0,
301
+ "mask_center_y": 0.0,
302
+ "mask_radius": 0.2,
303
+ "mask_sigma": 1.0,
304
+ "eps_max_mult": 1.0,
305
+ "eps_min_mult": 1.0,
306
+ "step_max_mult": 1.0,
307
+ "step_min_mult": 1.0,
308
+ },
309
+ {
310
+ "image": os.path.join("stimuli", "GroupingByContinuity.png"),
311
+ "name": "Grouping by Continuity",
312
+ "wiki": "https://en.wikipedia.org/wiki/Principles_of_grouping",
313
+ "papers": [
314
+ "[Gestalt Principles](https://en.wikipedia.org/wiki/Gestalt_psychology)",
315
+ "[Visual Organization](https://doi.org/10.1016/j.tics.2003.08.003)"
316
+ ],
317
+ "method": "Prior-Guided Drift Diffusion",
318
+ "reverse_diff": {
319
+ "model": "resnet50_robust",
320
+ "layer": "layer3",
321
+ "initial_noise": 0.0,
322
+ "diffusion_noise": 0.005,
323
+ "step_size": 0.4,
324
+ "iterations": 101,
325
+ "epsilon": 4.0
326
+ },
327
+ "use_adaptive_eps": False,
328
+ "use_adaptive_step": False,
329
+ "mask_center_x": 0.0,
330
+ "mask_center_y": 0.0,
331
+ "mask_radius": 0.2,
332
+ "mask_sigma": 1.0,
333
+ "eps_max_mult": 1.0,
334
+ "eps_min_mult": 1.0,
335
+ "step_max_mult": 1.0,
336
+ "step_min_mult": 1.0,
337
+ },
338
+ {
339
+ "image": os.path.join("stimuli", "figure_ground.png"),
340
+ "name": "Figure-Ground Illusion",
341
+ "wiki": "https://en.wikipedia.org/wiki/Figure-ground_(perception)",
342
+ "papers": [
343
+ "[Gestalt Principles](https://en.wikipedia.org/wiki/Gestalt_psychology)",
344
+ "[Perceptual Organization](https://doi.org/10.1016/j.tics.2003.08.003)"
345
+ ],
346
+ "method": "Prior-Guided Drift Diffusion",
347
+ "reverse_diff": {
348
+ "model": "resnet50_robust",
349
+ "layer": "layer3",
350
+ "initial_noise": 0.1,
351
+ "diffusion_noise": 0.003,
352
+ "step_size": 0.5,
353
+ "iterations": 101,
354
+ "epsilon": 3.0
355
+ },
356
+ "use_adaptive_eps": False,
357
+ "use_adaptive_step": False,
358
+ "mask_center_x": 0.0,
359
+ "mask_center_y": 0.0,
360
+ "mask_radius": 0.2,
361
+ "mask_sigma": 1.0,
362
+ "eps_max_mult": 1.0,
363
+ "eps_min_mult": 1.0,
364
+ "step_max_mult": 1.0,
365
+ "step_min_mult": 1.0,
366
+ }
367
+ ]
368
+
369
+ def _input_image_stem(image):
370
+ """Return a safe filename stem from the input image: known name or 'user_img'."""
371
+ if image is None:
372
+ return "user_img"
373
+ path = None
374
+ if isinstance(image, str) and (os.path.isfile(image) or os.path.exists(image)):
375
+ path = image
376
+ if isinstance(image, dict) and image.get("path") and os.path.exists(image.get("path", "")):
377
+ path = image["path"]
378
+ if path:
379
+ name = os.path.splitext(os.path.basename(path))[0]
380
+ # Safe for filenames: alphanumeric, underscore, hyphen only; max length
381
+ safe = re.sub(r"[^\w\-]", "_", name).strip("_") or "user_img"
382
+ return safe[:80] if len(safe) > 80 else safe
383
+ return "user_img"
384
+
385
+
386
+ def _get_image_path_for_stem(img):
387
+ """Extract file path from Gradio image value (path string, dict with path, or PIL) for stem tracking."""
388
+ if img is None:
389
+ return ""
390
+ if isinstance(img, str) and (os.path.isfile(img) or os.path.exists(img)):
391
+ return img
392
+ if isinstance(img, dict) and img.get("path"):
393
+ p = img["path"]
394
+ if isinstance(p, str) and os.path.exists(p):
395
+ return p
396
+ return ""
397
+
398
+
399
+ def _update_tracked_image_path(img):
400
+ """Keep path only when it's a known stimulus (e.g. from stimuli/); else '' so stem is 'user_img'."""
401
+ path = _get_image_path_for_stem(img)
402
+ if path and "stimuli" in path:
403
+ return path
404
+ return ""
405
+
406
+
407
+ def _config_to_json_serializable(c):
408
+ """Return a copy of config with only JSON-serializable values."""
409
+ if isinstance(c, dict):
410
+ return {k: _config_to_json_serializable(v) for k, v in c.items()}
411
+ if isinstance(c, (list, tuple)):
412
+ return [_config_to_json_serializable(x) for x in c]
413
+ if isinstance(c, (bool, int, float, str, type(None))):
414
+ return c
415
+ if hasattr(c, "item"): # e.g. numpy scalar
416
+ return c.item()
417
+ return str(c)
418
+
419
+
420
+ @GPU
421
+ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
422
+ initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3",
423
+ use_adaptive_eps=False, use_adaptive_step=False,
424
+ mask_center_x=0.0, mask_center_y=0.0, mask_radius=0.3, mask_sigma=0.2,
425
+ eps_max_mult=4.0, eps_min_mult=1.0, step_max_mult=4.0, step_min_mult=1.0,
426
+ use_biased_inference=False, biased_class_name="",
427
+ current_image_path=""):
428
+ # Check if image is provided
429
+ if image is None:
430
+ return None, [], "Please upload an image before running inference.", None
431
+
432
+ # Convert eps to float
433
+ eps = float(eps_value)
434
+ step_size_f = float(step_size)
435
+ # Coerce checkbox values (Gradio can send bool, None, or other)
436
+ use_adaptive_eps = bool(use_adaptive_eps) if use_adaptive_eps is not None else False
437
+ use_adaptive_step = bool(use_adaptive_step) if use_adaptive_step is not None else False
438
+
439
+ # Load inference configuration based on the selected type
440
+ config = get_inference_configs(inference_type=inference_type, eps=eps, n_itr=int(num_iterations))
441
+
442
+ # Handle Prior-Guided Drift Diffusion specific parameters
443
+ if inference_type == "Prior-Guided Drift Diffusion":
444
+ config['initial_inference_noise_ratio'] = float(initial_noise)
445
+ config['diffusion_noise_ratio'] = float(diffusion_noise)
446
+ config['step_size'] = step_size_f
447
+ config['top_layer'] = model_layer
448
+
449
+ # Inference normalization off (option removed from UI)
450
+ config['inference_normalization'] = 'off'
451
+ config['recognition_normalization'] = 'off'
452
+
453
+ # Adaptive epsilon (Gaussian mask)
454
+ if use_adaptive_eps:
455
+ config['adaptive_epsilon'] = {
456
+ 'enabled': True,
457
+ 'base_epsilon': eps,
458
+ 'center_x': float(mask_center_x),
459
+ 'center_y': float(mask_center_y),
460
+ 'flat_radius': float(mask_radius),
461
+ 'sigma': float(mask_sigma),
462
+ 'max_multiplier': float(eps_max_mult),
463
+ 'min_multiplier': float(eps_min_mult),
464
+ }
465
+ else:
466
+ config['adaptive_epsilon'] = None
467
+
468
+ # Adaptive step size (same Gaussian mask location/size, different multipliers)
469
+ if use_adaptive_step:
470
+ config['adaptive_step_size'] = {
471
+ 'enabled': True,
472
+ 'base_step_size': step_size_f,
473
+ 'center_x': float(mask_center_x),
474
+ 'center_y': float(mask_center_y),
475
+ 'flat_radius': float(mask_radius),
476
+ 'sigma': float(mask_sigma),
477
+ 'max_multiplier': float(step_max_mult),
478
+ 'min_multiplier': float(step_min_mult),
479
+ }
480
+ else:
481
+ config['adaptive_step_size'] = None
482
+
483
+ # Biased inference: bias perception toward a target ImageNet class
484
+ use_biased_inference = bool(use_biased_inference) if use_biased_inference is not None else False
485
+ biased_class_name = (biased_class_name or "").strip() if biased_class_name else ""
486
+ if use_biased_inference and biased_class_name:
487
+ config['biased_inference'] = {'enable': True, 'class': biased_class_name}
488
+ else:
489
+ config['biased_inference'] = config.get('biased_inference') or {'enable': False, 'class': None}
490
+
491
+ # Run generative inference
492
+ result = model.inference(image, model_type, config)
493
+
494
+ # Extract results based on return type
495
+ if isinstance(result, tuple):
496
+ # Old format returning (output_image, all_steps)
497
+ output_image, all_steps = result
498
+ else:
499
+ # New format returning dictionary
500
+ output_image = result['final_image']
501
+ all_steps = result['steps']
502
+
503
+ # Create animation frames
504
+ frames = []
505
+ for i, step_image in enumerate(all_steps):
506
+ # Convert tensor to PIL image
507
+ step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
508
+ frames.append(step_pil)
509
+
510
+ # Convert the final output image to PIL
511
+ final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
512
+
513
+ # Always save GIF and config and offer as downloads (browser will ask where to save)
514
+ save_status = ""
515
+ files_for_download = None
516
+ if frames:
517
+ # Use tracked path when available (e.g. from Load Parameters); else derive from image (PIL loses path)
518
+ stem = _input_image_stem(current_image_path if (current_image_path and current_image_path.strip()) else image)
519
+ unique_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{stem}"
520
+ gif_path = os.path.join(SAVED_RUNS_DIR, f"{unique_id}.gif")
521
+ config_path = os.path.join(SAVED_RUNS_DIR, f"{unique_id}_config.json")
522
+ try:
523
+ frames[0].save(
524
+ gif_path,
525
+ save_all=True,
526
+ append_images=frames[1:],
527
+ loop=0,
528
+ duration=200,
529
+ )
530
+ save_config = {
531
+ "model_type": model_type,
532
+ "input_image_name": stem,
533
+ **_config_to_json_serializable(config),
534
+ }
535
+ with open(config_path, "w") as f:
536
+ json.dump(save_config, f, indent=2)
537
+ files_for_download = [gif_path, config_path]
538
+ save_status = "**Download results** — Use the links below to save the GIF and config to your device (your browser may ask where to save)."
539
+ except Exception as e:
540
+ save_status = f"Save failed: {e}"
541
+
542
+ return final_image, frames, save_status, files_for_download
543
+
544
+ def _image_to_pil(img):
545
+ """Convert Gradio image value (PIL, numpy, path, or dict) to PIL Image; return None if invalid."""
546
+ if img is None:
547
+ return None
548
+ if isinstance(img, Image.Image):
549
+ return img
550
+ if isinstance(img, np.ndarray):
551
+ return Image.fromarray(img.astype(np.uint8))
552
+ if isinstance(img, dict) and "path" in img:
553
+ return Image.open(img["path"]).convert("RGB")
554
+ if isinstance(img, str) and os.path.exists(img):
555
+ return Image.open(img).convert("RGB")
556
+ return None
557
+
558
+
559
+ def _image_size(img):
560
+ """Return (width, height) from Gradio image value, or (1, 1) if unknown."""
561
+ pil = _image_to_pil(img)
562
+ if pil is not None:
563
+ return pil.size
564
+ return (1, 1)
565
+
566
+
567
+ def image_click_to_center(evt: gr.SelectData):
568
+ """Convert click (x, y) on image to normalized mask center (-1 to 1). Returns (center_x, center_y)."""
569
+ if evt.index is None or (isinstance(evt.index, (list, tuple)) and len(evt.index) < 2):
570
+ return 0.0, 0.0
571
+ x, y = float(evt.index[0]), float(evt.index[1])
572
+ w, h = _image_size(evt.value)
573
+ if w <= 0 or h <= 0:
574
+ return 0.0, 0.0
575
+ center_x = (x / w) * 2.0 - 1.0
576
+ center_y = (y / h) * 2.0 - 1.0
577
+ center_x = max(-1.0, min(1.0, center_x))
578
+ center_y = max(-1.0, min(1.0, center_y))
579
+ return center_x, center_y
580
+
581
+
582
+ def draw_mask_overlay(image, center_x, center_y, radius):
583
+ """Draw the Gaussian mask center and radius on a copy of the image. Returns PIL or None."""
584
+ pil = _image_to_pil(image)
585
+ if pil is None:
586
+ return None
587
+ img = pil.convert("RGB").copy()
588
+ w, h = img.size
589
+ cx_px = (float(center_x) + 1.0) / 2.0 * w
590
+ cy_px = (float(center_y) + 1.0) / 2.0 * h
591
+ radius_px = float(radius) * min(w, h) / 2.0
592
+ draw = ImageDraw.Draw(img)
593
+ # Circle for radius
594
+ draw.ellipse(
595
+ [cx_px - radius_px, cy_px - radius_px, cx_px + radius_px, cy_px + radius_px],
596
+ outline="#E11D48",
597
+ width=2 * max(2, min(w, h) // 150),
598
+ )
599
+ # Center dot
600
+ r = max(2, min(w, h) // 80)
601
+ draw.ellipse([cx_px - r, cy_px - r, cx_px + r, cy_px + r], fill="#E11D48", outline="#FFF")
602
+ return img
603
+
604
+
605
+ # Helper function to apply example parameters (adaptive mask off by default unless example defines it)
606
+ def apply_example(example):
607
+ rd = example["reverse_diff"]
608
+ mcx = example.get("mask_center_x", 0.0)
609
+ mcy = example.get("mask_center_y", 0.0)
610
+ mrad = example.get("mask_radius", 0.3)
611
+ mask_img = draw_mask_overlay(example["image"], mcx, mcy, mrad)
612
+ return [
613
+ example["image"],
614
+ rd.get("model", "resnet50_robust"),
615
+ example["method"],
616
+ rd["epsilon"],
617
+ rd["iterations"],
618
+ rd["initial_noise"],
619
+ rd["diffusion_noise"],
620
+ rd["step_size"],
621
+ rd["layer"],
622
+ example.get("use_adaptive_eps", False),
623
+ example.get("use_adaptive_step", False),
624
+ mcx,
625
+ mcy,
626
+ example.get("mask_radius", 0.3),
627
+ example.get("mask_sigma", 0.2),
628
+ example.get("eps_max_mult", 4.0),
629
+ example.get("eps_min_mult", 1.0),
630
+ example.get("step_max_mult", 4.0),
631
+ example.get("step_min_mult", 1.0),
632
+ example.get("use_biased_inference", False),
633
+ example.get("biased_class_name", ""),
634
+ example["image"], # keep path for save filename (e.g. UrbanOffice1 -> urbanoffice1)
635
+ mask_img,
636
+ gr.Group(visible=True),
637
+ ]
638
+
639
+ # Define the interface
640
+ with gr.Blocks(title="Human Hallucination Prediction", css="""
641
+ .purple-button {
642
+ background-color: #8B5CF6 !important;
643
+ color: white !important;
644
+ border: none !important;
645
+ }
646
+ .purple-button:hover {
647
+ background-color: #7C3AED !important;
648
+ }
649
+ """) as demo:
650
+ gr.Markdown("# Human Hallucination Prediction")
651
+ gr.Markdown("**Predict what visual hallucinations humans may experience** using neural networks.")
652
+
653
+ gr.Markdown("""
654
+ **How to predict hallucinations:**
655
+ 1. **Select an example image** below and click "Load Parameters" to set the prediction settings
656
+ 2. **Click "Run Generative Inference"** to predict what hallucination humans may perceive
657
+ 3. **View the prediction**: Watch as the model reveals the perceptual structures it expects—matching what humans typically hallucinate
658
+ 4. **You can upload your own images**
659
+ 5. **You can download the results** as a .gif file together with the configs.json
660
+ """)
661
+ with gr.Row():
662
+ with gr.Column(scale=1):
663
+ # Inputs (track path so save filenames use stimulus name when from example)
664
+ default_image_path = os.path.join("stimuli", "urbanoffice1.jpg")
665
+ image_input = gr.Image(label="Input Image (click to set mask center)", type="pil", value=default_image_path)
666
+ current_image_path_state = gr.State(value=default_image_path)
667
+ mask_preview = gr.Image(
668
+ label="Mask center preview (click to set center — circle shows mask)",
669
+ type="pil",
670
+ interactive=False,
671
+ )
672
+ # Run Inference button right below the image
673
+ run_button = gr.Button("🔮 Predict Hallucination", variant="primary", elem_classes="purple-button")
674
+
675
+ # Parameters toggle button
676
+ params_button = gr.Button("⚙️ Play with the parameters", variant="secondary")
677
+
678
+ # Parameters section (initially hidden)
679
+ with gr.Group(visible=False) as params_section:
680
+ with gr.Row():
681
+ model_choice = gr.Dropdown(
682
+ choices=["resnet50_robust", "standard_resnet50"], # "resnet50_robust_face" - hidden for deployment
683
+ value="resnet50_robust",
684
+ label="Model"
685
+ )
686
+
687
+ inference_type = gr.Dropdown(
688
+ choices=["Prior-Guided Drift Diffusion", "IncreaseConfidence"],
689
+ value="Prior-Guided Drift Diffusion",
690
+ label="Inference Method"
691
+ )
692
+
693
+ with gr.Row():
694
+ eps_slider = gr.Slider(minimum=0.0, maximum=40.0, value=40.0, step=0.01, label="Epsilon (Stimulus Fidelity)")
695
+ iterations_slider = gr.Slider(minimum=1, maximum=600, value=500, step=1, label="Number of Iterations")
696
+
697
+ with gr.Row():
698
+ initial_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.01,
699
+ label="Drift Noise")
700
+ diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=0.05, value=0.002, step=0.001,
701
+ label="Diffusion Noise")
702
+
703
+ with gr.Row():
704
+ step_size_slider = gr.Slider(minimum=0.01, maximum=2.0, value=1.0, step=0.01,
705
+ label="Update Rate")
706
+ layer_choice = gr.Dropdown(
707
+ choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
708
+ value="all",
709
+ label="Model Layer"
710
+ )
711
+
712
+ gr.Markdown("### 🎯 Adaptive Gaussian mask (spatially varying constraint)")
713
+ gr.Markdown("Define where on the image the mask is centered and how large its radius is. Coordinates: **-1** = left/top, **1** = right/bottom, **0** = center.")
714
+ with gr.Row():
715
+ use_adaptive_eps_check = gr.Checkbox(value=False, label="Use adaptive epsilon (stronger/weaker constraint by region)")
716
+ use_adaptive_step_check = gr.Checkbox(value=True, label="Use adaptive step size (stronger/weaker updates by region)")
717
+ with gr.Row():
718
+ mask_center_x_slider = gr.Slider(minimum=-1.0, maximum=1.0, value=0.5, step=0.05, label="Mask center X")
719
+ mask_center_y_slider = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Mask center Y")
720
+ with gr.Row():
721
+ mask_radius_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.2, step=0.01, label="Mask radius (flat region size)")
722
+ mask_sigma_slider = gr.Slider(minimum=0.05, maximum=1.0, value=0.2, step=0.01, label="Mask sigma (fall-off outside radius)")
723
+ with gr.Row():
724
+ eps_max_mult_slider = gr.Slider(minimum=0.1, maximum=350.0, value=20.0, step=0.1, label="Epsilon: multiplier at center")
725
+ eps_min_mult_slider = gr.Slider(minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Epsilon: multiplier at periphery")
726
+ with gr.Row():
727
+ step_max_mult_slider = gr.Slider(minimum=0.1, maximum=150.0, value=50.0, step=0.1, label="Step size: multiplier at center")
728
+ step_min_mult_slider = gr.Slider(minimum=0.1, maximum=10.0, value=0.2, step=0.1, label="Step size: multiplier at periphery")
729
+ gr.Markdown("### 🎯 Biased inference")
730
+ gr.Markdown("Bias the prediction toward a specific ImageNet category (1000 classes).")
731
+ with gr.Row():
732
+ use_biased_inference_check = gr.Checkbox(value=False, label="Use biased inference (bias toward a target class)")
733
+ biased_class_dropdown = gr.Dropdown(
734
+ choices=[("— No bias —", "")] + [(label, label) for label in sorted(IMAGENET_LABELS)],
735
+ value="",
736
+ label="Biased toward category",
737
+ allow_custom_value=False,
738
+ filterable=True,
739
+ )
740
+ with gr.Column(scale=2):
741
+ # Outputs
742
+ output_image = gr.Image(label="Predicted Hallucination")
743
+ output_frames = gr.Gallery(label="Hallucination Prediction Process", columns=5, rows=2)
744
+ save_status_md = gr.Markdown(value="")
745
+ download_files = gr.File(label="Download results (GIF + config)", file_count="multiple")
746
+
747
+ # Examples section with integrated explanations
748
+ gr.Markdown("## Examples")
749
+ gr.Markdown("Select an example and click Load Parameters to apply its settings")
750
+
751
+ # For each example, create a row with the image and explanation side by side
752
+ for i, ex in enumerate(examples):
753
+ with gr.Row():
754
+ # Left column for the image
755
+ with gr.Column(scale=1):
756
+ # Display the example image
757
+ example_img = gr.Image(value=ex["image"], type="filepath", label=f"{ex['name']}")
758
+ load_btn = gr.Button(f"Load Parameters", variant="primary")
759
+
760
+ # Set up the load button to apply this example's parameters
761
+ load_btn.click(
762
+ fn=lambda ex=ex: apply_example(ex),
763
+ outputs=[
764
+ image_input, model_choice, inference_type,
765
+ eps_slider, iterations_slider,
766
+ initial_noise_slider, diffusion_noise_slider,
767
+ step_size_slider, layer_choice,
768
+ use_adaptive_eps_check, use_adaptive_step_check,
769
+ mask_center_x_slider, mask_center_y_slider,
770
+ mask_radius_slider, mask_sigma_slider,
771
+ eps_max_mult_slider, eps_min_mult_slider,
772
+ step_max_mult_slider, step_min_mult_slider,
773
+ use_biased_inference_check, biased_class_dropdown,
774
+ current_image_path_state,
775
+ mask_preview,
776
+ params_section,
777
+ ],
778
+ )
779
+
780
+ # Right column for the explanation
781
+ with gr.Column(scale=2):
782
+ gr.Markdown(f"### {ex['name']}")
783
+ if ex["name"] not in ("farm1", "ArtGallery1", "UrbanOffice1"):
784
+ gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
785
+
786
+ # Show instructions if they exist
787
+ if "instructions" in ex:
788
+ gr.Markdown(f"**Instructions:** {ex['instructions']}")
789
+
790
+
791
+ if i < len(examples) - 1: # Don't add separator after the last example
792
+ gr.Markdown("---")
793
+
794
+ # Set up event handler for the main inference
795
+ run_button.click(
796
+ fn=run_inference,
797
+ inputs=[
798
+ image_input, model_choice, inference_type,
799
+ eps_slider, iterations_slider,
800
+ initial_noise_slider, diffusion_noise_slider,
801
+ step_size_slider, layer_choice,
802
+ use_adaptive_eps_check, use_adaptive_step_check,
803
+ mask_center_x_slider, mask_center_y_slider,
804
+ mask_radius_slider, mask_sigma_slider,
805
+ eps_max_mult_slider, eps_min_mult_slider,
806
+ step_max_mult_slider, step_min_mult_slider,
807
+ use_biased_inference_check, biased_class_dropdown,
808
+ current_image_path_state,
809
+ ],
810
+ outputs=[output_image, output_frames, save_status_md, download_files]
811
+ )
812
+
813
+ # Toggle parameters visibility
814
+ def toggle_params():
815
+ return gr.Group(visible=True)
816
+
817
+ params_button.click(
818
+ fn=toggle_params,
819
+ outputs=[params_section],
820
+ )
821
+
822
+ # Click on input image or mask preview to set Gaussian mask center
823
+ def _mask_preview_inputs():
824
+ return [image_input, mask_center_x_slider, mask_center_y_slider, mask_radius_slider]
825
+
826
+ image_input.select(
827
+ fn=image_click_to_center,
828
+ outputs=[mask_center_x_slider, mask_center_y_slider],
829
+ )
830
+ mask_preview.select(
831
+ fn=image_click_to_center,
832
+ outputs=[mask_center_x_slider, mask_center_y_slider],
833
+ )
834
+ # Update mask preview when image or center/radius change
835
+ image_input.change(
836
+ fn=draw_mask_overlay,
837
+ inputs=_mask_preview_inputs(),
838
+ outputs=[mask_preview],
839
+ )
840
+ # Keep tracked path for save filename: known stimulus name or clear so stem becomes 'user_img'
841
+ image_input.change(
842
+ fn=_update_tracked_image_path,
843
+ inputs=[image_input],
844
+ outputs=[current_image_path_state],
845
+ )
846
+ mask_center_x_slider.change(
847
+ fn=draw_mask_overlay,
848
+ inputs=_mask_preview_inputs(),
849
+ outputs=[mask_preview],
850
+ )
851
+ mask_center_y_slider.change(
852
+ fn=draw_mask_overlay,
853
+ inputs=_mask_preview_inputs(),
854
+ outputs=[mask_preview],
855
+ )
856
+ mask_radius_slider.change(
857
+ fn=draw_mask_overlay,
858
+ inputs=_mask_preview_inputs(),
859
+ outputs=[mask_preview],
860
+ )
861
+ # Populate mask preview on load
862
+ demo.load(
863
+ fn=draw_mask_overlay,
864
+ inputs=_mask_preview_inputs(),
865
+ outputs=[mask_preview],
866
+ )
867
+
868
+ # About section
869
+ gr.Markdown("""
870
+ ## 🧠 About Hallucination Prediction
871
+
872
+ This tool predicts human visual hallucinations using **generative inference** with adversarially robust neural networks. Robust models develop human-like perceptual biases, allowing them to forecast what perceptual structures humans will experience.
873
+
874
+ ### Prediction Methods:
875
+
876
+ **Prior-Guided Drift Diffusion** (Primary Method)
877
+ Starting from a noisy representation, the model converges toward what it expects to perceive—revealing predicted hallucinations
878
+
879
+ **IncreaseConfidence**
880
+ Moving away from unlikely interpretations to reveal the most probable perceptual experience
881
+
882
+ ### Parameters:
883
+ - **Drift Noise**: Initial uncertainty in the prediction process
884
+ - **Diffusion Noise**: Stochastic exploration during prediction
885
+ - **Update Rate**: Speed of convergence to the predicted hallucination
886
+ - **Number of Iterations**: How many prediction steps to perform
887
+ - **Model Layer**: Which perceptual level to predict from (early edges vs. high-level objects)
888
+ - **Epsilon (Stimulus Fidelity)**: How closely the prediction must match the input stimulus
889
+
890
+ ### Why Does This Work?
891
+
892
+ Adversarially robust neural networks develop perceptual representations similar to human vision. When we use generative inference to reveal what these networks "expect" to see, it matches what humans hallucinate in ambiguous images—allowing us to predict human perception.
893
+
894
+ **Developed by [Tahereh Toosi](https://toosi.github.io)**
895
+ """)
896
+
897
+ # Launch the demo
898
+ if __name__ == "__main__":
899
+ print(f"Starting server on port {args.port}")
900
+ demo.launch(
901
+ server_name="0.0.0.0",
902
+ server_port=args.port,
903
+ share=False,
904
+ debug=True
905
+ )
face_vase.png ADDED
huggingface-metadata.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "title": "Human Hallucination Prediction",
3
+ "emoji": "👁️",
4
+ "colorFrom": "blue",
5
+ "colorTo": "purple",
6
+ "sdk": "gradio",
7
+ "sdk_version": "5.23.1",
8
+ "app_file": "app.py",
9
+ "pinned": false,
10
+ "license": "mit"
11
+ }
inference.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from torchvision.models.resnet import ResNet50_Weights
7
+ from PIL import Image
8
+ import numpy as np
9
+ import os
10
+ import requests
11
+ import time
12
+ import copy
13
+ from collections import OrderedDict
14
+ from pathlib import Path
15
+
16
+ # Check for available hardware acceleration
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
20
+ device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs
21
+ else:
22
+ device = torch.device("cpu")
23
+ print(f"Using device: {device}")
24
+
25
+ # Constants
26
+ MODEL_URLS = {
27
+ # Use /resolve/main/ for raw file download; /blob/main/ returns HTML and causes "invalid load key, '<'"
28
+ 'resnet50_robust': 'https://huggingface.co/ttoosi/robust_resnet50_eps_3.00_epoch_best/resolve/main/resnet50_robust.pt',
29
+ 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
30
+ 'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/100_checkpoint.pt'
31
+ }
32
+
33
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
34
+ IMAGENET_STD = [0.229, 0.224, 0.225]
35
+
36
+ # Define the transforms based on whether normalization is on or off
37
+ def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD):
38
+ if normalize:
39
+ return transforms.Compose([
40
+ transforms.Resize(input_size),
41
+ transforms.CenterCrop(input_size),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(norm_mean, norm_std),
44
+ ])
45
+ else:
46
+ return transforms.Compose([
47
+ transforms.Resize(input_size),
48
+ transforms.CenterCrop(input_size),
49
+ transforms.ToTensor(),
50
+ ])
51
+
52
+ # Default transform without normalization
53
+ transform = transforms.Compose([
54
+ transforms.Resize(224),
55
+ transforms.CenterCrop(224),
56
+ transforms.ToTensor(),
57
+ ])
58
+
59
+ normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
60
+
61
+
62
+ def create_flat_top_gaussian_mask(image_shape, center_x, center_y, flat_radius, sigma, max_multiplier=4.0, min_multiplier=1.0, device=None):
63
+ """
64
+ Create a flat-top Gaussian epsilon/step-size mask for adaptive constraint application.
65
+ Coordinates are normalized: -1 = left/top edge, 1 = right/bottom edge, 0 = center.
66
+
67
+ Args:
68
+ image_shape: (batch, channels, height, width)
69
+ center_x: Center x (normalized -1 to 1)
70
+ center_y: Center y (normalized -1 to 1)
71
+ flat_radius: Radius of flat-top region (normalized)
72
+ sigma: Std for Gaussian fall-off outside flat region
73
+ max_multiplier: Multiplier inside flat region
74
+ min_multiplier: Multiplier far from center
75
+ device: torch device for the output tensor
76
+ Returns:
77
+ Tensor (1, 1, H, W)
78
+ """
79
+ _, _, H, W = image_shape
80
+ if device is None:
81
+ device = torch.device("cpu")
82
+ y_coords = torch.linspace(-1, 1, H, device=device)
83
+ x_coords = torch.linspace(-1, 1, W, device=device)
84
+ y_coords, x_coords = torch.meshgrid(y_coords, x_coords, indexing='ij')
85
+ dist = torch.sqrt((x_coords - center_x) ** 2 + (y_coords - center_y) ** 2)
86
+ multiplier = torch.where(
87
+ dist <= flat_radius,
88
+ torch.full_like(dist, max_multiplier),
89
+ min_multiplier + (max_multiplier - min_multiplier) * torch.exp(-((dist - flat_radius) ** 2) / (2 * sigma ** 2))
90
+ )
91
+ return multiplier.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
92
+
93
+
94
+ def extract_middle_layers(model, layer_index):
95
+ """
96
+ Extract a subset of the model up to a specific layer.
97
+
98
+ Args:
99
+ model: The neural network model
100
+ layer_index: String 'all' for the full model, or a layer identifier (string or int)
101
+ For ResNet: integers 0-8 representing specific layers
102
+ For ViT: strings like 'encoder.layers.encoder_layer_3'
103
+
104
+ Returns:
105
+ A modified model that outputs features from the specified layer
106
+ """
107
+ if isinstance(layer_index, str) and layer_index == 'all':
108
+ return model
109
+
110
+ # Special case for ViT's encoder layers with DataParallel wrapper
111
+ if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'):
112
+ try:
113
+ target_layer_idx = int(layer_index.split('_')[-1])
114
+
115
+ # Create a deep copy of the model to avoid modifying the original
116
+ new_model = copy.deepcopy(model)
117
+
118
+ # For models wrapped in DataParallel
119
+ if hasattr(new_model, 'module'):
120
+ # Create a subset of encoder layers up to the specified index
121
+ encoder_layers = nn.Sequential()
122
+ for i in range(target_layer_idx + 1):
123
+ layer_name = f"encoder_layer_{i}"
124
+ if hasattr(new_model.module.encoder.layers, layer_name):
125
+ encoder_layers.add_module(layer_name,
126
+ getattr(new_model.module.encoder.layers, layer_name))
127
+
128
+ # Replace the encoder layers with our truncated version
129
+ new_model.module.encoder.layers = encoder_layers
130
+
131
+ # Remove the heads since we're stopping at the encoder layer
132
+ new_model.module.heads = nn.Identity()
133
+
134
+ return new_model
135
+ else:
136
+ # Direct model access (not DataParallel)
137
+ encoder_layers = nn.Sequential()
138
+ for i in range(target_layer_idx + 1):
139
+ layer_name = f"encoder_layer_{i}"
140
+ if hasattr(new_model.encoder.layers, layer_name):
141
+ encoder_layers.add_module(layer_name,
142
+ getattr(new_model.encoder.layers, layer_name))
143
+
144
+ # Replace the encoder layers with our truncated version
145
+ new_model.encoder.layers = encoder_layers
146
+
147
+ # Remove the heads since we're stopping at the encoder layer
148
+ new_model.heads = nn.Identity()
149
+
150
+ return new_model
151
+
152
+ except (ValueError, IndexError) as e:
153
+ raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}")
154
+
155
+ # Handling for ViT whole blocks
156
+ elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')):
157
+ # Check for DataParallel wrapper
158
+ base_model = model.module if hasattr(model, 'module') else model
159
+
160
+ # Create a deep copy to avoid modifying the original
161
+ new_model = copy.deepcopy(model)
162
+ base_new_model = new_model.module if hasattr(new_model, 'module') else new_model
163
+
164
+ # Add the desired number of transformer blocks
165
+ if isinstance(layer_index, int):
166
+ # Truncate the blocks
167
+ base_new_model.blocks = base_new_model.blocks[:layer_index+1]
168
+
169
+ return new_model
170
+
171
+ else:
172
+ # Original ResNet/VGG handling
173
+ modules = list(model.named_children())
174
+ print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}")
175
+
176
+ cutoff_idx = next((i for i, (name, _) in enumerate(modules)
177
+ if name == str(layer_index)), None)
178
+
179
+ if cutoff_idx is not None:
180
+ # Keep modules up to and including the target
181
+ new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1]))
182
+ return new_model
183
+ else:
184
+ raise ValueError(f"Module {layer_index} not found in model")
185
+
186
+ # Get ImageNet labels
187
+ def get_imagenet_labels():
188
+ url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
189
+ response = requests.get(url)
190
+ if response.status_code == 200:
191
+ return response.json()
192
+ else:
193
+ raise RuntimeError("Failed to fetch ImageNet labels")
194
+
195
+ # Download model if needed
196
+ def download_model(model_type):
197
+ if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
198
+ return None # Use PyTorch's pretrained model
199
+
200
+ # Handle special case for face model
201
+ if model_type == 'resnet50_robust_face':
202
+ model_path = Path("models/resnet50_robust_face_100_checkpoint.pt")
203
+ else:
204
+ model_path = Path(f"models/{model_type}.pt")
205
+
206
+ if not model_path.exists():
207
+ print(f"Downloading {model_type} model...")
208
+ url = MODEL_URLS[model_type]
209
+ response = requests.get(url, stream=True)
210
+ if response.status_code == 200:
211
+ with open(model_path, 'wb') as f:
212
+ for chunk in response.iter_content(chunk_size=8192):
213
+ f.write(chunk)
214
+ print(f"Model downloaded and saved to {model_path}")
215
+ else:
216
+ raise RuntimeError(f"Failed to download model: {response.status_code}")
217
+ return model_path
218
+
219
+ class NormalizeByChannelMeanStd(nn.Module):
220
+ def __init__(self, mean, std):
221
+ super(NormalizeByChannelMeanStd, self).__init__()
222
+ if not isinstance(mean, torch.Tensor):
223
+ mean = torch.tensor(mean)
224
+ if not isinstance(std, torch.Tensor):
225
+ std = torch.tensor(std)
226
+ self.register_buffer("mean", mean)
227
+ self.register_buffer("std", std)
228
+
229
+ def forward(self, tensor):
230
+ return self.normalize_fn(tensor, self.mean, self.std)
231
+
232
+ def normalize_fn(self, tensor, mean, std):
233
+ """Differentiable version of torchvision.functional.normalize"""
234
+ # here we assume the color channel is at dim=1
235
+ mean = mean[None, :, None, None]
236
+ std = std[None, :, None, None]
237
+ return tensor.sub(mean).div(std)
238
+
239
+ class InferStep:
240
+ def __init__(self, orig_image, eps, step_size, adaptive_eps_config=None, adaptive_step_config=None):
241
+ self.orig_image = orig_image
242
+ dev = orig_image.device
243
+
244
+ # Adaptive epsilon (spatially varying constraint)
245
+ if adaptive_eps_config and adaptive_eps_config.get("enabled", False):
246
+ self.use_adaptive_eps = True
247
+ base_eps = float(adaptive_eps_config.get("base_epsilon", eps))
248
+ self.eps_map = create_flat_top_gaussian_mask(
249
+ orig_image.shape,
250
+ float(adaptive_eps_config.get("center_x", 0.0)),
251
+ float(adaptive_eps_config.get("center_y", 0.0)),
252
+ float(adaptive_eps_config.get("flat_radius", 0.3)),
253
+ float(adaptive_eps_config.get("sigma", 0.2)),
254
+ float(adaptive_eps_config.get("max_multiplier", 4.0)),
255
+ float(adaptive_eps_config.get("min_multiplier", 1.0)),
256
+ device=dev,
257
+ ) * base_eps
258
+ self.eps_map = self.eps_map.to(dev)
259
+ else:
260
+ self.use_adaptive_eps = False
261
+ self.eps = eps
262
+
263
+ # Adaptive step size (spatially varying update strength)
264
+ if adaptive_step_config and adaptive_step_config.get("enabled", False):
265
+ self.use_adaptive_step_size = True
266
+ base_step = float(adaptive_step_config.get("base_step_size", step_size))
267
+ self.step_size_map = create_flat_top_gaussian_mask(
268
+ orig_image.shape,
269
+ float(adaptive_step_config.get("center_x", 0.0)),
270
+ float(adaptive_step_config.get("center_y", 0.0)),
271
+ float(adaptive_step_config.get("flat_radius", 0.3)),
272
+ float(adaptive_step_config.get("sigma", 0.2)),
273
+ float(adaptive_step_config.get("max_multiplier", 4.0)),
274
+ float(adaptive_step_config.get("min_multiplier", 1.0)),
275
+ device=dev,
276
+ ) * base_step
277
+ self.step_size_map = self.step_size_map.to(dev)
278
+ else:
279
+ self.use_adaptive_step_size = False
280
+ self.step_size = step_size
281
+
282
+ def project(self, x):
283
+ # Per-pixel L2 (color) constraint: scale diff so color_norm <= eps (or eps_map).
284
+ # Constraint is per-pixel L2 on RGB (spherical), not L∞ per channel.
285
+ diff = x - self.orig_image # (B, C, H, W)
286
+ color_norm = torch.norm(diff, dim=1, keepdim=True) # (B, 1, H, W)
287
+ if self.use_adaptive_eps:
288
+ # eps_map has shape (1, 1, H, W), broadcasts to (B, 1, H, W)
289
+ scale = torch.clamp(self.eps_map / (color_norm + 1e-10), max=1.0)
290
+ else:
291
+ scale = torch.clamp(self.eps / (color_norm + 1e-10), max=1.0)
292
+ diff = diff * scale
293
+ return torch.clamp(self.orig_image + diff, 0, 1)
294
+
295
+ def step(self, x, grad):
296
+ # Same gradient everywhere; scale magnitude by step_size (or step_size_map). Multiply, not add/subtract.
297
+ l = len(x.shape) - 1
298
+ grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1] * l))
299
+ scaled_grad = grad / (grad_norm + 1e-10)
300
+ if self.use_adaptive_step_size:
301
+ return scaled_grad * self.step_size_map
302
+ return scaled_grad * self.step_size
303
+
304
+ def get_iterations_to_show(n_itr):
305
+ """Generate a dynamic list of iterations to show based on total iterations."""
306
+ if n_itr <= 50:
307
+ return [1, 5, 10, 20, 30, 40, 50, n_itr]
308
+ elif n_itr <= 100:
309
+ return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
310
+ elif n_itr <= 200:
311
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
312
+ elif n_itr <= 500:
313
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
314
+ else:
315
+ # For very large iterations, show more evenly distributed points
316
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
317
+ int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
318
+
319
+ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
320
+ """Generate inference configuration with customizable parameters.
321
+
322
+ Args:
323
+ inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion')
324
+ eps (float): Maximum perturbation size
325
+ n_itr (int): Number of iterations
326
+ step_size (float): Step size for each iteration
327
+ """
328
+
329
+ # Base configuration common to all inference types
330
+ config = {
331
+ 'loss_infer': inference_type, # How to guide the optimization
332
+ 'n_itr': n_itr, # Number of iterations
333
+ 'eps': eps, # Maximum perturbation size
334
+ 'step_size': step_size, # Step size for each iteration
335
+ 'diffusion_noise_ratio': 0.0, # No diffusion noise
336
+ 'initial_inference_noise_ratio': 0.0, # No initial noise
337
+ 'top_layer': 'all', # Use all layers of the model
338
+ 'inference_normalization': 'off', # 'on' or 'off' (match psychiatry implementation)
339
+ 'recognition_normalization': 'off',
340
+ 'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
341
+ 'misc_info': {'keep_grads': False}, # Additional configuration
342
+ 'biased_inference': {'enable': False, 'class': None} # Bias perception toward a target class (ImageNet label)
343
+ }
344
+
345
+ # Customize based on inference type
346
+ if inference_type == 'IncreaseConfidence':
347
+ config['loss_function'] = 'CE' # Cross Entropy
348
+
349
+ elif inference_type == 'Prior-Guided Drift Diffusion':
350
+ config['loss_function'] = 'MSE' # Mean Square Error
351
+ config['initial_inference_noise_ratio'] = 0.05 # Initial noise for diffusion
352
+ config['diffusion_noise_ratio'] = 0.01 # Add noise during diffusion
353
+
354
+ elif inference_type == 'GradModulation':
355
+ config['loss_function'] = 'CE' # Cross Entropy
356
+ config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength
357
+
358
+ elif inference_type == 'CompositionalFusion':
359
+ config['loss_function'] = 'CE' # Cross Entropy
360
+ config['misc_info']['positive_classes'] = [] # Classes to maximize
361
+ config['misc_info']['negative_classes'] = [] # Classes to minimize
362
+
363
+ return config
364
+
365
+ class GenerativeInferenceModel:
366
+ def __init__(self):
367
+ self.models = {}
368
+ self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
369
+ self.labels = get_imagenet_labels()
370
+
371
+ def verify_model_integrity(self, model, model_type):
372
+ """
373
+ Verify model integrity by running a test input through it.
374
+ Returns whether the model passes basic integrity check.
375
+ """
376
+ try:
377
+ print(f"\n=== Running model integrity check for {model_type} ===")
378
+ # Create a deterministic test input directly on the correct device
379
+ test_input = torch.zeros(1, 3, 224, 224, device=device)
380
+ test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
381
+
382
+ # Run forward pass
383
+ with torch.no_grad():
384
+ output = model(test_input)
385
+
386
+ # Check output shape
387
+ if output.shape != (1, 1000):
388
+ print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)")
389
+ return False
390
+
391
+ # Get top prediction
392
+ probs = torch.nn.functional.softmax(output, dim=1)
393
+ confidence, prediction = torch.max(probs, 1)
394
+
395
+ # Calculate basic statistics on output
396
+ mean = output.mean().item()
397
+ std = output.std().item()
398
+ min_val = output.min().item()
399
+ max_val = output.max().item()
400
+
401
+ print(f"Model integrity check results:")
402
+ print(f"- Output shape: {output.shape}")
403
+ print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
404
+ print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}")
405
+
406
+ # Basic sanity checks
407
+ if torch.isnan(output).any():
408
+ print("❌ Model produced NaN outputs")
409
+ return False
410
+
411
+ if output.std().item() < 0.1:
412
+ print("⚠️ Low output variance, model may not be discriminative")
413
+
414
+ print("✅ Model passes basic integrity check")
415
+ return True
416
+
417
+ except Exception as e:
418
+ print(f"❌ Model integrity check failed with error: {e}")
419
+ # Rather than failing completely, we'll continue
420
+ return True
421
+
422
+ def load_model(self, model_type):
423
+ """Load model from checkpoint or use pretrained model."""
424
+ if model_type in self.models:
425
+ print(f"Using cached {model_type} model")
426
+ return self.models[model_type]
427
+
428
+ # Record loading time for performance analysis
429
+ start_time = time.time()
430
+ model_path = download_model(model_type)
431
+
432
+ # Create a sequential model with normalizer and ResNet50
433
+ resnet = models.resnet50()
434
+ model = nn.Sequential(
435
+ self.normalizer, # Normalizer is part of the model sequence
436
+ resnet
437
+ )
438
+
439
+ # Load the model checkpoint
440
+ if model_path:
441
+ print(f"Loading {model_type} model from {model_path}...")
442
+ try:
443
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
444
+
445
+ # Print checkpoint structure for better understanding
446
+ print("\n=== Analyzing checkpoint structure ===")
447
+ if isinstance(checkpoint, dict):
448
+ print(f"Checkpoint contains keys: {list(checkpoint.keys())}")
449
+
450
+ # Examine 'model' structure if it exists
451
+ if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
452
+ model_dict = checkpoint['model']
453
+ # Get sample of keys to understand structure
454
+ first_keys = list(model_dict.keys())[:5]
455
+ print(f"'model' contains keys like: {first_keys}")
456
+
457
+ # Check for common prefixes in the model dict
458
+ prefixes = set()
459
+ for key in list(model_dict.keys())[:100]: # Check first 100 keys
460
+ parts = key.split('.')
461
+ if len(parts) > 1:
462
+ prefixes.add(parts[0])
463
+ if prefixes:
464
+ print(f"Common prefixes in model dict: {prefixes}")
465
+ else:
466
+ print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}")
467
+
468
+ # Handle different checkpoint formats
469
+ if 'model' in checkpoint:
470
+ # Format from madrylab robust models
471
+ state_dict = checkpoint['model']
472
+ print("Using 'model' key from checkpoint")
473
+ elif 'state_dict' in checkpoint:
474
+ state_dict = checkpoint['state_dict']
475
+ print("Using 'state_dict' key from checkpoint")
476
+ else:
477
+ # Direct state dict
478
+ state_dict = checkpoint
479
+ print("Using checkpoint directly as state_dict")
480
+
481
+ # Handle prefix in state dict keys for ResNet part
482
+ resnet_state_dict = {}
483
+ prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.']
484
+ resnet_keys = set(resnet.state_dict().keys())
485
+
486
+ # First check if we can find keys directly in the attacker.model path
487
+ print("\n=== Phase 1: Checking for specific model structures ===")
488
+
489
+ # Check for 'module.model' structure (seen in actual checkpoint)
490
+ module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')]
491
+ if module_model_keys:
492
+ print(f"Found 'module.model' structure with {len(module_model_keys)} parameters")
493
+ # Extract all parameters from module.model
494
+ for source_key, value in state_dict.items():
495
+ if source_key.startswith('module.model.'):
496
+ target_key = source_key[len('module.model.'):]
497
+ resnet_state_dict[target_key] = value
498
+
499
+ print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
500
+
501
+ # Check for 'attacker.model' structure
502
+ attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')]
503
+ if attacker_model_keys:
504
+ print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters")
505
+ # Extract all parameters from attacker.model
506
+ for source_key, value in state_dict.items():
507
+ if source_key.startswith('attacker.model.'):
508
+ target_key = source_key[len('attacker.model.'):]
509
+ resnet_state_dict[target_key] = value
510
+
511
+ print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model")
512
+
513
+ # Check if 'model' (not attacker.model) exists as a fallback
514
+ model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')]
515
+ if model_keys and len(resnet_state_dict) < len(resnet_keys):
516
+ print(f"Found additional 'model.' structure with {len(model_keys)} parameters")
517
+ # Try to complete missing parameters
518
+ for source_key, value in state_dict.items():
519
+ if source_key.startswith('model.'):
520
+ target_key = source_key[len('model.'):]
521
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
522
+ resnet_state_dict[target_key] = value
523
+
524
+ else:
525
+ # Check for other known structures
526
+ structure_found = False
527
+
528
+ # Check for 'model.' prefix
529
+ model_keys = [key for key in state_dict.keys() if key.startswith('model.')]
530
+ if model_keys:
531
+ print(f"Found 'model.' structure with {len(model_keys)} parameters")
532
+ for source_key, value in state_dict.items():
533
+ if source_key.startswith('model.'):
534
+ target_key = source_key[len('model.'):]
535
+ resnet_state_dict[target_key] = value
536
+ structure_found = True
537
+
538
+ # Check for ResNet parameters at the top level
539
+ top_level_resnet_keys = 0
540
+ for key in resnet_keys:
541
+ if key in state_dict:
542
+ top_level_resnet_keys += 1
543
+
544
+ if top_level_resnet_keys > 0:
545
+ print(f"Found {top_level_resnet_keys} ResNet parameters at top level")
546
+ for target_key in resnet_keys:
547
+ if target_key in state_dict:
548
+ resnet_state_dict[target_key] = state_dict[target_key]
549
+ structure_found = True
550
+
551
+ # If no structure was recognized, try the prefix mapping approach
552
+ if not structure_found:
553
+ print("No standard model structure found, trying prefix mappings...")
554
+ for target_key in resnet_keys:
555
+ for prefix in prefixes_to_try:
556
+ source_key = prefix + target_key
557
+ if source_key in state_dict:
558
+ resnet_state_dict[target_key] = state_dict[source_key]
559
+ break
560
+
561
+ # If we still can't find enough keys, try a final approach of removing prefixes
562
+ if len(resnet_state_dict) < len(resnet_keys):
563
+ print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...")
564
+
565
+ # Track matches found through prefix removal
566
+ prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']}
567
+ layer_matches = {} # Track matches by layer type
568
+
569
+ # Count parameter keys by layer type for analysis
570
+ for key in resnet_keys:
571
+ layer_name = key.split('.')[0] if '.' in key else key
572
+ if layer_name not in layer_matches:
573
+ layer_matches[layer_name] = {'total': 0, 'matched': 0}
574
+ layer_matches[layer_name]['total'] += 1
575
+
576
+ # Try keys with common prefixes
577
+ for source_key, value in state_dict.items():
578
+ # Skip if already found
579
+ target_key = source_key
580
+ matched_prefix = None
581
+
582
+ # Try removing various prefixes
583
+ for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']:
584
+ if source_key.startswith(prefix):
585
+ target_key = source_key[len(prefix):]
586
+ matched_prefix = prefix
587
+ break
588
+
589
+ # If the target key is in the ResNet keys, add it to the state dict
590
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
591
+ resnet_state_dict[target_key] = value
592
+
593
+ # Update match statistics
594
+ if matched_prefix:
595
+ prefix_matches[matched_prefix] += 1
596
+
597
+ # Update layer matches
598
+ layer_name = target_key.split('.')[0] if '.' in target_key else target_key
599
+ if layer_name in layer_matches:
600
+ layer_matches[layer_name]['matched'] += 1
601
+
602
+ # Print detailed prefix removal statistics
603
+ print("\n=== Prefix Removal Statistics ===")
604
+ total_matches = sum(prefix_matches.values())
605
+ print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)")
606
+
607
+ # Show matches by prefix
608
+ print("\nMatches by prefix:")
609
+ for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True):
610
+ if count > 0:
611
+ print(f" {prefix}: {count} parameters")
612
+
613
+ # Show matches by layer type
614
+ print("\nMatches by layer type:")
615
+ for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True):
616
+ match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0
617
+ print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)")
618
+
619
+ # Check for specific important layers (conv1, layer1, etc.)
620
+ critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
621
+ print("\nStatus of critical layers:")
622
+ for layer in critical_layers:
623
+ if layer in layer_matches:
624
+ match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100
625
+ status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE"
626
+ print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}")
627
+ else:
628
+ print(f" {layer}: Not found in model")
629
+
630
+ # Load the ResNet state dict
631
+ if resnet_state_dict:
632
+ try:
633
+ # Use strict=False to allow missing keys
634
+ result = resnet.load_state_dict(resnet_state_dict, strict=False)
635
+ missing_keys, unexpected_keys = result
636
+
637
+ # Generate detailed information with better formatting
638
+ loading_report = []
639
+ loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====")
640
+ loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}")
641
+ loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}")
642
+ loading_report.append(f"Missing keys: {len(missing_keys):,} parameters")
643
+ loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters")
644
+
645
+ # Calculate percentage of parameters loaded
646
+ loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys)
647
+ loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100
648
+
649
+ # Determine loading success status
650
+ if loaded_percent >= 99.5:
651
+ status = "✅ COMPLETE - All important parameters loaded"
652
+ elif loaded_percent >= 90:
653
+ status = "🟡 PARTIAL - Most parameters loaded, should still function"
654
+ elif loaded_percent >= 50:
655
+ status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly"
656
+ else:
657
+ status = "❌ FAILED - Critical parameters missing, will not function properly"
658
+
659
+ loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)")
660
+ loading_report.append(f"Loading status: {status}")
661
+
662
+ # If loading is severely incomplete, fall back to PyTorch's pretrained model
663
+ if loaded_percent < 50:
664
+ loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.")
665
+ loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.")
666
+
667
+ # Create a new ResNet model with pretrained weights
668
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
669
+ model = nn.Sequential(self.normalizer, resnet)
670
+ loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
671
+
672
+ # Show missing keys by layer type
673
+ if missing_keys:
674
+ loading_report.append("\nMissing keys by layer type:")
675
+ layer_types = {}
676
+ for key in missing_keys:
677
+ # Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.)
678
+ parts = key.split('.')
679
+ if len(parts) > 0:
680
+ layer_type = parts[0]
681
+ if layer_type not in layer_types:
682
+ layer_types[layer_type] = 0
683
+ layer_types[layer_type] += 1
684
+
685
+ # Add counts by layer type
686
+ for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
687
+ loading_report.append(f" {layer_type}: {count:,} parameters")
688
+
689
+ loading_report.append("\nFirst 10 missing keys:")
690
+ for i, key in enumerate(sorted(missing_keys)[:10]):
691
+ loading_report.append(f" {i+1}. {key}")
692
+
693
+ # Show unexpected keys if any
694
+ if unexpected_keys:
695
+ loading_report.append("\nFirst 10 unexpected keys:")
696
+ for i, key in enumerate(sorted(unexpected_keys)[:10]):
697
+ loading_report.append(f" {i+1}. {key}")
698
+
699
+ loading_report.append("========================================")
700
+
701
+ # Convert report to string and print it
702
+ report_text = "\n".join(loading_report)
703
+ print(report_text)
704
+
705
+ # Also save to a file for reference
706
+ os.makedirs("logs", exist_ok=True)
707
+ with open(f"logs/model_loading_{model_type}.log", "w") as f:
708
+ f.write(report_text)
709
+
710
+ # Look for normalizer parameters as well
711
+ if any(key.startswith('attacker.normalize.') for key in state_dict.keys()):
712
+ norm_state_dict = {}
713
+ for key, value in state_dict.items():
714
+ if key.startswith('attacker.normalize.'):
715
+ norm_key = key[len('attacker.normalize.'):]
716
+ norm_state_dict[norm_key] = value
717
+
718
+ if norm_state_dict:
719
+ try:
720
+ self.normalizer.load_state_dict(norm_state_dict, strict=False)
721
+ print("Successfully loaded normalizer parameters")
722
+ except Exception as e:
723
+ print(f"Warning: Could not load normalizer parameters: {e}")
724
+ except Exception as e:
725
+ print(f"Warning: Error loading ResNet parameters: {e}")
726
+ # Fall back to loading without normalizer
727
+ model = resnet # Use just the ResNet model without normalizer
728
+ except Exception as e:
729
+ print(f"Error loading model checkpoint: {e}")
730
+ # Fallback to PyTorch's pretrained model
731
+ print("Falling back to PyTorch's pretrained model")
732
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
733
+ model = nn.Sequential(self.normalizer, resnet)
734
+ else:
735
+ # Fallback to PyTorch's pretrained model
736
+ print("No checkpoint available, using PyTorch's pretrained model")
737
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
738
+ model = nn.Sequential(self.normalizer, resnet)
739
+
740
+ model = model.to(device)
741
+ model.eval() # Set to evaluation mode
742
+
743
+ # Verify model integrity
744
+ self.verify_model_integrity(model, model_type)
745
+
746
+ # Store the model for future use
747
+ self.models[model_type] = model
748
+ end_time = time.time()
749
+ load_time = end_time - start_time
750
+ print(f"Model {model_type} loaded in {load_time:.2f} seconds")
751
+ return model
752
+
753
+ def inference(self, image, model_type, config):
754
+ """Run generative inference on the image."""
755
+ # Time the entire inference process
756
+ inference_start = time.time()
757
+
758
+ # Load model if not already loaded
759
+ model = self.load_model(model_type)
760
+
761
+ # Check if image is a file path
762
+ if isinstance(image, str):
763
+ if os.path.exists(image):
764
+ image = Image.open(image).convert('RGB')
765
+ else:
766
+ raise ValueError(f"Image path does not exist: {image}")
767
+ elif isinstance(image, torch.Tensor):
768
+ raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor")
769
+
770
+ # Prepare image tensor - match original code's conditional transform
771
+ load_start = time.time()
772
+ use_norm = config['inference_normalization'] == 'on'
773
+ custom_transform = get_transform(
774
+ input_size=224,
775
+ normalize=use_norm,
776
+ norm_mean=IMAGENET_MEAN,
777
+ norm_std=IMAGENET_STD
778
+ )
779
+
780
+ # Special handling for GradModulation as in original
781
+ if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']:
782
+ grad_modulation = config['misc_info']['grad_modulation']
783
+ image_tensor = custom_transform(image).unsqueeze(0).to(device)
784
+ image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device)
785
+ else:
786
+ image_tensor = custom_transform(image).unsqueeze(0).to(device)
787
+
788
+ image_tensor.requires_grad = True
789
+ print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds")
790
+
791
+ # Check model structure
792
+ is_sequential = isinstance(model, nn.Sequential)
793
+
794
+ # Get original predictions
795
+ with torch.no_grad():
796
+ # If the model is sequential with a normalizer, skip the normalization step
797
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
798
+ print("Model is sequential with normalization")
799
+ # Get the core model part (typically at index 1 in Sequential)
800
+ core_model = model[1]
801
+ if config['inference_normalization'] == 'on':
802
+ output_original = model(image_tensor) # Model includes normalization
803
+ else:
804
+ output_original = core_model(image_tensor) # Model includes normalization
805
+
806
+ else:
807
+ print("Model is not sequential with normalization")
808
+ # Use manual normalization for non-sequential models
809
+ if config['inference_normalization'] == 'on':
810
+ normalized_tensor = normalize_transform(image_tensor)
811
+ output_original = model(normalized_tensor)
812
+ else:
813
+ output_original = model(image_tensor)
814
+ core_model = model
815
+
816
+ probs_orig = F.softmax(output_original, dim=1)
817
+ conf_orig, classes_orig = torch.max(probs_orig, 1)
818
+
819
+ # Get least confident classes
820
+ _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False)
821
+
822
+ # Initialize inference step (adaptive Gaussian mask from config)
823
+ adaptive_eps = config.get("adaptive_epsilon")
824
+ adaptive_step = config.get("adaptive_step_size")
825
+ if adaptive_eps and adaptive_eps.get("enabled"):
826
+ print(f"Applying adaptive epsilon mask: center=({adaptive_eps.get('center_x'):.2f}, {adaptive_eps.get('center_y'):.2f}), radius={adaptive_eps.get('flat_radius'):.2f}")
827
+ if adaptive_step and adaptive_step.get("enabled"):
828
+ print(f"Applying adaptive step-size mask: center=({adaptive_step.get('center_x'):.2f}, {adaptive_step.get('center_y'):.2f}), radius={adaptive_step.get('flat_radius'):.2f}")
829
+ infer_step = InferStep(
830
+ image_tensor,
831
+ config["eps"],
832
+ config["step_size"],
833
+ adaptive_eps_config=adaptive_eps,
834
+ adaptive_step_config=adaptive_step,
835
+ )
836
+
837
+ # Biased inference: resolve class name to index (ImageNet simple labels, case-insensitive)
838
+ biased_inference_config = config.get('biased_inference', {'enable': False, 'class': None})
839
+ biased_class_index = None
840
+ biased_class_tensor = None
841
+ if biased_inference_config.get('enable', False):
842
+ class_name = biased_inference_config.get('class') or biased_inference_config.get('class_name')
843
+ if class_name:
844
+ try:
845
+ biased_class_index = next(
846
+ i for i, label in enumerate(self.labels)
847
+ if label.lower() == class_name.lower()
848
+ )
849
+ biased_class_tensor = torch.tensor(
850
+ [biased_class_index], device=device, dtype=torch.long
851
+ )
852
+ print(f"Biased inference: biasing toward class '{self.labels[biased_class_index]}' (index {biased_class_index})")
853
+ except StopIteration:
854
+ raise ValueError(
855
+ f"biased_inference class '{class_name}' not found in ImageNet simple labels. "
856
+ "Use a label from imagenet-simple-labels (e.g. 'goldfish', 'tabby cat')."
857
+ )
858
+ else:
859
+ print("Biased inference enabled but no class specified; ignoring.")
860
+
861
+ # Storage for inference steps
862
+ # Create a new tensor that requires gradients
863
+ x = image_tensor.clone().detach().requires_grad_(True)
864
+ all_steps = [image_tensor[0].detach().cpu()]
865
+
866
+ # For Prior-Guided Drift Diffusion, extract selected layer and initialize with noisy features
867
+ noisy_features = None
868
+ layer_model = None
869
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
870
+ print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...")
871
+
872
+ # Extract model up to the specified layer
873
+ try:
874
+ # Start by finding the actual model to use
875
+ base_model = model
876
+
877
+ # Handle DataParallel wrapper if present
878
+ if hasattr(base_model, 'module'):
879
+ base_model = base_model.module
880
+
881
+ # Log the initial model structure
882
+ print(f"DEBUG - Initial model structure: {type(base_model)}")
883
+
884
+ # If we have a Sequential model (which is likely our normalizer + model structure)
885
+ if isinstance(base_model, nn.Sequential):
886
+ print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children")
887
+
888
+ # If this is our NormalizeByChannelMeanStd + ResNet pattern
889
+ if len(list(base_model.children())) >= 2:
890
+ # The actual ResNet model is the second component (index 1)
891
+ actual_model = list(base_model.children())[1]
892
+ print(f"DEBUG - Using ResNet component: {type(actual_model)}")
893
+ print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}")
894
+
895
+ # Extract from the actual ResNet
896
+ layer_model = extract_middle_layers(actual_model, config['top_layer'])
897
+ else:
898
+ # Just a single component Sequential
899
+ layer_model = extract_middle_layers(base_model, config['top_layer'])
900
+ else:
901
+ # Not Sequential, might be direct model
902
+ print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}")
903
+ layer_model = extract_middle_layers(base_model, config['top_layer'])
904
+
905
+ print(f"Successfully extracted model up to layer: {config['top_layer']}")
906
+ except ValueError as e:
907
+ print(f"Layer extraction failed: {e}. Using full model.")
908
+ layer_model = model
909
+
910
+ # Add noise to the image - exactly match original code
911
+ added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device)
912
+ noisy_image_tensor = image_tensor + added_noise
913
+
914
+ # Compute noisy features - simplified to match original code
915
+ noisy_features = layer_model(noisy_image_tensor)
916
+
917
+ print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}")
918
+
919
+ # Main inference loop
920
+ print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
921
+ loop_start = time.time()
922
+ for i in range(config['n_itr']):
923
+ # Reset gradients
924
+ x.grad = None
925
+
926
+ # Forward pass - use layer_model for Prior-Guided Drift Diffusion, full model otherwise
927
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
928
+ # Use the extracted layer model for Prior-Guided Drift Diffusion
929
+ # In original code, normalization is handled at transform time, not during forward pass
930
+ output = layer_model(x)
931
+ else:
932
+ # Standard forward pass with full model
933
+ # Simplified to match original code's approach
934
+ output = model(x)
935
+
936
+ # Calculate loss and gradients based on inference type
937
+ try:
938
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
939
+ # Use MSE loss to match the noisy features
940
+ assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE"
941
+ if noisy_features is not None:
942
+ loss = F.mse_loss(output, noisy_features)
943
+ grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original
944
+ else:
945
+ raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
946
+
947
+ else: # Default 'IncreaseConfidence' approach
948
+ # Get the least confident classes
949
+ num_classes = min(10, least_confident_classes.size(1))
950
+ target_classes = least_confident_classes[0, :num_classes]
951
+
952
+ # Create targets for least confident classes
953
+ targets = torch.tensor([idx.item() for idx in target_classes], device=device)
954
+
955
+ # Use a combined loss to increase confidence
956
+ loss = 0
957
+ for target in targets:
958
+ # Create one-hot target
959
+ one_hot = torch.zeros_like(output)
960
+ one_hot[0, target] = 1
961
+ # Use loss to maximize confidence
962
+ loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
963
+
964
+ grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
965
+
966
+ # Biased inference: subtract gradient that pushes toward target class
967
+ if biased_inference_config.get('enable', False) and biased_class_tensor is not None:
968
+ output_full = model(x)
969
+ loss_biased = F.cross_entropy(output_full, biased_class_tensor)
970
+ grad_biased = torch.autograd.grad(loss_biased, x)[0]
971
+ grad = grad - grad_biased
972
+
973
+ if grad is None:
974
+ print("Warning: Direct gradient calculation failed")
975
+ # Fall back to random perturbation
976
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
977
+ x = infer_step.project(x + random_noise)
978
+ else:
979
+ # Update image with gradient - do this exactly as in original code
980
+ adjusted_grad = infer_step.step(x, grad)
981
+
982
+ # Add diffusion noise if specified
983
+ diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device)
984
+
985
+ # Apply gradient and noise in one operation before projecting, exactly as in original
986
+ x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
987
+
988
+ except Exception as e:
989
+ print(f"Error in gradient calculation: {e}")
990
+ # Fall back to random perturbation - match original code
991
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
992
+ x = infer_step.project(x.clone() + random_noise)
993
+
994
+ # Store step if in iterations_to_show
995
+ if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
996
+ all_steps.append(x[0].detach().cpu())
997
+
998
+ # Print some info about the inference
999
+ with torch.no_grad():
1000
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
1001
+ if config['inference_normalization'] == 'on':
1002
+ final_output = model(x)
1003
+ else:
1004
+ final_output = core_model(x)
1005
+ else:
1006
+ if config['inference_normalization'] == 'on':
1007
+ normalized_x = normalize_transform(x)
1008
+ final_output = model(normalized_x)
1009
+ else:
1010
+ final_output = model(x)
1011
+
1012
+ final_probs = F.softmax(final_output, dim=1)
1013
+ final_conf, final_classes = torch.max(final_probs, 1)
1014
+
1015
+ # Calculate timing information
1016
+ loop_time = time.time() - loop_start
1017
+ total_time = time.time() - inference_start
1018
+ avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0
1019
+
1020
+ print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
1021
+ print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
1022
+ print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)")
1023
+ print(f"Total inference time: {total_time:.2f} seconds")
1024
+
1025
+ # Return results in format compatible with both old and new code
1026
+ return {
1027
+ 'final_image': x[0].detach().cpu(),
1028
+ 'steps': all_steps,
1029
+ 'original_class': classes_orig.item(),
1030
+ 'original_confidence': conf_orig.item(),
1031
+ 'final_class': final_classes.item(),
1032
+ 'final_confidence': final_conf.item()
1033
+ }
1034
+
1035
+ # Utility function to show inference steps
1036
+ def show_inference_steps(steps, figsize=(15, 10)):
1037
+ import matplotlib.pyplot as plt
1038
+
1039
+ n_steps = len(steps)
1040
+ fig, axes = plt.subplots(1, n_steps, figsize=figsize)
1041
+
1042
+ for i, step_img in enumerate(steps):
1043
+ img = step_img.permute(1, 2, 0).numpy()
1044
+ axes[i].imshow(img)
1045
+ axes[i].set_title(f"Step {i}")
1046
+ axes[i].axis('off')
1047
+
1048
+ plt.tight_layout()
1049
+ return fig
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pillow
5
+ gradio
6
+ matplotlib
7
+ requests
8
+ tqdm
9
+ huggingface_hub