Spaces:
Sleeping
Sleeping
Upload 23 files
Browse files- .gitattributes +5 -0
- LICENSE +21 -0
- README.md +77 -7
- app.py +409 -4
- face_vase.png +0 -0
- huggingface-metadata.json +11 -0
- inference.py +923 -0
- logs/model_loading_resnet50_robust.log +9 -0
- logs/model_loading_resnet50_robust_face.log +46 -0
- logs/model_loading_robust_resnet50.log +9 -0
- logs/model_loading_standard_resnet50.log +9 -0
- models/resnet50_robust.pt +3 -0
- models/resnet50_robust_face_100_checkpoint.pt +3 -0
- models/robust_resnet50.pt +3 -0
- models/standard_resnet50.pt +3 -0
- requirements.txt +9 -0
- stimuli/Confetti_illusion.png +3 -0
- stimuli/CornsweetBlock.png +3 -0
- stimuli/EhresteinSingleColor.png +3 -0
- stimuli/GroupingByContinuity.png +0 -0
- stimuli/Kanizsa_square.jpg +0 -0
- stimuli/Neon_Color_Circle.jpg +3 -0
- stimuli/face_vase.png +0 -0
- stimuli/figure_ground.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
stimuli/Confetti_illusion.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
stimuli/CornsweetBlock.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
stimuli/EhresteinSingleColor.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
stimuli/figure_ground.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
stimuli/Neon_Color_Circle.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 GenerativeInferenceDemo
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,13 +1,83 @@
|
|
| 1 |
---
|
| 2 |
-
title: Generative Inference
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Generative Inference Demo
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.23.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Generative Inference Demo
|
| 14 |
+
|
| 15 |
+
This Gradio demo showcases how neural networks perceive visual illusions through generative inference. The demo uses both standard and robust ResNet50 models to reveal emergent perception of contours, figure-ground separation, and other visual phenomena.
|
| 16 |
+
|
| 17 |
+
## Models
|
| 18 |
+
|
| 19 |
+
- **Robust ResNet50**: A model trained with adversarial examples (ε=3.0), exhibiting more human-like visual perception
|
| 20 |
+
- **Standard ResNet50**: A model trained without adversarial examples (ε=0.0)
|
| 21 |
+
|
| 22 |
+
## Features
|
| 23 |
+
|
| 24 |
+
- Upload your own images or use example illusions
|
| 25 |
+
- Choose between robust and standard models
|
| 26 |
+
- Adjust perturbation size (epsilon) and iteration count
|
| 27 |
+
- Visualize how perception emerges over time
|
| 28 |
+
- Includes classic illusions:
|
| 29 |
+
- Kanizsa shapes
|
| 30 |
+
- Face-Vase illusions
|
| 31 |
+
- Figure-Ground segmentation
|
| 32 |
+
- Neon color spreading
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
1. Select an example image or upload your own
|
| 37 |
+
2. Choose the model type (robust or standard)
|
| 38 |
+
3. Adjust epsilon and iteration parameters
|
| 39 |
+
4. Click "Run Inference" to see how the model perceives the image
|
| 40 |
+
|
| 41 |
+
## About
|
| 42 |
+
|
| 43 |
+
This demo is based on research showing how adversarially robust models develop more human-like visual representations. The generative inference process reveals these perceptual biases by optimizing the input to maximize the model's confidence.
|
| 44 |
+
|
| 45 |
+
## Installation
|
| 46 |
+
|
| 47 |
+
To run this demo locally:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
# Clone the repository
|
| 51 |
+
git clone [repo-url]
|
| 52 |
+
cd GenerativeInferenceDemo
|
| 53 |
+
|
| 54 |
+
# Install dependencies
|
| 55 |
+
pip install -r requirements.txt
|
| 56 |
+
|
| 57 |
+
# Run the app
|
| 58 |
+
python app.py
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
The web app will be available at http://localhost:7860 (or another port if 7860 is busy).
|
| 62 |
+
|
| 63 |
+
## About the Models
|
| 64 |
+
|
| 65 |
+
- **Robust ResNet50**: A model trained with adversarial examples, making it more robust to small perturbations. These models often exhibit more human-like visual perception.
|
| 66 |
+
- **Standard ResNet50**: A standard ImageNet-trained ResNet50 model.
|
| 67 |
+
|
| 68 |
+
## How It Works
|
| 69 |
+
|
| 70 |
+
1. The algorithm starts with an input image
|
| 71 |
+
2. It iteratively updates the image to increase the model's confidence in its predictions
|
| 72 |
+
3. These updates are constrained to a small neighborhood (controlled by epsilon) around the original image
|
| 73 |
+
4. The resulting changes reveal how the network "sees" the image
|
| 74 |
+
|
| 75 |
+
## Citation
|
| 76 |
+
|
| 77 |
+
If you use this work in your research, please cite the original paper:
|
| 78 |
+
|
| 79 |
+
[Citation information will be added here]
|
| 80 |
+
|
| 81 |
+
## License
|
| 82 |
+
|
| 83 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
app.py
CHANGED
|
@@ -1,7 +1,412 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
try:
|
| 6 |
+
from spaces import GPU
|
| 7 |
+
except ImportError:
|
| 8 |
+
# Define a no-op decorator if running locally
|
| 9 |
+
def GPU(func):
|
| 10 |
+
return func
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import argparse
|
| 14 |
+
from inference import GenerativeInferenceModel, get_inference_configs
|
| 15 |
|
| 16 |
+
# Parse command line arguments
|
| 17 |
+
parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
|
| 18 |
+
parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
|
| 21 |
+
# Create model directories if they don't exist
|
| 22 |
+
os.makedirs("models", exist_ok=True)
|
| 23 |
+
os.makedirs("stimuli", exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# Check if running on Hugging Face Spaces
|
| 26 |
+
if "SPACE_ID" in os.environ:
|
| 27 |
+
default_port = int(os.environ.get("PORT", 7860))
|
| 28 |
+
else:
|
| 29 |
+
default_port = 8861 # Local default port
|
| 30 |
+
|
| 31 |
+
# Initialize model
|
| 32 |
+
model = GenerativeInferenceModel()
|
| 33 |
+
|
| 34 |
+
# Define example images and their parameters with updated values from the research
|
| 35 |
+
examples = [
|
| 36 |
+
{
|
| 37 |
+
"image": os.path.join("stimuli", "Neon_Color_Circle.jpg"),
|
| 38 |
+
"name": "Neon Color Spreading",
|
| 39 |
+
"wiki": "https://en.wikipedia.org/wiki/Neon_color_spreading",
|
| 40 |
+
"papers": [
|
| 41 |
+
"[Color Assimilation](https://doi.org/10.1016/j.visres.2000.200.1)",
|
| 42 |
+
"[Perceptual Filling-in](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 43 |
+
],
|
| 44 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 45 |
+
"reverse_diff": {
|
| 46 |
+
"model": "resnet50_robust",
|
| 47 |
+
"layer": "layer3",
|
| 48 |
+
"initial_noise": 0.8,
|
| 49 |
+
"diffusion_noise": 0.003,
|
| 50 |
+
"step_size": 1.0,
|
| 51 |
+
"iterations": 101,
|
| 52 |
+
"epsilon": 20.0
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"image": os.path.join("stimuli", "Kanizsa_square.jpg"),
|
| 57 |
+
"name": "Kanizsa Square",
|
| 58 |
+
"wiki": "https://en.wikipedia.org/wiki/Kanizsa_triangle",
|
| 59 |
+
"papers": [
|
| 60 |
+
"[Gestalt Psychology](https://en.wikipedia.org/wiki/Gestalt_psychology)",
|
| 61 |
+
"[Neural Mechanisms](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 62 |
+
],
|
| 63 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 64 |
+
"reverse_diff": {
|
| 65 |
+
"model": "resnet50_robust",
|
| 66 |
+
"layer": "all",
|
| 67 |
+
"initial_noise": 0.0,
|
| 68 |
+
"diffusion_noise": 0.005,
|
| 69 |
+
"step_size": 0.64,
|
| 70 |
+
"iterations": 100,
|
| 71 |
+
"epsilon": 5.0
|
| 72 |
+
}
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"image": os.path.join("stimuli", "CornsweetBlock.png"),
|
| 76 |
+
"name": "Cornsweet Illusion",
|
| 77 |
+
"wiki": "https://en.wikipedia.org/wiki/Cornsweet_illusion",
|
| 78 |
+
"papers": [
|
| 79 |
+
"[Brightness Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
|
| 80 |
+
"[Edge Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 81 |
+
],
|
| 82 |
+
"instructions": "Both blocks are gray in color (the same), use your finger to cover the middle line. Hit 'Load Parameters' and then hit 'Run Generative Inference' to see how the model sees the blocks.",
|
| 83 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 84 |
+
"reverse_diff": {
|
| 85 |
+
"model": "resnet50_robust",
|
| 86 |
+
"layer": "layer3",
|
| 87 |
+
"initial_noise": 0.5,
|
| 88 |
+
"diffusion_noise": 0.005,
|
| 89 |
+
"step_size": 0.8,
|
| 90 |
+
"iterations": 51,
|
| 91 |
+
"epsilon": 20.0
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"image": os.path.join("stimuli", "face_vase.png"),
|
| 96 |
+
"name": "Rubin's Face-Vase (Object Prior)",
|
| 97 |
+
"wiki": "https://en.wikipedia.org/wiki/Rubin_vase",
|
| 98 |
+
"papers": [
|
| 99 |
+
"[Figure-Ground Perception](https://en.wikipedia.org/wiki/Figure-ground_(perception))",
|
| 100 |
+
"[Bistable Perception](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 101 |
+
],
|
| 102 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 103 |
+
"reverse_diff": {
|
| 104 |
+
"model": "resnet50_robust",
|
| 105 |
+
"layer": "avgpool",
|
| 106 |
+
"initial_noise": 0.9,
|
| 107 |
+
"diffusion_noise": 0.003,
|
| 108 |
+
"step_size": 0.58,
|
| 109 |
+
"iterations": 100,
|
| 110 |
+
"epsilon": 0.81
|
| 111 |
+
}
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"image": os.path.join("stimuli", "Confetti_illusion.png"),
|
| 115 |
+
"name": "Confetti Illusion",
|
| 116 |
+
"wiki": "https://www.youtube.com/watch?v=SvEiEi8O7QE",
|
| 117 |
+
"papers": [
|
| 118 |
+
"[Color Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
|
| 119 |
+
"[Context Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 120 |
+
],
|
| 121 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 122 |
+
"reverse_diff": {
|
| 123 |
+
"model": "resnet50_robust",
|
| 124 |
+
"layer": "layer3",
|
| 125 |
+
"initial_noise": 0.1,
|
| 126 |
+
"diffusion_noise": 0.003,
|
| 127 |
+
"step_size": 0.5,
|
| 128 |
+
"iterations": 101,
|
| 129 |
+
"epsilon": 20.0
|
| 130 |
+
}
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"image": os.path.join("stimuli", "EhresteinSingleColor.png"),
|
| 134 |
+
"name": "Ehrenstein Illusion",
|
| 135 |
+
"wiki": "https://en.wikipedia.org/wiki/Ehrenstein_illusion",
|
| 136 |
+
"papers": [
|
| 137 |
+
"[Subjective Contours](https://doi.org/10.1016/j.visres.2000.200.1)",
|
| 138 |
+
"[Neural Processing](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 139 |
+
],
|
| 140 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 141 |
+
"reverse_diff": {
|
| 142 |
+
"model": "resnet50_robust",
|
| 143 |
+
"layer": "layer3",
|
| 144 |
+
"initial_noise": 0.5,
|
| 145 |
+
"diffusion_noise": 0.005,
|
| 146 |
+
"step_size": 0.8,
|
| 147 |
+
"iterations": 101,
|
| 148 |
+
"epsilon": 20.0
|
| 149 |
+
}
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"image": os.path.join("stimuli", "GroupingByContinuity.png"),
|
| 153 |
+
"name": "Grouping by Continuity",
|
| 154 |
+
"wiki": "https://en.wikipedia.org/wiki/Principles_of_grouping",
|
| 155 |
+
"papers": [
|
| 156 |
+
"[Gestalt Principles](https://en.wikipedia.org/wiki/Gestalt_psychology)",
|
| 157 |
+
"[Visual Organization](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 158 |
+
],
|
| 159 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 160 |
+
"reverse_diff": {
|
| 161 |
+
"model": "resnet50_robust",
|
| 162 |
+
"layer": "layer3",
|
| 163 |
+
"initial_noise": 0.0,
|
| 164 |
+
"diffusion_noise": 0.005,
|
| 165 |
+
"step_size": 0.4,
|
| 166 |
+
"iterations": 101,
|
| 167 |
+
"epsilon": 4.0
|
| 168 |
+
}
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"image": os.path.join("stimuli", "figure_ground.png"),
|
| 172 |
+
"name": "Figure-Ground Illusion",
|
| 173 |
+
"wiki": "https://en.wikipedia.org/wiki/Figure-ground_(perception)",
|
| 174 |
+
"papers": [
|
| 175 |
+
"[Gestalt Principles](https://en.wikipedia.org/wiki/Gestalt_psychology)",
|
| 176 |
+
"[Perceptual Organization](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 177 |
+
],
|
| 178 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 179 |
+
"reverse_diff": {
|
| 180 |
+
"model": "resnet50_robust",
|
| 181 |
+
"layer": "layer3",
|
| 182 |
+
"initial_noise": 0.1,
|
| 183 |
+
"diffusion_noise": 0.003,
|
| 184 |
+
"step_size": 0.5,
|
| 185 |
+
"iterations": 101,
|
| 186 |
+
"epsilon": 3.0
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
@GPU
|
| 192 |
+
def run_inference(image, model_type, inference_type, eps_value, num_iterations,
|
| 193 |
+
initial_noise=0.05, diffusion_noise=0.3, step_size=0.8, model_layer="layer3"):
|
| 194 |
+
# Check if image is provided
|
| 195 |
+
if image is None:
|
| 196 |
+
return None, "Please upload an image before running inference."
|
| 197 |
+
|
| 198 |
+
# Convert eps to float
|
| 199 |
+
eps = float(eps_value)
|
| 200 |
+
|
| 201 |
+
# Load inference configuration based on the selected type
|
| 202 |
+
config = get_inference_configs(inference_type=inference_type, eps=eps, n_itr=int(num_iterations))
|
| 203 |
+
|
| 204 |
+
# Handle Prior-Guided Drift Diffusion specific parameters
|
| 205 |
+
if inference_type == "Prior-Guided Drift Diffusion":
|
| 206 |
+
config['initial_inference_noise_ratio'] = float(initial_noise)
|
| 207 |
+
config['diffusion_noise_ratio'] = float(diffusion_noise)
|
| 208 |
+
config['step_size'] = float(step_size) # Added step size parameter
|
| 209 |
+
config['top_layer'] = model_layer
|
| 210 |
+
|
| 211 |
+
# Run generative inference
|
| 212 |
+
result = model.inference(image, model_type, config)
|
| 213 |
+
|
| 214 |
+
# Extract results based on return type
|
| 215 |
+
if isinstance(result, tuple):
|
| 216 |
+
# Old format returning (output_image, all_steps)
|
| 217 |
+
output_image, all_steps = result
|
| 218 |
+
else:
|
| 219 |
+
# New format returning dictionary
|
| 220 |
+
output_image = result['final_image']
|
| 221 |
+
all_steps = result['steps']
|
| 222 |
+
|
| 223 |
+
# Create animation frames
|
| 224 |
+
frames = []
|
| 225 |
+
for i, step_image in enumerate(all_steps):
|
| 226 |
+
# Convert tensor to PIL image
|
| 227 |
+
step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
| 228 |
+
frames.append(step_pil)
|
| 229 |
+
|
| 230 |
+
# Convert the final output image to PIL
|
| 231 |
+
final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
| 232 |
+
|
| 233 |
+
# Return the final inferred image and the animation frames directly
|
| 234 |
+
return final_image, frames
|
| 235 |
+
|
| 236 |
+
# Helper function to apply example parameters
|
| 237 |
+
def apply_example(example):
|
| 238 |
+
return [
|
| 239 |
+
example["image"],
|
| 240 |
+
"resnet50_robust", # Model type
|
| 241 |
+
example["method"], # Inference type
|
| 242 |
+
example["reverse_diff"]["epsilon"], # Epsilon value
|
| 243 |
+
example["reverse_diff"]["iterations"], # Number of iterations
|
| 244 |
+
example["reverse_diff"]["initial_noise"], # Initial noise
|
| 245 |
+
example["reverse_diff"]["diffusion_noise"], # Diffusion noise value (corrected)
|
| 246 |
+
example["reverse_diff"]["step_size"], # Step size (added)
|
| 247 |
+
example["reverse_diff"]["layer"], # Model layer
|
| 248 |
+
gr.Group(visible=True) # Show parameters section
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
# Define the interface
|
| 252 |
+
with gr.Blocks(title="Generative Inference Demo", css="""
|
| 253 |
+
.purple-button {
|
| 254 |
+
background-color: #8B5CF6 !important;
|
| 255 |
+
color: white !important;
|
| 256 |
+
border: none !important;
|
| 257 |
+
}
|
| 258 |
+
.purple-button:hover {
|
| 259 |
+
background-color: #7C3AED !important;
|
| 260 |
+
}
|
| 261 |
+
""") as demo:
|
| 262 |
+
gr.Markdown("# Generative Inference Demo")
|
| 263 |
+
gr.Markdown("This demo showcases how neural networks can perceive visual illusions and develop Gestalt principles of perceptual organization through generative inference.")
|
| 264 |
+
|
| 265 |
+
gr.Markdown("""
|
| 266 |
+
**How to use this demo:**
|
| 267 |
+
- **Load pre-configured examples**: Click on any visual illusion below and hit "Load Parameters" to automatically set up the optimal parameters for that illusion
|
| 268 |
+
- **Run the inference**: After loading parameters or setting your own, hit "Run Inference" to start the generative inference process
|
| 269 |
+
- **You can also upload your own images** and experiment with different parameters to see how they affect the generative inference process
|
| 270 |
+
""")
|
| 271 |
+
|
| 272 |
+
# Main processing interface
|
| 273 |
+
with gr.Row():
|
| 274 |
+
with gr.Column(scale=1):
|
| 275 |
+
# Inputs
|
| 276 |
+
image_input = gr.Image(label="Input Image", type="pil", value=os.path.join("stimuli", "Neon_Color_Circle.jpg"))
|
| 277 |
+
|
| 278 |
+
# Run Inference button right below the image
|
| 279 |
+
run_button = gr.Button("🪄 Run Generative Inference", variant="primary", elem_classes="purple-button")
|
| 280 |
+
|
| 281 |
+
# Parameters toggle button
|
| 282 |
+
params_button = gr.Button("⚙️ Play with the parameters", variant="secondary")
|
| 283 |
+
|
| 284 |
+
# Parameters section (initially hidden)
|
| 285 |
+
with gr.Group(visible=False) as params_section:
|
| 286 |
+
with gr.Row():
|
| 287 |
+
model_choice = gr.Dropdown(
|
| 288 |
+
choices=["resnet50_robust", "standard_resnet50"], # "resnet50_robust_face" - hidden for deployment
|
| 289 |
+
value="resnet50_robust",
|
| 290 |
+
label="Model"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
inference_type = gr.Dropdown(
|
| 294 |
+
choices=["Prior-Guided Drift Diffusion", "IncreaseConfidence"],
|
| 295 |
+
value="Prior-Guided Drift Diffusion",
|
| 296 |
+
label="Inference Method"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
with gr.Row():
|
| 300 |
+
eps_slider = gr.Slider(minimum=0.0, maximum=40.0, value=20.0, step=0.01, label="Epsilon (Stimulus Fidelity)")
|
| 301 |
+
iterations_slider = gr.Slider(minimum=1, maximum=600, value=101, step=1, label="Number of Iterations") # Updated max to 600
|
| 302 |
+
|
| 303 |
+
with gr.Row():
|
| 304 |
+
initial_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.01,
|
| 305 |
+
label="Drift Noise")
|
| 306 |
+
diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=0.05, value=0.003, step=0.001,
|
| 307 |
+
label="Diffusion Noise") # Corrected name
|
| 308 |
+
|
| 309 |
+
with gr.Row():
|
| 310 |
+
step_size_slider = gr.Slider(minimum=0.01, maximum=2.0, value=1.0, step=0.01,
|
| 311 |
+
label="Update Rate") # Added step size slider
|
| 312 |
+
layer_choice = gr.Dropdown(
|
| 313 |
+
choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
|
| 314 |
+
value="layer3",
|
| 315 |
+
label="Model Layer"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
with gr.Column(scale=2):
|
| 319 |
+
# Outputs
|
| 320 |
+
output_image = gr.Image(label="Final Inferred Image")
|
| 321 |
+
output_frames = gr.Gallery(label="Inference Steps", columns=5, rows=2)
|
| 322 |
+
|
| 323 |
+
# Examples section with integrated explanations
|
| 324 |
+
gr.Markdown("## Visual Illusion Examples")
|
| 325 |
+
gr.Markdown("Select an illusion to load its parameters and see how generative inference reveals perceptual effects")
|
| 326 |
+
|
| 327 |
+
# For each example, create a row with the image and explanation side by side
|
| 328 |
+
for i, ex in enumerate(examples):
|
| 329 |
+
with gr.Row():
|
| 330 |
+
# Left column for the image
|
| 331 |
+
with gr.Column(scale=1):
|
| 332 |
+
# Display the example image
|
| 333 |
+
example_img = gr.Image(value=ex["image"], type="filepath", label=f"{ex['name']}")
|
| 334 |
+
load_btn = gr.Button(f"Load Parameters", variant="primary")
|
| 335 |
+
|
| 336 |
+
# Set up the load button to apply this example's parameters
|
| 337 |
+
load_btn.click(
|
| 338 |
+
fn=lambda ex=ex: apply_example(ex),
|
| 339 |
+
outputs=[
|
| 340 |
+
image_input, model_choice, inference_type,
|
| 341 |
+
eps_slider, iterations_slider,
|
| 342 |
+
initial_noise_slider, diffusion_noise_slider,
|
| 343 |
+
step_size_slider, layer_choice, params_section
|
| 344 |
+
]
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Right column for the explanation
|
| 348 |
+
with gr.Column(scale=2):
|
| 349 |
+
gr.Markdown(f"### {ex['name']}")
|
| 350 |
+
gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
|
| 351 |
+
|
| 352 |
+
# Show instructions if they exist
|
| 353 |
+
if "instructions" in ex:
|
| 354 |
+
gr.Markdown(f"**Instructions:** {ex['instructions']}")
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if i < len(examples) - 1: # Don't add separator after the last example
|
| 358 |
+
gr.Markdown("---")
|
| 359 |
+
|
| 360 |
+
# Set up event handler for the main inference
|
| 361 |
+
run_button.click(
|
| 362 |
+
fn=run_inference,
|
| 363 |
+
inputs=[
|
| 364 |
+
image_input, model_choice, inference_type,
|
| 365 |
+
eps_slider, iterations_slider,
|
| 366 |
+
initial_noise_slider, diffusion_noise_slider,
|
| 367 |
+
step_size_slider, layer_choice
|
| 368 |
+
],
|
| 369 |
+
outputs=[output_image, output_frames]
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Toggle parameters visibility
|
| 373 |
+
def toggle_params():
|
| 374 |
+
return gr.Group(visible=True)
|
| 375 |
+
|
| 376 |
+
params_button.click(
|
| 377 |
+
fn=toggle_params,
|
| 378 |
+
outputs=[params_section]
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# About section
|
| 382 |
+
gr.Markdown("""
|
| 383 |
+
## About Generative Inference
|
| 384 |
+
|
| 385 |
+
Generative inference is a technique that reveals how neural networks perceive visual stimuli. This demo primarily uses the Prior-Guided Drift Diffusion method.
|
| 386 |
+
|
| 387 |
+
### Prior-Guided Drift Diffusion
|
| 388 |
+
Moving away from a noisy representation of the input images
|
| 389 |
+
|
| 390 |
+
### IncreaseConfidence
|
| 391 |
+
Moving away from the least likely class identified at iteration 0 (fast perception)
|
| 392 |
+
|
| 393 |
+
### Parameters:
|
| 394 |
+
- **Drift Noise**: Controls the amount of noise added to the image at the beginning
|
| 395 |
+
- **Diffusion Noise**: Controls the amount of noise added at each optimization step
|
| 396 |
+
- **Update Rate**: Learning rate for the optimization process
|
| 397 |
+
- **Number of Iterations**: How many optimization steps to perform
|
| 398 |
+
- **Model Layer**: Select a specific layer of the ResNet50 model to extract features from
|
| 399 |
+
- **Epsilon (Stimulus Fidelity)**: Controls the size of perturbation during optimization
|
| 400 |
+
|
| 401 |
+
**Generative Inference was developed by [Tahereh Toosi](https://toosi.github.io).**
|
| 402 |
+
""")
|
| 403 |
+
|
| 404 |
+
# Launch the demo
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
print(f"Starting server on port {args.port}")
|
| 407 |
+
demo.launch(
|
| 408 |
+
server_name="0.0.0.0",
|
| 409 |
+
server_port=args.port,
|
| 410 |
+
share=False,
|
| 411 |
+
debug=True
|
| 412 |
+
)
|
face_vase.png
ADDED
|
huggingface-metadata.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"title": "Generative Inference Demo",
|
| 3 |
+
"emoji": "🧠",
|
| 4 |
+
"colorFrom": "indigo",
|
| 5 |
+
"colorTo": "purple",
|
| 6 |
+
"sdk": "gradio",
|
| 7 |
+
"sdk_version": "3.32.0",
|
| 8 |
+
"app_file": "app.py",
|
| 9 |
+
"pinned": false,
|
| 10 |
+
"license": "mit"
|
| 11 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,923 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as models
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from torchvision.models.resnet import ResNet50_Weights
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import requests
|
| 11 |
+
import time
|
| 12 |
+
import copy
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Check for available hardware acceleration
|
| 17 |
+
if torch.cuda.is_available():
|
| 18 |
+
device = torch.device("cuda")
|
| 19 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 20 |
+
device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs
|
| 21 |
+
else:
|
| 22 |
+
device = torch.device("cpu")
|
| 23 |
+
print(f"Using device: {device}")
|
| 24 |
+
|
| 25 |
+
# Constants
|
| 26 |
+
MODEL_URLS = {
|
| 27 |
+
'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
|
| 28 |
+
'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
|
| 29 |
+
'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/blob/main/100_checkpoint.pt'
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 33 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 34 |
+
|
| 35 |
+
# Define the transforms based on whether normalization is on or off
|
| 36 |
+
def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD):
|
| 37 |
+
if normalize:
|
| 38 |
+
return transforms.Compose([
|
| 39 |
+
transforms.Resize(input_size),
|
| 40 |
+
transforms.CenterCrop(input_size),
|
| 41 |
+
transforms.ToTensor(),
|
| 42 |
+
transforms.Normalize(norm_mean, norm_std),
|
| 43 |
+
])
|
| 44 |
+
else:
|
| 45 |
+
return transforms.Compose([
|
| 46 |
+
transforms.Resize(input_size),
|
| 47 |
+
transforms.CenterCrop(input_size),
|
| 48 |
+
transforms.ToTensor(),
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
# Default transform without normalization
|
| 52 |
+
transform = transforms.Compose([
|
| 53 |
+
transforms.Resize(224),
|
| 54 |
+
transforms.CenterCrop(224),
|
| 55 |
+
transforms.ToTensor(),
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
|
| 59 |
+
|
| 60 |
+
def extract_middle_layers(model, layer_index):
|
| 61 |
+
"""
|
| 62 |
+
Extract a subset of the model up to a specific layer.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
model: The neural network model
|
| 66 |
+
layer_index: String 'all' for the full model, or a layer identifier (string or int)
|
| 67 |
+
For ResNet: integers 0-8 representing specific layers
|
| 68 |
+
For ViT: strings like 'encoder.layers.encoder_layer_3'
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
A modified model that outputs features from the specified layer
|
| 72 |
+
"""
|
| 73 |
+
if isinstance(layer_index, str) and layer_index == 'all':
|
| 74 |
+
return model
|
| 75 |
+
|
| 76 |
+
# Special case for ViT's encoder layers with DataParallel wrapper
|
| 77 |
+
if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'):
|
| 78 |
+
try:
|
| 79 |
+
target_layer_idx = int(layer_index.split('_')[-1])
|
| 80 |
+
|
| 81 |
+
# Create a deep copy of the model to avoid modifying the original
|
| 82 |
+
new_model = copy.deepcopy(model)
|
| 83 |
+
|
| 84 |
+
# For models wrapped in DataParallel
|
| 85 |
+
if hasattr(new_model, 'module'):
|
| 86 |
+
# Create a subset of encoder layers up to the specified index
|
| 87 |
+
encoder_layers = nn.Sequential()
|
| 88 |
+
for i in range(target_layer_idx + 1):
|
| 89 |
+
layer_name = f"encoder_layer_{i}"
|
| 90 |
+
if hasattr(new_model.module.encoder.layers, layer_name):
|
| 91 |
+
encoder_layers.add_module(layer_name,
|
| 92 |
+
getattr(new_model.module.encoder.layers, layer_name))
|
| 93 |
+
|
| 94 |
+
# Replace the encoder layers with our truncated version
|
| 95 |
+
new_model.module.encoder.layers = encoder_layers
|
| 96 |
+
|
| 97 |
+
# Remove the heads since we're stopping at the encoder layer
|
| 98 |
+
new_model.module.heads = nn.Identity()
|
| 99 |
+
|
| 100 |
+
return new_model
|
| 101 |
+
else:
|
| 102 |
+
# Direct model access (not DataParallel)
|
| 103 |
+
encoder_layers = nn.Sequential()
|
| 104 |
+
for i in range(target_layer_idx + 1):
|
| 105 |
+
layer_name = f"encoder_layer_{i}"
|
| 106 |
+
if hasattr(new_model.encoder.layers, layer_name):
|
| 107 |
+
encoder_layers.add_module(layer_name,
|
| 108 |
+
getattr(new_model.encoder.layers, layer_name))
|
| 109 |
+
|
| 110 |
+
# Replace the encoder layers with our truncated version
|
| 111 |
+
new_model.encoder.layers = encoder_layers
|
| 112 |
+
|
| 113 |
+
# Remove the heads since we're stopping at the encoder layer
|
| 114 |
+
new_model.heads = nn.Identity()
|
| 115 |
+
|
| 116 |
+
return new_model
|
| 117 |
+
|
| 118 |
+
except (ValueError, IndexError) as e:
|
| 119 |
+
raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}")
|
| 120 |
+
|
| 121 |
+
# Handling for ViT whole blocks
|
| 122 |
+
elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')):
|
| 123 |
+
# Check for DataParallel wrapper
|
| 124 |
+
base_model = model.module if hasattr(model, 'module') else model
|
| 125 |
+
|
| 126 |
+
# Create a deep copy to avoid modifying the original
|
| 127 |
+
new_model = copy.deepcopy(model)
|
| 128 |
+
base_new_model = new_model.module if hasattr(new_model, 'module') else new_model
|
| 129 |
+
|
| 130 |
+
# Add the desired number of transformer blocks
|
| 131 |
+
if isinstance(layer_index, int):
|
| 132 |
+
# Truncate the blocks
|
| 133 |
+
base_new_model.blocks = base_new_model.blocks[:layer_index+1]
|
| 134 |
+
|
| 135 |
+
return new_model
|
| 136 |
+
|
| 137 |
+
else:
|
| 138 |
+
# Original ResNet/VGG handling
|
| 139 |
+
modules = list(model.named_children())
|
| 140 |
+
print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}")
|
| 141 |
+
|
| 142 |
+
cutoff_idx = next((i for i, (name, _) in enumerate(modules)
|
| 143 |
+
if name == str(layer_index)), None)
|
| 144 |
+
|
| 145 |
+
if cutoff_idx is not None:
|
| 146 |
+
# Keep modules up to and including the target
|
| 147 |
+
new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1]))
|
| 148 |
+
return new_model
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError(f"Module {layer_index} not found in model")
|
| 151 |
+
|
| 152 |
+
# Get ImageNet labels
|
| 153 |
+
def get_imagenet_labels():
|
| 154 |
+
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
|
| 155 |
+
response = requests.get(url)
|
| 156 |
+
if response.status_code == 200:
|
| 157 |
+
return response.json()
|
| 158 |
+
else:
|
| 159 |
+
raise RuntimeError("Failed to fetch ImageNet labels")
|
| 160 |
+
|
| 161 |
+
# Download model if needed
|
| 162 |
+
def download_model(model_type):
|
| 163 |
+
if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
|
| 164 |
+
return None # Use PyTorch's pretrained model
|
| 165 |
+
|
| 166 |
+
# Handle special case for face model
|
| 167 |
+
if model_type == 'resnet50_robust_face':
|
| 168 |
+
model_path = Path("models/resnet50_robust_face_100_checkpoint.pt")
|
| 169 |
+
else:
|
| 170 |
+
model_path = Path(f"models/{model_type}.pt")
|
| 171 |
+
|
| 172 |
+
if not model_path.exists():
|
| 173 |
+
print(f"Downloading {model_type} model...")
|
| 174 |
+
url = MODEL_URLS[model_type]
|
| 175 |
+
response = requests.get(url, stream=True)
|
| 176 |
+
if response.status_code == 200:
|
| 177 |
+
with open(model_path, 'wb') as f:
|
| 178 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 179 |
+
f.write(chunk)
|
| 180 |
+
print(f"Model downloaded and saved to {model_path}")
|
| 181 |
+
else:
|
| 182 |
+
raise RuntimeError(f"Failed to download model: {response.status_code}")
|
| 183 |
+
return model_path
|
| 184 |
+
|
| 185 |
+
class NormalizeByChannelMeanStd(nn.Module):
|
| 186 |
+
def __init__(self, mean, std):
|
| 187 |
+
super(NormalizeByChannelMeanStd, self).__init__()
|
| 188 |
+
if not isinstance(mean, torch.Tensor):
|
| 189 |
+
mean = torch.tensor(mean)
|
| 190 |
+
if not isinstance(std, torch.Tensor):
|
| 191 |
+
std = torch.tensor(std)
|
| 192 |
+
self.register_buffer("mean", mean)
|
| 193 |
+
self.register_buffer("std", std)
|
| 194 |
+
|
| 195 |
+
def forward(self, tensor):
|
| 196 |
+
return self.normalize_fn(tensor, self.mean, self.std)
|
| 197 |
+
|
| 198 |
+
def normalize_fn(self, tensor, mean, std):
|
| 199 |
+
"""Differentiable version of torchvision.functional.normalize"""
|
| 200 |
+
# here we assume the color channel is at dim=1
|
| 201 |
+
mean = mean[None, :, None, None]
|
| 202 |
+
std = std[None, :, None, None]
|
| 203 |
+
return tensor.sub(mean).div(std)
|
| 204 |
+
|
| 205 |
+
class InferStep:
|
| 206 |
+
def __init__(self, orig_image, eps, step_size):
|
| 207 |
+
self.orig_image = orig_image
|
| 208 |
+
self.eps = eps
|
| 209 |
+
self.step_size = step_size
|
| 210 |
+
|
| 211 |
+
def project(self, x):
|
| 212 |
+
diff = x - self.orig_image
|
| 213 |
+
diff = torch.clamp(diff, -self.eps, self.eps)
|
| 214 |
+
return torch.clamp(self.orig_image + diff, 0, 1)
|
| 215 |
+
|
| 216 |
+
def step(self, x, grad):
|
| 217 |
+
l = len(x.shape) - 1
|
| 218 |
+
grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l))
|
| 219 |
+
scaled_grad = grad / (grad_norm + 1e-10)
|
| 220 |
+
return scaled_grad * self.step_size
|
| 221 |
+
|
| 222 |
+
def get_iterations_to_show(n_itr):
|
| 223 |
+
"""Generate a dynamic list of iterations to show based on total iterations."""
|
| 224 |
+
if n_itr <= 50:
|
| 225 |
+
return [1, 5, 10, 20, 30, 40, 50, n_itr]
|
| 226 |
+
elif n_itr <= 100:
|
| 227 |
+
return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
|
| 228 |
+
elif n_itr <= 200:
|
| 229 |
+
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
|
| 230 |
+
elif n_itr <= 500:
|
| 231 |
+
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
|
| 232 |
+
else:
|
| 233 |
+
# For very large iterations, show more evenly distributed points
|
| 234 |
+
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
|
| 235 |
+
int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
|
| 236 |
+
|
| 237 |
+
def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
|
| 238 |
+
"""Generate inference configuration with customizable parameters.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion')
|
| 242 |
+
eps (float): Maximum perturbation size
|
| 243 |
+
n_itr (int): Number of iterations
|
| 244 |
+
step_size (float): Step size for each iteration
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
# Base configuration common to all inference types
|
| 248 |
+
config = {
|
| 249 |
+
'loss_infer': inference_type, # How to guide the optimization
|
| 250 |
+
'n_itr': n_itr, # Number of iterations
|
| 251 |
+
'eps': eps, # Maximum perturbation size
|
| 252 |
+
'step_size': step_size, # Step size for each iteration
|
| 253 |
+
'diffusion_noise_ratio': 0.0, # No diffusion noise
|
| 254 |
+
'initial_inference_noise_ratio': 0.0, # No initial noise
|
| 255 |
+
'top_layer': 'all', # Use all layers of the model
|
| 256 |
+
'inference_normalization': False, # Apply normalization during inference
|
| 257 |
+
'recognition_normalization': False, # Apply normalization during recognition
|
| 258 |
+
'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
|
| 259 |
+
'misc_info': {'keep_grads': False} # Additional configuration
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
# Customize based on inference type
|
| 263 |
+
if inference_type == 'IncreaseConfidence':
|
| 264 |
+
config['loss_function'] = 'CE' # Cross Entropy
|
| 265 |
+
|
| 266 |
+
elif inference_type == 'Prior-Guided Drift Diffusion':
|
| 267 |
+
config['loss_function'] = 'MSE' # Mean Square Error
|
| 268 |
+
config['initial_inference_noise_ratio'] = 0.05 # Initial noise for diffusion
|
| 269 |
+
config['diffusion_noise_ratio'] = 0.01 # Add noise during diffusion
|
| 270 |
+
|
| 271 |
+
elif inference_type == 'GradModulation':
|
| 272 |
+
config['loss_function'] = 'CE' # Cross Entropy
|
| 273 |
+
config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength
|
| 274 |
+
|
| 275 |
+
elif inference_type == 'CompositionalFusion':
|
| 276 |
+
config['loss_function'] = 'CE' # Cross Entropy
|
| 277 |
+
config['misc_info']['positive_classes'] = [] # Classes to maximize
|
| 278 |
+
config['misc_info']['negative_classes'] = [] # Classes to minimize
|
| 279 |
+
|
| 280 |
+
return config
|
| 281 |
+
|
| 282 |
+
class GenerativeInferenceModel:
|
| 283 |
+
def __init__(self):
|
| 284 |
+
self.models = {}
|
| 285 |
+
self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
|
| 286 |
+
self.labels = get_imagenet_labels()
|
| 287 |
+
|
| 288 |
+
def verify_model_integrity(self, model, model_type):
|
| 289 |
+
"""
|
| 290 |
+
Verify model integrity by running a test input through it.
|
| 291 |
+
Returns whether the model passes basic integrity check.
|
| 292 |
+
"""
|
| 293 |
+
try:
|
| 294 |
+
print(f"\n=== Running model integrity check for {model_type} ===")
|
| 295 |
+
# Create a deterministic test input directly on the correct device
|
| 296 |
+
test_input = torch.zeros(1, 3, 224, 224, device=device)
|
| 297 |
+
test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
|
| 298 |
+
|
| 299 |
+
# Run forward pass
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
output = model(test_input)
|
| 302 |
+
|
| 303 |
+
# Check output shape
|
| 304 |
+
if output.shape != (1, 1000):
|
| 305 |
+
print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)")
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
# Get top prediction
|
| 309 |
+
probs = torch.nn.functional.softmax(output, dim=1)
|
| 310 |
+
confidence, prediction = torch.max(probs, 1)
|
| 311 |
+
|
| 312 |
+
# Calculate basic statistics on output
|
| 313 |
+
mean = output.mean().item()
|
| 314 |
+
std = output.std().item()
|
| 315 |
+
min_val = output.min().item()
|
| 316 |
+
max_val = output.max().item()
|
| 317 |
+
|
| 318 |
+
print(f"Model integrity check results:")
|
| 319 |
+
print(f"- Output shape: {output.shape}")
|
| 320 |
+
print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
|
| 321 |
+
print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}")
|
| 322 |
+
|
| 323 |
+
# Basic sanity checks
|
| 324 |
+
if torch.isnan(output).any():
|
| 325 |
+
print("❌ Model produced NaN outputs")
|
| 326 |
+
return False
|
| 327 |
+
|
| 328 |
+
if output.std().item() < 0.1:
|
| 329 |
+
print("⚠️ Low output variance, model may not be discriminative")
|
| 330 |
+
|
| 331 |
+
print("✅ Model passes basic integrity check")
|
| 332 |
+
return True
|
| 333 |
+
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print(f"❌ Model integrity check failed with error: {e}")
|
| 336 |
+
# Rather than failing completely, we'll continue
|
| 337 |
+
return True
|
| 338 |
+
|
| 339 |
+
def load_model(self, model_type):
|
| 340 |
+
"""Load model from checkpoint or use pretrained model."""
|
| 341 |
+
if model_type in self.models:
|
| 342 |
+
print(f"Using cached {model_type} model")
|
| 343 |
+
return self.models[model_type]
|
| 344 |
+
|
| 345 |
+
# Record loading time for performance analysis
|
| 346 |
+
start_time = time.time()
|
| 347 |
+
model_path = download_model(model_type)
|
| 348 |
+
|
| 349 |
+
# Create a sequential model with normalizer and ResNet50
|
| 350 |
+
resnet = models.resnet50()
|
| 351 |
+
model = nn.Sequential(
|
| 352 |
+
self.normalizer, # Normalizer is part of the model sequence
|
| 353 |
+
resnet
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Load the model checkpoint
|
| 357 |
+
if model_path:
|
| 358 |
+
print(f"Loading {model_type} model from {model_path}...")
|
| 359 |
+
try:
|
| 360 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 361 |
+
|
| 362 |
+
# Print checkpoint structure for better understanding
|
| 363 |
+
print("\n=== Analyzing checkpoint structure ===")
|
| 364 |
+
if isinstance(checkpoint, dict):
|
| 365 |
+
print(f"Checkpoint contains keys: {list(checkpoint.keys())}")
|
| 366 |
+
|
| 367 |
+
# Examine 'model' structure if it exists
|
| 368 |
+
if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
|
| 369 |
+
model_dict = checkpoint['model']
|
| 370 |
+
# Get sample of keys to understand structure
|
| 371 |
+
first_keys = list(model_dict.keys())[:5]
|
| 372 |
+
print(f"'model' contains keys like: {first_keys}")
|
| 373 |
+
|
| 374 |
+
# Check for common prefixes in the model dict
|
| 375 |
+
prefixes = set()
|
| 376 |
+
for key in list(model_dict.keys())[:100]: # Check first 100 keys
|
| 377 |
+
parts = key.split('.')
|
| 378 |
+
if len(parts) > 1:
|
| 379 |
+
prefixes.add(parts[0])
|
| 380 |
+
if prefixes:
|
| 381 |
+
print(f"Common prefixes in model dict: {prefixes}")
|
| 382 |
+
else:
|
| 383 |
+
print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}")
|
| 384 |
+
|
| 385 |
+
# Handle different checkpoint formats
|
| 386 |
+
if 'model' in checkpoint:
|
| 387 |
+
# Format from madrylab robust models
|
| 388 |
+
state_dict = checkpoint['model']
|
| 389 |
+
print("Using 'model' key from checkpoint")
|
| 390 |
+
elif 'state_dict' in checkpoint:
|
| 391 |
+
state_dict = checkpoint['state_dict']
|
| 392 |
+
print("Using 'state_dict' key from checkpoint")
|
| 393 |
+
else:
|
| 394 |
+
# Direct state dict
|
| 395 |
+
state_dict = checkpoint
|
| 396 |
+
print("Using checkpoint directly as state_dict")
|
| 397 |
+
|
| 398 |
+
# Handle prefix in state dict keys for ResNet part
|
| 399 |
+
resnet_state_dict = {}
|
| 400 |
+
prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.']
|
| 401 |
+
resnet_keys = set(resnet.state_dict().keys())
|
| 402 |
+
|
| 403 |
+
# First check if we can find keys directly in the attacker.model path
|
| 404 |
+
print("\n=== Phase 1: Checking for specific model structures ===")
|
| 405 |
+
|
| 406 |
+
# Check for 'module.model' structure (seen in actual checkpoint)
|
| 407 |
+
module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')]
|
| 408 |
+
if module_model_keys:
|
| 409 |
+
print(f"Found 'module.model' structure with {len(module_model_keys)} parameters")
|
| 410 |
+
# Extract all parameters from module.model
|
| 411 |
+
for source_key, value in state_dict.items():
|
| 412 |
+
if source_key.startswith('module.model.'):
|
| 413 |
+
target_key = source_key[len('module.model.'):]
|
| 414 |
+
resnet_state_dict[target_key] = value
|
| 415 |
+
|
| 416 |
+
print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
|
| 417 |
+
|
| 418 |
+
# Check for 'attacker.model' structure
|
| 419 |
+
attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')]
|
| 420 |
+
if attacker_model_keys:
|
| 421 |
+
print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters")
|
| 422 |
+
# Extract all parameters from attacker.model
|
| 423 |
+
for source_key, value in state_dict.items():
|
| 424 |
+
if source_key.startswith('attacker.model.'):
|
| 425 |
+
target_key = source_key[len('attacker.model.'):]
|
| 426 |
+
resnet_state_dict[target_key] = value
|
| 427 |
+
|
| 428 |
+
print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model")
|
| 429 |
+
|
| 430 |
+
# Check if 'model' (not attacker.model) exists as a fallback
|
| 431 |
+
model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')]
|
| 432 |
+
if model_keys and len(resnet_state_dict) < len(resnet_keys):
|
| 433 |
+
print(f"Found additional 'model.' structure with {len(model_keys)} parameters")
|
| 434 |
+
# Try to complete missing parameters
|
| 435 |
+
for source_key, value in state_dict.items():
|
| 436 |
+
if source_key.startswith('model.'):
|
| 437 |
+
target_key = source_key[len('model.'):]
|
| 438 |
+
if target_key in resnet_keys and target_key not in resnet_state_dict:
|
| 439 |
+
resnet_state_dict[target_key] = value
|
| 440 |
+
|
| 441 |
+
else:
|
| 442 |
+
# Check for other known structures
|
| 443 |
+
structure_found = False
|
| 444 |
+
|
| 445 |
+
# Check for 'model.' prefix
|
| 446 |
+
model_keys = [key for key in state_dict.keys() if key.startswith('model.')]
|
| 447 |
+
if model_keys:
|
| 448 |
+
print(f"Found 'model.' structure with {len(model_keys)} parameters")
|
| 449 |
+
for source_key, value in state_dict.items():
|
| 450 |
+
if source_key.startswith('model.'):
|
| 451 |
+
target_key = source_key[len('model.'):]
|
| 452 |
+
resnet_state_dict[target_key] = value
|
| 453 |
+
structure_found = True
|
| 454 |
+
|
| 455 |
+
# Check for ResNet parameters at the top level
|
| 456 |
+
top_level_resnet_keys = 0
|
| 457 |
+
for key in resnet_keys:
|
| 458 |
+
if key in state_dict:
|
| 459 |
+
top_level_resnet_keys += 1
|
| 460 |
+
|
| 461 |
+
if top_level_resnet_keys > 0:
|
| 462 |
+
print(f"Found {top_level_resnet_keys} ResNet parameters at top level")
|
| 463 |
+
for target_key in resnet_keys:
|
| 464 |
+
if target_key in state_dict:
|
| 465 |
+
resnet_state_dict[target_key] = state_dict[target_key]
|
| 466 |
+
structure_found = True
|
| 467 |
+
|
| 468 |
+
# If no structure was recognized, try the prefix mapping approach
|
| 469 |
+
if not structure_found:
|
| 470 |
+
print("No standard model structure found, trying prefix mappings...")
|
| 471 |
+
for target_key in resnet_keys:
|
| 472 |
+
for prefix in prefixes_to_try:
|
| 473 |
+
source_key = prefix + target_key
|
| 474 |
+
if source_key in state_dict:
|
| 475 |
+
resnet_state_dict[target_key] = state_dict[source_key]
|
| 476 |
+
break
|
| 477 |
+
|
| 478 |
+
# If we still can't find enough keys, try a final approach of removing prefixes
|
| 479 |
+
if len(resnet_state_dict) < len(resnet_keys):
|
| 480 |
+
print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...")
|
| 481 |
+
|
| 482 |
+
# Track matches found through prefix removal
|
| 483 |
+
prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']}
|
| 484 |
+
layer_matches = {} # Track matches by layer type
|
| 485 |
+
|
| 486 |
+
# Count parameter keys by layer type for analysis
|
| 487 |
+
for key in resnet_keys:
|
| 488 |
+
layer_name = key.split('.')[0] if '.' in key else key
|
| 489 |
+
if layer_name not in layer_matches:
|
| 490 |
+
layer_matches[layer_name] = {'total': 0, 'matched': 0}
|
| 491 |
+
layer_matches[layer_name]['total'] += 1
|
| 492 |
+
|
| 493 |
+
# Try keys with common prefixes
|
| 494 |
+
for source_key, value in state_dict.items():
|
| 495 |
+
# Skip if already found
|
| 496 |
+
target_key = source_key
|
| 497 |
+
matched_prefix = None
|
| 498 |
+
|
| 499 |
+
# Try removing various prefixes
|
| 500 |
+
for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']:
|
| 501 |
+
if source_key.startswith(prefix):
|
| 502 |
+
target_key = source_key[len(prefix):]
|
| 503 |
+
matched_prefix = prefix
|
| 504 |
+
break
|
| 505 |
+
|
| 506 |
+
# If the target key is in the ResNet keys, add it to the state dict
|
| 507 |
+
if target_key in resnet_keys and target_key not in resnet_state_dict:
|
| 508 |
+
resnet_state_dict[target_key] = value
|
| 509 |
+
|
| 510 |
+
# Update match statistics
|
| 511 |
+
if matched_prefix:
|
| 512 |
+
prefix_matches[matched_prefix] += 1
|
| 513 |
+
|
| 514 |
+
# Update layer matches
|
| 515 |
+
layer_name = target_key.split('.')[0] if '.' in target_key else target_key
|
| 516 |
+
if layer_name in layer_matches:
|
| 517 |
+
layer_matches[layer_name]['matched'] += 1
|
| 518 |
+
|
| 519 |
+
# Print detailed prefix removal statistics
|
| 520 |
+
print("\n=== Prefix Removal Statistics ===")
|
| 521 |
+
total_matches = sum(prefix_matches.values())
|
| 522 |
+
print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)")
|
| 523 |
+
|
| 524 |
+
# Show matches by prefix
|
| 525 |
+
print("\nMatches by prefix:")
|
| 526 |
+
for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True):
|
| 527 |
+
if count > 0:
|
| 528 |
+
print(f" {prefix}: {count} parameters")
|
| 529 |
+
|
| 530 |
+
# Show matches by layer type
|
| 531 |
+
print("\nMatches by layer type:")
|
| 532 |
+
for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True):
|
| 533 |
+
match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0
|
| 534 |
+
print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)")
|
| 535 |
+
|
| 536 |
+
# Check for specific important layers (conv1, layer1, etc.)
|
| 537 |
+
critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
|
| 538 |
+
print("\nStatus of critical layers:")
|
| 539 |
+
for layer in critical_layers:
|
| 540 |
+
if layer in layer_matches:
|
| 541 |
+
match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100
|
| 542 |
+
status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE"
|
| 543 |
+
print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}")
|
| 544 |
+
else:
|
| 545 |
+
print(f" {layer}: Not found in model")
|
| 546 |
+
|
| 547 |
+
# Load the ResNet state dict
|
| 548 |
+
if resnet_state_dict:
|
| 549 |
+
try:
|
| 550 |
+
# Use strict=False to allow missing keys
|
| 551 |
+
result = resnet.load_state_dict(resnet_state_dict, strict=False)
|
| 552 |
+
missing_keys, unexpected_keys = result
|
| 553 |
+
|
| 554 |
+
# Generate detailed information with better formatting
|
| 555 |
+
loading_report = []
|
| 556 |
+
loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====")
|
| 557 |
+
loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}")
|
| 558 |
+
loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}")
|
| 559 |
+
loading_report.append(f"Missing keys: {len(missing_keys):,} parameters")
|
| 560 |
+
loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters")
|
| 561 |
+
|
| 562 |
+
# Calculate percentage of parameters loaded
|
| 563 |
+
loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys)
|
| 564 |
+
loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100
|
| 565 |
+
|
| 566 |
+
# Determine loading success status
|
| 567 |
+
if loaded_percent >= 99.5:
|
| 568 |
+
status = "✅ COMPLETE - All important parameters loaded"
|
| 569 |
+
elif loaded_percent >= 90:
|
| 570 |
+
status = "🟡 PARTIAL - Most parameters loaded, should still function"
|
| 571 |
+
elif loaded_percent >= 50:
|
| 572 |
+
status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly"
|
| 573 |
+
else:
|
| 574 |
+
status = "❌ FAILED - Critical parameters missing, will not function properly"
|
| 575 |
+
|
| 576 |
+
loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)")
|
| 577 |
+
loading_report.append(f"Loading status: {status}")
|
| 578 |
+
|
| 579 |
+
# If loading is severely incomplete, fall back to PyTorch's pretrained model
|
| 580 |
+
if loaded_percent < 50:
|
| 581 |
+
loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.")
|
| 582 |
+
loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.")
|
| 583 |
+
|
| 584 |
+
# Create a new ResNet model with pretrained weights
|
| 585 |
+
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 586 |
+
model = nn.Sequential(self.normalizer, resnet)
|
| 587 |
+
loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
|
| 588 |
+
|
| 589 |
+
# Show missing keys by layer type
|
| 590 |
+
if missing_keys:
|
| 591 |
+
loading_report.append("\nMissing keys by layer type:")
|
| 592 |
+
layer_types = {}
|
| 593 |
+
for key in missing_keys:
|
| 594 |
+
# Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.)
|
| 595 |
+
parts = key.split('.')
|
| 596 |
+
if len(parts) > 0:
|
| 597 |
+
layer_type = parts[0]
|
| 598 |
+
if layer_type not in layer_types:
|
| 599 |
+
layer_types[layer_type] = 0
|
| 600 |
+
layer_types[layer_type] += 1
|
| 601 |
+
|
| 602 |
+
# Add counts by layer type
|
| 603 |
+
for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
|
| 604 |
+
loading_report.append(f" {layer_type}: {count:,} parameters")
|
| 605 |
+
|
| 606 |
+
loading_report.append("\nFirst 10 missing keys:")
|
| 607 |
+
for i, key in enumerate(sorted(missing_keys)[:10]):
|
| 608 |
+
loading_report.append(f" {i+1}. {key}")
|
| 609 |
+
|
| 610 |
+
# Show unexpected keys if any
|
| 611 |
+
if unexpected_keys:
|
| 612 |
+
loading_report.append("\nFirst 10 unexpected keys:")
|
| 613 |
+
for i, key in enumerate(sorted(unexpected_keys)[:10]):
|
| 614 |
+
loading_report.append(f" {i+1}. {key}")
|
| 615 |
+
|
| 616 |
+
loading_report.append("========================================")
|
| 617 |
+
|
| 618 |
+
# Convert report to string and print it
|
| 619 |
+
report_text = "\n".join(loading_report)
|
| 620 |
+
print(report_text)
|
| 621 |
+
|
| 622 |
+
# Also save to a file for reference
|
| 623 |
+
os.makedirs("logs", exist_ok=True)
|
| 624 |
+
with open(f"logs/model_loading_{model_type}.log", "w") as f:
|
| 625 |
+
f.write(report_text)
|
| 626 |
+
|
| 627 |
+
# Look for normalizer parameters as well
|
| 628 |
+
if any(key.startswith('attacker.normalize.') for key in state_dict.keys()):
|
| 629 |
+
norm_state_dict = {}
|
| 630 |
+
for key, value in state_dict.items():
|
| 631 |
+
if key.startswith('attacker.normalize.'):
|
| 632 |
+
norm_key = key[len('attacker.normalize.'):]
|
| 633 |
+
norm_state_dict[norm_key] = value
|
| 634 |
+
|
| 635 |
+
if norm_state_dict:
|
| 636 |
+
try:
|
| 637 |
+
self.normalizer.load_state_dict(norm_state_dict, strict=False)
|
| 638 |
+
print("Successfully loaded normalizer parameters")
|
| 639 |
+
except Exception as e:
|
| 640 |
+
print(f"Warning: Could not load normalizer parameters: {e}")
|
| 641 |
+
except Exception as e:
|
| 642 |
+
print(f"Warning: Error loading ResNet parameters: {e}")
|
| 643 |
+
# Fall back to loading without normalizer
|
| 644 |
+
model = resnet # Use just the ResNet model without normalizer
|
| 645 |
+
except Exception as e:
|
| 646 |
+
print(f"Error loading model checkpoint: {e}")
|
| 647 |
+
# Fallback to PyTorch's pretrained model
|
| 648 |
+
print("Falling back to PyTorch's pretrained model")
|
| 649 |
+
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 650 |
+
model = nn.Sequential(self.normalizer, resnet)
|
| 651 |
+
else:
|
| 652 |
+
# Fallback to PyTorch's pretrained model
|
| 653 |
+
print("No checkpoint available, using PyTorch's pretrained model")
|
| 654 |
+
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 655 |
+
model = nn.Sequential(self.normalizer, resnet)
|
| 656 |
+
|
| 657 |
+
model = model.to(device)
|
| 658 |
+
model.eval() # Set to evaluation mode
|
| 659 |
+
|
| 660 |
+
# Verify model integrity
|
| 661 |
+
self.verify_model_integrity(model, model_type)
|
| 662 |
+
|
| 663 |
+
# Store the model for future use
|
| 664 |
+
self.models[model_type] = model
|
| 665 |
+
end_time = time.time()
|
| 666 |
+
load_time = end_time - start_time
|
| 667 |
+
print(f"Model {model_type} loaded in {load_time:.2f} seconds")
|
| 668 |
+
return model
|
| 669 |
+
|
| 670 |
+
def inference(self, image, model_type, config):
|
| 671 |
+
"""Run generative inference on the image."""
|
| 672 |
+
# Time the entire inference process
|
| 673 |
+
inference_start = time.time()
|
| 674 |
+
|
| 675 |
+
# Load model if not already loaded
|
| 676 |
+
model = self.load_model(model_type)
|
| 677 |
+
|
| 678 |
+
# Check if image is a file path
|
| 679 |
+
if isinstance(image, str):
|
| 680 |
+
if os.path.exists(image):
|
| 681 |
+
image = Image.open(image).convert('RGB')
|
| 682 |
+
else:
|
| 683 |
+
raise ValueError(f"Image path does not exist: {image}")
|
| 684 |
+
elif isinstance(image, torch.Tensor):
|
| 685 |
+
raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor")
|
| 686 |
+
|
| 687 |
+
# Prepare image tensor - match original code's conditional transform
|
| 688 |
+
load_start = time.time()
|
| 689 |
+
use_norm = config['inference_normalization'] == 'on'
|
| 690 |
+
custom_transform = get_transform(
|
| 691 |
+
input_size=224,
|
| 692 |
+
normalize=use_norm,
|
| 693 |
+
norm_mean=IMAGENET_MEAN,
|
| 694 |
+
norm_std=IMAGENET_STD
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
# Special handling for GradModulation as in original
|
| 698 |
+
if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']:
|
| 699 |
+
grad_modulation = config['misc_info']['grad_modulation']
|
| 700 |
+
image_tensor = custom_transform(image).unsqueeze(0).to(device)
|
| 701 |
+
image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device)
|
| 702 |
+
else:
|
| 703 |
+
image_tensor = custom_transform(image).unsqueeze(0).to(device)
|
| 704 |
+
|
| 705 |
+
image_tensor.requires_grad = True
|
| 706 |
+
print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds")
|
| 707 |
+
|
| 708 |
+
# Check model structure
|
| 709 |
+
is_sequential = isinstance(model, nn.Sequential)
|
| 710 |
+
|
| 711 |
+
# Get original predictions
|
| 712 |
+
with torch.no_grad():
|
| 713 |
+
# If the model is sequential with a normalizer, skip the normalization step
|
| 714 |
+
if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
|
| 715 |
+
print("Model is sequential with normalization")
|
| 716 |
+
# Get the core model part (typically at index 1 in Sequential)
|
| 717 |
+
core_model = model[1]
|
| 718 |
+
if config['inference_normalization']:
|
| 719 |
+
output_original = model(image_tensor) # Model includes normalization
|
| 720 |
+
else:
|
| 721 |
+
output_original = core_model(image_tensor) # Model includes normalization
|
| 722 |
+
|
| 723 |
+
else:
|
| 724 |
+
print("Model is not sequential with normalization")
|
| 725 |
+
# Use manual normalization for non-sequential models
|
| 726 |
+
if config['inference_normalization']:
|
| 727 |
+
normalized_tensor = normalize_transform(image_tensor)
|
| 728 |
+
output_original = model(normalized_tensor)
|
| 729 |
+
else:
|
| 730 |
+
output_original = model(image_tensor)
|
| 731 |
+
core_model = model
|
| 732 |
+
|
| 733 |
+
probs_orig = F.softmax(output_original, dim=1)
|
| 734 |
+
conf_orig, classes_orig = torch.max(probs_orig, 1)
|
| 735 |
+
|
| 736 |
+
# Get least confident classes
|
| 737 |
+
_, least_confident_classes = torch.topk(probs_orig, k=100, largest=False)
|
| 738 |
+
|
| 739 |
+
# Initialize inference step
|
| 740 |
+
infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
|
| 741 |
+
|
| 742 |
+
# Storage for inference steps
|
| 743 |
+
# Create a new tensor that requires gradients
|
| 744 |
+
x = image_tensor.clone().detach().requires_grad_(True)
|
| 745 |
+
all_steps = [image_tensor[0].detach().cpu()]
|
| 746 |
+
|
| 747 |
+
# For Prior-Guided Drift Diffusion, extract selected layer and initialize with noisy features
|
| 748 |
+
noisy_features = None
|
| 749 |
+
layer_model = None
|
| 750 |
+
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
|
| 751 |
+
print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...")
|
| 752 |
+
|
| 753 |
+
# Extract model up to the specified layer
|
| 754 |
+
try:
|
| 755 |
+
# Start by finding the actual model to use
|
| 756 |
+
base_model = model
|
| 757 |
+
|
| 758 |
+
# Handle DataParallel wrapper if present
|
| 759 |
+
if hasattr(base_model, 'module'):
|
| 760 |
+
base_model = base_model.module
|
| 761 |
+
|
| 762 |
+
# Log the initial model structure
|
| 763 |
+
print(f"DEBUG - Initial model structure: {type(base_model)}")
|
| 764 |
+
|
| 765 |
+
# If we have a Sequential model (which is likely our normalizer + model structure)
|
| 766 |
+
if isinstance(base_model, nn.Sequential):
|
| 767 |
+
print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children")
|
| 768 |
+
|
| 769 |
+
# If this is our NormalizeByChannelMeanStd + ResNet pattern
|
| 770 |
+
if len(list(base_model.children())) >= 2:
|
| 771 |
+
# The actual ResNet model is the second component (index 1)
|
| 772 |
+
actual_model = list(base_model.children())[1]
|
| 773 |
+
print(f"DEBUG - Using ResNet component: {type(actual_model)}")
|
| 774 |
+
print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}")
|
| 775 |
+
|
| 776 |
+
# Extract from the actual ResNet
|
| 777 |
+
layer_model = extract_middle_layers(actual_model, config['top_layer'])
|
| 778 |
+
else:
|
| 779 |
+
# Just a single component Sequential
|
| 780 |
+
layer_model = extract_middle_layers(base_model, config['top_layer'])
|
| 781 |
+
else:
|
| 782 |
+
# Not Sequential, might be direct model
|
| 783 |
+
print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}")
|
| 784 |
+
layer_model = extract_middle_layers(base_model, config['top_layer'])
|
| 785 |
+
|
| 786 |
+
print(f"Successfully extracted model up to layer: {config['top_layer']}")
|
| 787 |
+
except ValueError as e:
|
| 788 |
+
print(f"Layer extraction failed: {e}. Using full model.")
|
| 789 |
+
layer_model = model
|
| 790 |
+
|
| 791 |
+
# Add noise to the image - exactly match original code
|
| 792 |
+
added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device)
|
| 793 |
+
noisy_image_tensor = image_tensor + added_noise
|
| 794 |
+
|
| 795 |
+
# Compute noisy features - simplified to match original code
|
| 796 |
+
noisy_features = layer_model(noisy_image_tensor)
|
| 797 |
+
|
| 798 |
+
print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}")
|
| 799 |
+
|
| 800 |
+
# Main inference loop
|
| 801 |
+
print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
|
| 802 |
+
loop_start = time.time()
|
| 803 |
+
for i in range(config['n_itr']):
|
| 804 |
+
# Reset gradients
|
| 805 |
+
x.grad = None
|
| 806 |
+
|
| 807 |
+
# Forward pass - use layer_model for Prior-Guided Drift Diffusion, full model otherwise
|
| 808 |
+
if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
|
| 809 |
+
# Use the extracted layer model for Prior-Guided Drift Diffusion
|
| 810 |
+
# In original code, normalization is handled at transform time, not during forward pass
|
| 811 |
+
output = layer_model(x)
|
| 812 |
+
else:
|
| 813 |
+
# Standard forward pass with full model
|
| 814 |
+
# Simplified to match original code's approach
|
| 815 |
+
output = model(x)
|
| 816 |
+
|
| 817 |
+
# Calculate loss and gradients based on inference type
|
| 818 |
+
try:
|
| 819 |
+
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
|
| 820 |
+
# Use MSE loss to match the noisy features
|
| 821 |
+
assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE"
|
| 822 |
+
if noisy_features is not None:
|
| 823 |
+
loss = F.mse_loss(output, noisy_features)
|
| 824 |
+
grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original
|
| 825 |
+
else:
|
| 826 |
+
raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
|
| 827 |
+
|
| 828 |
+
else: # Default 'IncreaseConfidence' approach
|
| 829 |
+
# Get the least confident classes
|
| 830 |
+
num_classes = min(10, least_confident_classes.size(1))
|
| 831 |
+
target_classes = least_confident_classes[0, :num_classes]
|
| 832 |
+
|
| 833 |
+
# Create targets for least confident classes
|
| 834 |
+
targets = torch.tensor([idx.item() for idx in target_classes], device=device)
|
| 835 |
+
|
| 836 |
+
# Use a combined loss to increase confidence
|
| 837 |
+
loss = 0
|
| 838 |
+
for target in targets:
|
| 839 |
+
# Create one-hot target
|
| 840 |
+
one_hot = torch.zeros_like(output)
|
| 841 |
+
one_hot[0, target] = 1
|
| 842 |
+
# Use loss to maximize confidence
|
| 843 |
+
loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
|
| 844 |
+
|
| 845 |
+
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
|
| 846 |
+
|
| 847 |
+
if grad is None:
|
| 848 |
+
print("Warning: Direct gradient calculation failed")
|
| 849 |
+
# Fall back to random perturbation
|
| 850 |
+
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
|
| 851 |
+
x = infer_step.project(x + random_noise)
|
| 852 |
+
else:
|
| 853 |
+
# Update image with gradient - do this exactly as in original code
|
| 854 |
+
adjusted_grad = infer_step.step(x, grad)
|
| 855 |
+
|
| 856 |
+
# Add diffusion noise if specified
|
| 857 |
+
diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device)
|
| 858 |
+
|
| 859 |
+
# Apply gradient and noise in one operation before projecting, exactly as in original
|
| 860 |
+
x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
|
| 861 |
+
|
| 862 |
+
except Exception as e:
|
| 863 |
+
print(f"Error in gradient calculation: {e}")
|
| 864 |
+
# Fall back to random perturbation - match original code
|
| 865 |
+
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
|
| 866 |
+
x = infer_step.project(x.clone() + random_noise)
|
| 867 |
+
|
| 868 |
+
# Store step if in iterations_to_show
|
| 869 |
+
if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
|
| 870 |
+
all_steps.append(x[0].detach().cpu())
|
| 871 |
+
|
| 872 |
+
# Print some info about the inference
|
| 873 |
+
with torch.no_grad():
|
| 874 |
+
if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
|
| 875 |
+
if config['inference_normalization']:
|
| 876 |
+
final_output = model(x)
|
| 877 |
+
else:
|
| 878 |
+
final_output = core_model(x)
|
| 879 |
+
else:
|
| 880 |
+
if config['inference_normalization']:
|
| 881 |
+
normalized_x = normalize_transform(x)
|
| 882 |
+
final_output = model(normalized_x)
|
| 883 |
+
else:
|
| 884 |
+
final_output = model(x)
|
| 885 |
+
|
| 886 |
+
final_probs = F.softmax(final_output, dim=1)
|
| 887 |
+
final_conf, final_classes = torch.max(final_probs, 1)
|
| 888 |
+
|
| 889 |
+
# Calculate timing information
|
| 890 |
+
loop_time = time.time() - loop_start
|
| 891 |
+
total_time = time.time() - inference_start
|
| 892 |
+
avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0
|
| 893 |
+
|
| 894 |
+
print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
|
| 895 |
+
print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
|
| 896 |
+
print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)")
|
| 897 |
+
print(f"Total inference time: {total_time:.2f} seconds")
|
| 898 |
+
|
| 899 |
+
# Return results in format compatible with both old and new code
|
| 900 |
+
return {
|
| 901 |
+
'final_image': x[0].detach().cpu(),
|
| 902 |
+
'steps': all_steps,
|
| 903 |
+
'original_class': classes_orig.item(),
|
| 904 |
+
'original_confidence': conf_orig.item(),
|
| 905 |
+
'final_class': final_classes.item(),
|
| 906 |
+
'final_confidence': final_conf.item()
|
| 907 |
+
}
|
| 908 |
+
|
| 909 |
+
# Utility function to show inference steps
|
| 910 |
+
def show_inference_steps(steps, figsize=(15, 10)):
|
| 911 |
+
import matplotlib.pyplot as plt
|
| 912 |
+
|
| 913 |
+
n_steps = len(steps)
|
| 914 |
+
fig, axes = plt.subplots(1, n_steps, figsize=figsize)
|
| 915 |
+
|
| 916 |
+
for i, step_img in enumerate(steps):
|
| 917 |
+
img = step_img.permute(1, 2, 0).numpy()
|
| 918 |
+
axes[i].imshow(img)
|
| 919 |
+
axes[i].set_title(f"Step {i}")
|
| 920 |
+
axes[i].axis('off')
|
| 921 |
+
|
| 922 |
+
plt.tight_layout()
|
| 923 |
+
return fig
|
logs/model_loading_resnet50_robust.log
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
===== MODEL LOADING REPORT: resnet50_robust =====
|
| 3 |
+
Total parameters in checkpoint: 320
|
| 4 |
+
Total parameters in model: 320
|
| 5 |
+
Missing keys: 0 parameters
|
| 6 |
+
Unexpected keys: 0 parameters
|
| 7 |
+
Successfully loaded: 320 parameters (100.0%)
|
| 8 |
+
Loading status: ✅ COMPLETE - All important parameters loaded
|
| 9 |
+
========================================
|
logs/model_loading_resnet50_robust_face.log
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
===== MODEL LOADING REPORT: resnet50_robust_face =====
|
| 3 |
+
Total parameters in checkpoint: 320
|
| 4 |
+
Total parameters in model: 320
|
| 5 |
+
Missing keys: 267 parameters
|
| 6 |
+
Unexpected keys: 320 parameters
|
| 7 |
+
Successfully loaded: 0 parameters (0.0%)
|
| 8 |
+
Loading status: ❌ FAILED - Critical parameters missing, will not function properly
|
| 9 |
+
|
| 10 |
+
⚠️ WARNING: Loading from checkpoint is too incomplete.
|
| 11 |
+
⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.
|
| 12 |
+
✅ Successfully loaded PyTorch's pretrained ResNet50 model
|
| 13 |
+
|
| 14 |
+
Missing keys by layer type:
|
| 15 |
+
layer3: 95 parameters
|
| 16 |
+
layer2: 65 parameters
|
| 17 |
+
layer1: 50 parameters
|
| 18 |
+
layer4: 50 parameters
|
| 19 |
+
bn1: 4 parameters
|
| 20 |
+
fc: 2 parameters
|
| 21 |
+
conv1: 1 parameters
|
| 22 |
+
|
| 23 |
+
First 10 missing keys:
|
| 24 |
+
1. bn1.bias
|
| 25 |
+
2. bn1.running_mean
|
| 26 |
+
3. bn1.running_var
|
| 27 |
+
4. bn1.weight
|
| 28 |
+
5. conv1.weight
|
| 29 |
+
6. fc.bias
|
| 30 |
+
7. fc.weight
|
| 31 |
+
8. layer1.0.bn1.bias
|
| 32 |
+
9. layer1.0.bn1.running_mean
|
| 33 |
+
10. layer1.0.bn1.running_var
|
| 34 |
+
|
| 35 |
+
First 10 unexpected keys:
|
| 36 |
+
1. model.bn1.bias
|
| 37 |
+
2. model.bn1.num_batches_tracked
|
| 38 |
+
3. model.bn1.running_mean
|
| 39 |
+
4. model.bn1.running_var
|
| 40 |
+
5. model.bn1.weight
|
| 41 |
+
6. model.conv1.weight
|
| 42 |
+
7. model.fc.bias
|
| 43 |
+
8. model.fc.weight
|
| 44 |
+
9. model.layer1.0.bn1.bias
|
| 45 |
+
10. model.layer1.0.bn1.num_batches_tracked
|
| 46 |
+
========================================
|
logs/model_loading_robust_resnet50.log
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
===== MODEL LOADING REPORT: robust_resnet50 =====
|
| 3 |
+
Total parameters in checkpoint: 320
|
| 4 |
+
Total parameters in model: 320
|
| 5 |
+
Missing keys: 0 parameters
|
| 6 |
+
Unexpected keys: 0 parameters
|
| 7 |
+
Successfully loaded: 320 parameters (100.0%)
|
| 8 |
+
Loading status: ✅ COMPLETE - All important parameters loaded
|
| 9 |
+
========================================
|
logs/model_loading_standard_resnet50.log
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
===== MODEL LOADING REPORT: standard_resnet50 =====
|
| 3 |
+
Total parameters in checkpoint: 320
|
| 4 |
+
Total parameters in model: 320
|
| 5 |
+
Missing keys: 0 parameters
|
| 6 |
+
Unexpected keys: 0 parameters
|
| 7 |
+
Successfully loaded: 320 parameters (100.0%)
|
| 8 |
+
Loading status: ✅ COMPLETE - All important parameters loaded
|
| 9 |
+
========================================
|
models/resnet50_robust.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:380b14e6f9750bffa1447cf7017f65da4dc5ce71a3dd112f107515dcf7b14d9d
|
| 3 |
+
size 204818947
|
models/resnet50_robust_face_100_checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c48a5c16ca0d5ac4cb20f1b98e2128838746f18b658728ac661f1ffd589c37bf
|
| 3 |
+
size 196695413
|
models/robust_resnet50.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:380b14e6f9750bffa1447cf7017f65da4dc5ce71a3dd112f107515dcf7b14d9d
|
| 3 |
+
size 204818947
|
models/standard_resnet50.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72d4a99582db5d7fa86c3fd2a089f0bfd6a10f69d635bca51f6ad72ac6b458f0
|
| 3 |
+
size 204818947
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
pillow
|
| 5 |
+
gradio
|
| 6 |
+
matplotlib
|
| 7 |
+
requests
|
| 8 |
+
tqdm
|
| 9 |
+
huggingface_hub
|
stimuli/Confetti_illusion.png
ADDED
|
Git LFS Details
|
stimuli/CornsweetBlock.png
ADDED
|
Git LFS Details
|
stimuli/EhresteinSingleColor.png
ADDED
|
Git LFS Details
|
stimuli/GroupingByContinuity.png
ADDED
|
stimuli/Kanizsa_square.jpg
ADDED
|
stimuli/Neon_Color_Circle.jpg
ADDED
|
Git LFS Details
|
stimuli/face_vase.png
ADDED
|
stimuli/figure_ground.png
ADDED
|
Git LFS Details
|