Tahereh commited on
Commit
420f791
·
1 Parent(s): 0b968b4

Update to Generative Inference for Psychiatry Demo: add Noise stimulus, update parameters, fix model loading, and improve UI

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model checkpoints (downloaded automatically)
2
+ models/*.pt
3
+ models/*.ckpt
4
+
5
+ # Python cache
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ *.so
10
+
11
+ # Logs
12
+ logs/
13
+ *.log
14
+
15
+ # Environment
16
+ .env
17
+ .venv
18
+ env/
19
+ venv/
20
+
21
+ # IDE
22
+ .vscode/
23
+ .idea/
24
+ *.swp
25
+ *.swo
26
+
27
+ # OS
28
+ .DS_Store
29
+ Thumbs.db
30
+
.smbdeleteAAA29de78c16 ADDED
Binary file (10.9 kB). View file
 
DIFFERENCES.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Differences Between Reference Code and Current Implementation
2
+
3
+ ## Critical Differences Affecting Results
4
+
5
+ ### 1. **First Iteration Handling** ⚠️ **CRITICAL**
6
+ **Reference Code:**
7
+ ```python
8
+ if itr == 0:
9
+ # Don't add priors or diffusion noise to the first iteration
10
+ output = model(image_tensor)
11
+ # ... just get predictions, no gradient update
12
+ else:
13
+ # Calculate loss and gradients
14
+ if loss_infer == 'PGDD':
15
+ loss = torch.nn.functional.mse_loss(features, noisy_features)
16
+ grad = torch.autograd.grad(loss, image_tensor)[0]
17
+ adjusted_grad = inferstep.step(image_tensor, grad)
18
+ # ... apply gradient and noise
19
+ ```
20
+
21
+ **Current Implementation:**
22
+ - **MISSING**: No check for `itr == 0` or `i == 0`
23
+ - Applies gradients and diffusion noise from the very first iteration
24
+ - This causes different starting behavior
25
+
26
+ ### 2. **Model Extraction for PGDD**
27
+ **Reference Code:**
28
+ ```python
29
+ new_model = extract_middle_layers(model.module, top_layer)
30
+ ```
31
+
32
+ **Current Implementation:**
33
+ - Complex logic to handle Sequential models with normalizers
34
+ - Extracts from `model[1]` if Sequential, otherwise from `model`
35
+ - May handle DataParallel differently
36
+
37
+ ### 3. **Gradient Calculation**
38
+ **Reference Code:**
39
+ ```python
40
+ grad = torch.autograd.grad(loss, image_tensor)[0] # No retain_graph for PGDD
41
+ ```
42
+
43
+ **Current Implementation:**
44
+ - Same for PGDD (no retain_graph)
45
+ - But uses `retain_graph=True` for IncreaseConfidence
46
+
47
+ ### 4. **Normalization Handling**
48
+ **Reference Code:**
49
+ - Normalization is applied in the transform at the beginning
50
+ - `inference_normalization` controls whether transform includes normalization
51
+ - Model forward pass uses the already-normalized tensor
52
+
53
+ **Current Implementation:**
54
+ - Complex logic checking if model is Sequential with NormalizeByChannelMeanStd
55
+ - May apply normalization multiple times or inconsistently
56
+ - Different paths for sequential vs non-sequential models
57
+
58
+ ### 5. **Variable Naming and Structure**
59
+ **Reference Code:**
60
+ - Uses `image_tensor` throughout the loop
61
+ - Directly modifies `image_tensor` with `requires_grad=True`
62
+
63
+ **Current Implementation:**
64
+ - Creates separate `x = image_tensor.clone().detach().requires_grad_(True)`
65
+ - Uses `x` in the loop instead of `image_tensor`
66
+
67
+ ### 6. **Loss Function for IncreaseConfidence**
68
+ **Reference Code:**
69
+ ```python
70
+ loss = calculate_loss(features, least_confident_classes[0], loss_function)
71
+ # Uses CrossEntropyLoss or MSELoss based on loss_function
72
+ ```
73
+
74
+ **Current Implementation:**
75
+ ```python
76
+ # Creates one-hot targets and uses MSE on softmax outputs
77
+ loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
78
+ ```
79
+ - Different loss calculation method
80
+ - Uses MSE on softmax probabilities vs CrossEntropy on logits
81
+
82
+ ### 7. **Diffusion Noise Application**
83
+ **Reference Code:**
84
+ ```python
85
+ if itr == 0:
86
+ # Skip noise
87
+ else:
88
+ diffusion_noise = diffusion_noise_ratio * torch.randn_like(image_tensor).cuda()
89
+ if loss_infer == 'GradModulation':
90
+ image_tensor = inferstep.project(
91
+ image_tensor.clone() +
92
+ adjusted_grad * grad_modulation +
93
+ diffusion_noise * grad_modulation
94
+ )
95
+ else:
96
+ image_tensor = inferstep.project(
97
+ image_tensor.clone() + adjusted_grad + diffusion_noise
98
+ )
99
+ ```
100
+
101
+ **Current Implementation:**
102
+ - Always applies diffusion noise (no `itr == 0` check)
103
+ - Applies noise in all iterations including the first
104
+
105
+ ### 8. **Model Forward Pass in Loop**
106
+ **Reference Code:**
107
+ ```python
108
+ if inference_config['misc_info'].get('smooth_inference', False):
109
+ # Smooth inference logic
110
+ else:
111
+ new_model.zero_grad()
112
+ features = new_model(image_tensor)
113
+ ```
114
+
115
+ **Current Implementation:**
116
+ ```python
117
+ x.grad = None # Instead of new_model.zero_grad()
118
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
119
+ output = layer_model(x)
120
+ else:
121
+ output = model(x)
122
+ ```
123
+
124
+ ## Summary of Impact
125
+
126
+ 1. **First iteration difference**: Most critical - reference skips gradient update on iteration 0
127
+ 2. **Normalization**: Different application may cause numerical differences
128
+ 3. **Loss calculation**: Different methods for IncreaseConfidence
129
+ 4. **Model extraction**: May extract different layers due to Sequential handling
130
+
131
+ ## Recommended Fixes
132
+
133
+ 1. Add `if i == 0:` check to skip gradient update on first iteration
134
+ 2. Simplify model extraction to match reference: `extract_middle_layers(model.module, top_layer)`
135
+ 3. Align loss calculation for IncreaseConfidence with reference
136
+ 4. Ensure normalization is applied consistently
137
+
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Generative Inference Demo
3
  emoji: 🧠
4
  colorFrom: indigo
5
  colorTo: purple
 
1
  ---
2
+ title: Generative Inference for Psychiatry Demo
3
  emoji: 🧠
4
  colorFrom: indigo
5
  colorTo: purple
TROUBLESHOOTING.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Troubleshooting: Why It Works on Hugging Face Spaces But Not Locally
2
+
3
+ ## Common Issues and Solutions
4
+
5
+ ### 1. **Missing Dependencies** ⚠️ (Most Common)
6
+
7
+ **Problem**: The required Python packages are not installed locally.
8
+
9
+ **Solution**: Install all dependencies:
10
+ ```bash
11
+ cd /home/tahereh/engram/users/Tahereh/Codes/Public_Codes/Generative_Inference_Faces
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ **Required packages**:
16
+ - `torch` and `torchvision` (PyTorch)
17
+ - `gradio` (for the web interface)
18
+ - `numpy`, `pillow` (PIL), `matplotlib`
19
+ - `requests`, `tqdm`, `huggingface_hub`
20
+
21
+ ### 2. **GPU Decorator** ✅ (Fixed)
22
+
23
+ **Problem**: The `@GPU` decorator from Hugging Face Spaces is not available locally.
24
+
25
+ **Solution**: The code now automatically handles this:
26
+ - On Hugging Face Spaces: Uses the `spaces.GPU` decorator
27
+ - Locally: Uses a no-op decorator (GPU detection is automatic via PyTorch)
28
+
29
+ **Status**: ✅ Fixed in the code
30
+
31
+ ### 3. **Port Configuration** ✅ (Fixed)
32
+
33
+ **Problem**: Port configuration was inconsistent between local and Spaces environments.
34
+
35
+ **Solution**: The code now:
36
+ - Uses port 7860 by default (same as Spaces)
37
+ - Allows custom port via `--port` argument
38
+ - Automatically detects Hugging Face Spaces environment
39
+
40
+ **Status**: ✅ Fixed in the code
41
+
42
+ ### 4. **Model Files Not Downloaded**
43
+
44
+ **Problem**: Model checkpoint files may not be downloaded yet.
45
+
46
+ **Solution**: The code will automatically download models on first run, but you can verify:
47
+ ```bash
48
+ ls models/
49
+ ```
50
+
51
+ Expected files:
52
+ - `resnet50_robust.pt`
53
+ - `standard_resnet50.pt` (optional)
54
+ - `resnet50_robust_face_100_checkpoint.pt` (optional)
55
+
56
+ ### 5. **Missing Stimuli Images**
57
+
58
+ **Problem**: Example images may be missing.
59
+
60
+ **Solution**: Verify stimuli directory exists:
61
+ ```bash
62
+ ls stimuli/
63
+ ```
64
+
65
+ All example images should be present for the demo to work fully.
66
+
67
+ ### 6. **CUDA/GPU Issues**
68
+
69
+ **Problem**: GPU may not be available or configured correctly.
70
+
71
+ **Solution**: The code automatically detects available hardware:
72
+ - CUDA (NVIDIA GPUs)
73
+ - MPS (Apple Silicon)
74
+ - CPU (fallback)
75
+
76
+ Check your setup:
77
+ ```python
78
+ import torch
79
+ print("CUDA available:", torch.cuda.is_available())
80
+ print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
81
+ ```
82
+
83
+ ### 7. **Python Version**
84
+
85
+ **Problem**: Incompatible Python version.
86
+
87
+ **Solution**: Use Python 3.8+ (tested with 3.11.5):
88
+ ```bash
89
+ python --version
90
+ ```
91
+
92
+ ## Quick Start Guide
93
+
94
+ 1. **Install dependencies**:
95
+ ```bash
96
+ pip install -r requirements.txt
97
+ ```
98
+
99
+ 2. **Run the app**:
100
+ ```bash
101
+ python app.py
102
+ ```
103
+
104
+ Or with a custom port:
105
+ ```bash
106
+ python app.py --port 8080
107
+ ```
108
+
109
+ 3. **Access the web interface**:
110
+ - Open your browser to `http://localhost:7860`
111
+ - Or the port you specified
112
+
113
+ ## Differences Between Hugging Face Spaces and Local
114
+
115
+ | Feature | Hugging Face Spaces | Local |
116
+ |---------|-------------------|-------|
117
+ | GPU Decorator | `@spaces.GPU` available | No-op decorator (automatic GPU) |
118
+ | Port | Set via `PORT` env var | Default 7860, or `--port` arg |
119
+ | Dependencies | Pre-installed | Must install manually |
120
+ | Environment | `SPACE_ID` env var set | Not set |
121
+ | Model Storage | Persistent storage | Local `models/` directory |
122
+
123
+ ## Testing the Fixes
124
+
125
+ After applying the fixes, test with:
126
+ ```bash
127
+ # Check imports work
128
+ python -c "import gradio, torch, numpy, PIL; print('All imports OK')"
129
+
130
+ # Run the app
131
+ python app.py --port 7860
132
+ ```
133
+
134
+ ## Still Having Issues?
135
+
136
+ 1. **Check error messages**: Look for specific import errors or file not found errors
137
+ 2. **Verify Python environment**: Make sure you're using the correct virtual environment
138
+ 3. **Check file permissions**: Ensure the `models/` and `stimuli/` directories are writable
139
+ 4. **Review logs**: Check the `logs/` directory for model loading issues
140
+
__pycache__/inference.cpython-311.pyc ADDED
Binary file (40.8 kB). View file
 
app.py CHANGED
@@ -4,10 +4,13 @@ import numpy as np
4
  from PIL import Image
5
  try:
6
  from spaces import GPU
 
7
  except ImportError:
8
  # Define a no-op decorator if running locally
9
  def GPU(func):
 
10
  return func
 
11
 
12
  import os
13
  import argparse
@@ -15,7 +18,7 @@ from inference import GenerativeInferenceModel, get_inference_configs
15
 
16
  # Parse command line arguments
17
  parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
18
- parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
19
  args = parser.parse_args()
20
 
21
  # Create model directories if they don't exist
@@ -26,13 +29,54 @@ os.makedirs("stimuli", exist_ok=True)
26
  if "SPACE_ID" in os.environ:
27
  default_port = int(os.environ.get("PORT", 7860))
28
  else:
29
- default_port = 8861 # Local default port
 
 
 
30
 
31
  # Initialize model
32
  model = GenerativeInferenceModel()
33
 
34
  # Define example images and their parameters with updated values from the research
35
  examples = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  {
37
  "image": os.path.join("stimuli", "Neon_Color_Circle.jpg"),
38
  "name": "Neon Color Spreading",
@@ -91,25 +135,6 @@ examples = [
91
  "epsilon": 20.0
92
  }
93
  },
94
- {
95
- "image": os.path.join("stimuli", "face_vase.png"),
96
- "name": "Rubin's Face-Vase (Object Prior)",
97
- "wiki": "https://en.wikipedia.org/wiki/Rubin_vase",
98
- "papers": [
99
- "[Figure-Ground Perception](https://en.wikipedia.org/wiki/Figure-ground_(perception))",
100
- "[Bistable Perception](https://doi.org/10.1016/j.tics.2003.08.003)"
101
- ],
102
- "method": "Prior-Guided Drift Diffusion",
103
- "reverse_diff": {
104
- "model": "resnet50_robust",
105
- "layer": "avgpool",
106
- "initial_noise": 0.9,
107
- "diffusion_noise": 0.003,
108
- "step_size": 0.58,
109
- "iterations": 100,
110
- "epsilon": 0.81
111
- }
112
- },
113
  {
114
  "image": os.path.join("stimuli", "Confetti_illusion.png"),
115
  "name": "Confetti Illusion",
@@ -223,21 +248,76 @@ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
223
  # Create animation frames
224
  frames = []
225
  for i, step_image in enumerate(all_steps):
226
- # Convert tensor to PIL image
227
- step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
228
- frames.append(step_pil)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  # Convert the final output image to PIL
231
- final_image = Image.fromarray((output_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  # Return the final inferred image and the animation frames directly
234
  return final_image, frames
235
 
236
  # Helper function to apply example parameters
237
  def apply_example(example):
 
 
238
  return [
239
- example["image"],
240
- "resnet50_robust", # Model type
241
  example["method"], # Inference type
242
  example["reverse_diff"]["epsilon"], # Epsilon value
243
  example["reverse_diff"]["iterations"], # Number of iterations
@@ -249,7 +329,7 @@ def apply_example(example):
249
  ]
250
 
251
  # Define the interface
252
- with gr.Blocks(title="Generative Inference Demo", css="""
253
  .purple-button {
254
  background-color: #8B5CF6 !important;
255
  color: white !important;
@@ -259,7 +339,7 @@ with gr.Blocks(title="Generative Inference Demo", css="""
259
  background-color: #7C3AED !important;
260
  }
261
  """) as demo:
262
- gr.Markdown("# Generative Inference Demo")
263
  gr.Markdown("This demo showcases how neural networks can perceive visual illusions and develop Gestalt principles of perceptual organization through generative inference.")
264
 
265
  gr.Markdown("""
@@ -273,7 +353,9 @@ with gr.Blocks(title="Generative Inference Demo", css="""
273
  with gr.Row():
274
  with gr.Column(scale=1):
275
  # Inputs
276
- image_input = gr.Image(label="Input Image", type="pil", value=os.path.join("stimuli", "Neon_Color_Circle.jpg"))
 
 
277
 
278
  # Run Inference button right below the image
279
  run_button = gr.Button("🪄 Run Generative Inference", variant="primary", elem_classes="purple-button")
@@ -286,7 +368,7 @@ with gr.Blocks(title="Generative Inference Demo", css="""
286
  with gr.Row():
287
  model_choice = gr.Dropdown(
288
  choices=["resnet50_robust", "standard_resnet50", "resnet50_robust_face"], # "resnet50_robust_face" - hidden for deployment
289
- value="resnet50_robust",
290
  label="Model"
291
  )
292
 
@@ -297,21 +379,21 @@ with gr.Blocks(title="Generative Inference Demo", css="""
297
  )
298
 
299
  with gr.Row():
300
- eps_slider = gr.Slider(minimum=0.0, maximum=40.0, value=20.0, step=0.01, label="Epsilon (Stimulus Fidelity)")
301
- iterations_slider = gr.Slider(minimum=1, maximum=600, value=101, step=1, label="Number of Iterations") # Updated max to 600
302
 
303
  with gr.Row():
304
- initial_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.01,
305
  label="Drift Noise")
306
- diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=0.05, value=0.003, step=0.001,
307
  label="Diffusion Noise") # Corrected name
308
 
309
  with gr.Row():
310
- step_size_slider = gr.Slider(minimum=0.01, maximum=2.0, value=1.0, step=0.01,
311
  label="Update Rate") # Added step size slider
312
  layer_choice = gr.Dropdown(
313
  choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
314
- value="layer3",
315
  label="Model Layer"
316
  )
317
 
@@ -403,10 +485,10 @@ with gr.Blocks(title="Generative Inference Demo", css="""
403
 
404
  # Launch the demo
405
  if __name__ == "__main__":
406
- print(f"Starting server on port {args.port}")
407
  demo.launch(
408
  server_name="0.0.0.0",
409
- server_port=args.port,
410
  share=False,
411
  debug=True
412
  )
 
4
  from PIL import Image
5
  try:
6
  from spaces import GPU
7
+ print("Running on Hugging Face Spaces - GPU decorator available")
8
  except ImportError:
9
  # Define a no-op decorator if running locally
10
  def GPU(func):
11
+ """No-op decorator for local execution (GPU handling is automatic)"""
12
  return func
13
+ print("Running locally - GPU decorator not available (using automatic GPU detection)")
14
 
15
  import os
16
  import argparse
 
18
 
19
  # Parse command line arguments
20
  parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
21
+ parser.add_argument('--port', type=int, default=None, help='Port to run the server on')
22
  args = parser.parse_args()
23
 
24
  # Create model directories if they don't exist
 
29
  if "SPACE_ID" in os.environ:
30
  default_port = int(os.environ.get("PORT", 7860))
31
  else:
32
+ default_port = 7860 # Use same default port locally
33
+
34
+ # Use command line port if provided, otherwise use default
35
+ server_port = args.port if args.port is not None else default_port
36
 
37
  # Initialize model
38
  model = GenerativeInferenceModel()
39
 
40
  # Define example images and their parameters with updated values from the research
41
  examples = [
42
+ {
43
+ "image": os.path.join("stimuli", "face_vase.png"),
44
+ "name": "Rubin's Face-Vase (Object Prior)",
45
+ "wiki": "https://en.wikipedia.org/wiki/Rubin_vase",
46
+ "papers": [
47
+ "[Figure-Ground Perception](https://en.wikipedia.org/wiki/Figure-ground_(perception))",
48
+ "[Bistable Perception](https://doi.org/10.1016/j.tics.2003.08.003)"
49
+ ],
50
+ "method": "Prior-Guided Drift Diffusion",
51
+ "reverse_diff": {
52
+ "model": "resnet50_robust_face",
53
+ "layer": "layer4",
54
+ "initial_noise": 0.0,
55
+ "diffusion_noise": 0.006,
56
+ "step_size": 0.18,
57
+ "iterations": 100,
58
+ "epsilon": 9.53
59
+ }
60
+ },
61
+ {
62
+ "image": os.path.join("stimuli", "RandomizedPhaseOvalGray.png"),
63
+ "name": "Noise (Randomized Phase Oval)",
64
+ "wiki": "https://en.wikipedia.org/wiki/Visual_noise",
65
+ "papers": [
66
+ "[Perceptual Organization](https://doi.org/10.1016/j.tics.2003.08.003)",
67
+ "[Pattern Recognition](https://en.wikipedia.org/wiki/Pattern_recognition)"
68
+ ],
69
+ "method": "Prior-Guided Drift Diffusion",
70
+ "reverse_diff": {
71
+ "model": "resnet50_robust_face",
72
+ "layer": "all",
73
+ "initial_noise": 0.0,
74
+ "diffusion_noise": 0.05,
75
+ "step_size": 1.12,
76
+ "iterations": 428,
77
+ "epsilon": 198.62
78
+ }
79
+ },
80
  {
81
  "image": os.path.join("stimuli", "Neon_Color_Circle.jpg"),
82
  "name": "Neon Color Spreading",
 
135
  "epsilon": 20.0
136
  }
137
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  {
139
  "image": os.path.join("stimuli", "Confetti_illusion.png"),
140
  "name": "Confetti Illusion",
 
248
  # Create animation frames
249
  frames = []
250
  for i, step_image in enumerate(all_steps):
251
+ # Convert tensor to PIL image with proper error handling
252
+ try:
253
+ # Ensure tensor is on CPU and detached
254
+ if isinstance(step_image, torch.Tensor):
255
+ step_image = step_image.detach().cpu()
256
+ # Handle different tensor shapes
257
+ if len(step_image.shape) == 4: # [B, C, H, W]
258
+ step_image = step_image[0] # Take first batch item
259
+ elif len(step_image.shape) == 3: # [C, H, W]
260
+ pass # Already correct shape
261
+ else:
262
+ raise ValueError(f"Unexpected tensor shape: {step_image.shape}")
263
+
264
+ # Clamp values to [0, 1] range before converting
265
+ step_image = torch.clamp(step_image, 0, 1)
266
+ # Convert to numpy and ensure contiguous array
267
+ step_np = step_image.permute(1, 2, 0).numpy()
268
+ # Ensure it's a contiguous array with correct dtype
269
+ step_np = np.ascontiguousarray(step_np, dtype=np.float32)
270
+ # Convert to uint8
271
+ step_np = (step_np * 255).astype(np.uint8)
272
+ # Create PIL image
273
+ step_pil = Image.fromarray(step_np, mode='RGB')
274
+ frames.append(step_pil)
275
+ else:
276
+ print(f"Warning: step_image at index {i} is not a tensor: {type(step_image)}")
277
+ except Exception as e:
278
+ print(f"Error converting step {i} to PIL image: {e}, shape: {step_image.shape if hasattr(step_image, 'shape') else 'N/A'}")
279
+ # Skip this frame if conversion fails
280
+ continue
281
 
282
  # Convert the final output image to PIL
283
+ try:
284
+ if isinstance(output_image, torch.Tensor):
285
+ output_image = output_image.detach().cpu()
286
+ # Handle different tensor shapes
287
+ if len(output_image.shape) == 4: # [B, C, H, W]
288
+ output_image = output_image[0] # Take first batch item
289
+ elif len(output_image.shape) == 3: # [C, H, W]
290
+ pass # Already correct shape
291
+ else:
292
+ raise ValueError(f"Unexpected tensor shape: {output_image.shape}")
293
+
294
+ # Clamp values to [0, 1] range before converting
295
+ output_image = torch.clamp(output_image, 0, 1)
296
+ # Convert to numpy and ensure contiguous array
297
+ output_np = output_image.permute(1, 2, 0).numpy()
298
+ # Ensure it's a contiguous array with correct dtype
299
+ output_np = np.ascontiguousarray(output_np, dtype=np.float32)
300
+ # Convert to uint8
301
+ output_np = (output_np * 255).astype(np.uint8)
302
+ # Create PIL image
303
+ final_image = Image.fromarray(output_np, mode='RGB')
304
+ else:
305
+ raise ValueError(f"output_image is not a tensor: {type(output_image)}")
306
+ except Exception as e:
307
+ print(f"Error converting final image to PIL: {e}, shape: {output_image.shape if hasattr(output_image, 'shape') else 'N/A'}")
308
+ # Return a black image as fallback
309
+ final_image = Image.new('RGB', (224, 224), color='black')
310
 
311
  # Return the final inferred image and the animation frames directly
312
  return final_image, frames
313
 
314
  # Helper function to apply example parameters
315
  def apply_example(example):
316
+ # Get the full path to the image file
317
+ image_path = os.path.abspath(example["image"]) if os.path.exists(example["image"]) else example["image"]
318
  return [
319
+ image_path,
320
+ example["reverse_diff"]["model"], # Model type from example
321
  example["method"], # Inference type
322
  example["reverse_diff"]["epsilon"], # Epsilon value
323
  example["reverse_diff"]["iterations"], # Number of iterations
 
329
  ]
330
 
331
  # Define the interface
332
+ with gr.Blocks(title="Generative Inference for Psychiatry Demo", css="""
333
  .purple-button {
334
  background-color: #8B5CF6 !important;
335
  color: white !important;
 
339
  background-color: #7C3AED !important;
340
  }
341
  """) as demo:
342
+ gr.Markdown("# Generative Inference for Psychiatry Demo")
343
  gr.Markdown("This demo showcases how neural networks can perceive visual illusions and develop Gestalt principles of perceptual organization through generative inference.")
344
 
345
  gr.Markdown("""
 
353
  with gr.Row():
354
  with gr.Column(scale=1):
355
  # Inputs
356
+ # Use absolute path for default image to avoid directory errors
357
+ default_image_path = os.path.abspath(os.path.join("stimuli", "face_vase.png")) if os.path.exists(os.path.join("stimuli", "face_vase.png")) else None
358
+ image_input = gr.Image(label="Input Image", type="pil", value=default_image_path)
359
 
360
  # Run Inference button right below the image
361
  run_button = gr.Button("🪄 Run Generative Inference", variant="primary", elem_classes="purple-button")
 
368
  with gr.Row():
369
  model_choice = gr.Dropdown(
370
  choices=["resnet50_robust", "standard_resnet50", "resnet50_robust_face"], # "resnet50_robust_face" - hidden for deployment
371
+ value="resnet50_robust_face",
372
  label="Model"
373
  )
374
 
 
379
  )
380
 
381
  with gr.Row():
382
+ eps_slider = gr.Slider(minimum=0.0, maximum=200.0, value=9.53, step=0.01, label="Epsilon (Stimulus Fidelity)")
383
+ iterations_slider = gr.Slider(minimum=1, maximum=600, value=100, step=1, label="Number of Iterations") # Updated max to 600
384
 
385
  with gr.Row():
386
+ initial_noise_slider = gr.Slider(minimum=0.0, maximum=5.0, value=0.0, step=0.01,
387
  label="Drift Noise")
388
+ diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.006, step=0.001,
389
  label="Diffusion Noise") # Corrected name
390
 
391
  with gr.Row():
392
+ step_size_slider = gr.Slider(minimum=0.0, maximum=10.0, value=0.18, step=0.01,
393
  label="Update Rate") # Added step size slider
394
  layer_choice = gr.Dropdown(
395
  choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
396
+ value="layer4",
397
  label="Model Layer"
398
  )
399
 
 
485
 
486
  # Launch the demo
487
  if __name__ == "__main__":
488
+ print(f"Starting server on port {server_port}")
489
  demo.launch(
490
  server_name="0.0.0.0",
491
+ server_port=server_port,
492
  share=False,
493
  debug=True
494
  )
face_vase_black.png ADDED
huggingface-metadata.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "title": "Generative Inference Demo",
3
  "emoji": "🧠",
4
  "colorFrom": "indigo",
5
  "colorTo": "purple",
 
1
  {
2
+ "title": "Generative Inference for Psychiatry Demo",
3
  "emoji": "🧠",
4
  "colorFrom": "indigo",
5
  "colorTo": "purple",
inference.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- import torchvision.models as models
5
  import torchvision.transforms as transforms
 
6
  from torchvision.models.resnet import ResNet50_Weights
7
  from PIL import Image
8
  import numpy as np
@@ -12,6 +14,7 @@ import time
12
  import copy
13
  from collections import OrderedDict
14
  from pathlib import Path
 
15
 
16
  # Check for available hardware acceleration
17
  if torch.cuda.is_available():
@@ -22,175 +25,98 @@ else:
22
  device = torch.device("cpu")
23
  print(f"Using device: {device}")
24
 
25
- # Constants
26
  MODEL_URLS = {
27
  'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
28
  'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
29
  'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt'
30
  }
31
 
32
- # Per-model input size and normalization
33
- MODEL_PREPROC = {
34
- "resnet50_robust": {"size": 224, "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
35
- "resnet50_standard": {"size": 224, "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
36
- # Typical for face models trained ArcFace/InsightFace-style
37
- "resnet50_robust_face": {"size": 112, "mean": [0.5, 0.5, 0.5], "std": [0.5, 0.5, 0.5]},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  }
39
 
40
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
41
  IMAGENET_STD = [0.229, 0.224, 0.225]
42
 
43
- # Define the transforms based on whether normalization is on or off
44
- def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD):
45
- if normalize:
46
- return transforms.Compose([
47
- transforms.Resize(input_size),
48
- transforms.CenterCrop(input_size),
49
- transforms.ToTensor(),
50
- transforms.Normalize(norm_mean, norm_std),
51
- ])
 
52
  else:
53
- return transforms.Compose([
54
- transforms.Resize(input_size),
55
- transforms.CenterCrop(input_size),
56
- transforms.ToTensor(),
57
- ])
58
-
59
- # Default transform without normalization
60
- transform = transforms.Compose([
61
- transforms.Resize(224),
62
- transforms.CenterCrop(224),
63
- transforms.ToTensor(),
64
- ])
65
-
66
- normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
67
 
68
- def extract_middle_layers(model, layer_index):
69
- """
70
- Extract a subset of the model up to a specific layer.
71
-
72
- Args:
73
- model: The neural network model
74
- layer_index: String 'all' for the full model, or a layer identifier (string or int)
75
- For ResNet: integers 0-8 representing specific layers
76
- For ViT: strings like 'encoder.layers.encoder_layer_3'
77
-
78
- Returns:
79
- A modified model that outputs features from the specified layer
80
- """
81
- if isinstance(layer_index, str) and layer_index == 'all':
82
- return model
83
 
84
- # Special case for ViT's encoder layers with DataParallel wrapper
85
- if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'):
86
- try:
87
- target_layer_idx = int(layer_index.split('_')[-1])
88
-
89
- # Create a deep copy of the model to avoid modifying the original
90
- new_model = copy.deepcopy(model)
91
-
92
- # For models wrapped in DataParallel
93
- if hasattr(new_model, 'module'):
94
- # Create a subset of encoder layers up to the specified index
95
- encoder_layers = nn.Sequential()
96
- for i in range(target_layer_idx + 1):
97
- layer_name = f"encoder_layer_{i}"
98
- if hasattr(new_model.module.encoder.layers, layer_name):
99
- encoder_layers.add_module(layer_name,
100
- getattr(new_model.module.encoder.layers, layer_name))
101
-
102
- # Replace the encoder layers with our truncated version
103
- new_model.module.encoder.layers = encoder_layers
104
-
105
- # Remove the heads since we're stopping at the encoder layer
106
- new_model.module.heads = nn.Identity()
107
-
108
- return new_model
109
- else:
110
- # Direct model access (not DataParallel)
111
- encoder_layers = nn.Sequential()
112
- for i in range(target_layer_idx + 1):
113
- layer_name = f"encoder_layer_{i}"
114
- if hasattr(new_model.encoder.layers, layer_name):
115
- encoder_layers.add_module(layer_name,
116
- getattr(new_model.encoder.layers, layer_name))
117
-
118
- # Replace the encoder layers with our truncated version
119
- new_model.encoder.layers = encoder_layers
120
-
121
- # Remove the heads since we're stopping at the encoder layer
122
- new_model.heads = nn.Identity()
123
-
124
- return new_model
125
-
126
- except (ValueError, IndexError) as e:
127
- raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}")
128
 
129
- # Handling for ViT whole blocks
130
- elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')):
131
- # Check for DataParallel wrapper
132
- base_model = model.module if hasattr(model, 'module') else model
133
-
134
- # Create a deep copy to avoid modifying the original
135
- new_model = copy.deepcopy(model)
136
- base_new_model = new_model.module if hasattr(new_model, 'module') else new_model
137
-
138
- # Add the desired number of transformer blocks
139
- if isinstance(layer_index, int):
140
- # Truncate the blocks
141
- base_new_model.blocks = base_new_model.blocks[:layer_index+1]
142
-
143
- return new_model
144
 
145
- else:
146
- # Original ResNet/VGG handling
147
- modules = list(model.named_children())
148
- print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}")
149
-
150
- cutoff_idx = next((i for i, (name, _) in enumerate(modules)
151
- if name == str(layer_index)), None)
152
-
153
- if cutoff_idx is not None:
154
- # Keep modules up to and including the target
155
- new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1]))
156
- return new_model
157
- else:
158
- raise ValueError(f"Module {layer_index} not found in model")
159
-
160
- # Get ImageNet labels
161
- def get_imagenet_labels():
162
- url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
163
- response = requests.get(url)
164
- if response.status_code == 200:
165
- return response.json()
166
- else:
167
- raise RuntimeError("Failed to fetch ImageNet labels")
168
 
169
- # Download model if needed
170
- def download_model(model_type):
171
- if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
172
- return None # Use PyTorch's pretrained model
173
-
174
- # Handle special case for face model
175
- if model_type == 'resnet50_robust_face':
176
- model_path = Path("models/resnet50_robust_face_100_checkpoint.pt")
177
- else:
178
- model_path = Path(f"models/{model_type}.pt")
179
-
180
- if not model_path.exists():
181
- print(f"Downloading {model_type} model...")
182
- url = MODEL_URLS[model_type]
183
- response = requests.get(url, stream=True)
184
- if response.status_code == 200:
185
- with open(model_path, 'wb') as f:
186
- for chunk in response.iter_content(chunk_size=8192):
187
- f.write(chunk)
188
- print(f"Model downloaded and saved to {model_path}")
189
- else:
190
- raise RuntimeError(f"Failed to download model: {response.status_code}")
191
- return model_path
192
 
193
  class NormalizeByChannelMeanStd(nn.Module):
 
194
  def __init__(self, mean, std):
195
  super(NormalizeByChannelMeanStd, self).__init__()
196
  if not isinstance(mean, torch.Tensor):
@@ -205,737 +131,567 @@ class NormalizeByChannelMeanStd(nn.Module):
205
 
206
  def normalize_fn(self, tensor, mean, std):
207
  """Differentiable version of torchvision.functional.normalize"""
208
- # here we assume the color channel is at dim=1
209
  mean = mean[None, :, None, None]
210
  std = std[None, :, None, None]
211
  return tensor.sub(mean).div(std)
212
 
213
  class InferStep:
214
- def __init__(self, orig_image, eps, step_size):
 
 
215
  self.orig_image = orig_image
216
  self.eps = eps
217
  self.step_size = step_size
218
 
219
- def project(self, x):
 
220
  diff = x - self.orig_image
221
  diff = torch.clamp(diff, -self.eps, self.eps)
222
  return torch.clamp(self.orig_image + diff, 0, 1)
223
 
224
- def step(self, x, grad):
225
- l = len(x.shape) - 1
226
- grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l))
 
227
  scaled_grad = grad / (grad_norm + 1e-10)
228
  return scaled_grad * self.step_size
229
 
230
- def get_iterations_to_show(n_itr):
231
- """Generate a dynamic list of iterations to show based on total iterations."""
232
- if n_itr <= 50:
233
- return [1, 5, 10, 20, 30, 40, 50, n_itr]
234
- elif n_itr <= 100:
235
- return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
236
- elif n_itr <= 200:
237
- return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
238
- elif n_itr <= 500:
239
- return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
240
- else:
241
- # For very large iterations, show more evenly distributed points
242
- return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
243
- int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
244
-
245
- def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
246
- """Generate inference configuration with customizable parameters.
247
-
248
- Args:
249
- inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion')
250
- eps (float): Maximum perturbation size
251
- n_itr (int): Number of iterations
252
- step_size (float): Step size for each iteration
253
- """
254
 
255
- # Base configuration common to all inference types
256
- config = {
257
- 'loss_infer': inference_type, # How to guide the optimization
258
- 'n_itr': n_itr, # Number of iterations
259
- 'eps': eps, # Maximum perturbation size
260
- 'step_size': step_size, # Step size for each iteration
261
- 'diffusion_noise_ratio': 0.0, # No diffusion noise
262
- 'initial_inference_noise_ratio': 0.0, # No initial noise
263
- 'top_layer': 'all', # Use all layers of the model
264
- 'inference_normalization': False, # Apply normalization during inference
265
- 'recognition_normalization': False, # Apply normalization during recognition
266
- 'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
267
- 'misc_info': {'keep_grads': False} # Additional configuration
268
- }
269
 
270
- # Customize based on inference type
271
- if inference_type == 'IncreaseConfidence':
272
- config['loss_function'] = 'CE' # Cross Entropy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- elif inference_type == 'Prior-Guided Drift Diffusion':
275
- config['loss_function'] = 'MSE' # Mean Square Error
276
- config['initial_inference_noise_ratio'] = 0.05 # Initial noise for diffusion
277
- config['diffusion_noise_ratio'] = 0.01 # Add noise during diffusion
 
 
278
 
279
- elif inference_type == 'GradModulation':
280
- config['loss_function'] = 'CE' # Cross Entropy
281
- config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength
282
 
283
- elif inference_type == 'CompositionalFusion':
284
- config['loss_function'] = 'CE' # Cross Entropy
285
- config['misc_info']['positive_classes'] = [] # Classes to maximize
286
- config['misc_info']['negative_classes'] = [] # Classes to minimize
287
 
288
- return config
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  class GenerativeInferenceModel:
 
 
291
  def __init__(self):
292
  self.models = {}
293
- #self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
294
  self.model_preproc = {}
295
- self.labels = get_imagenet_labels()
296
 
297
- def verify_model_integrity(self, model, model_type):
298
- """
299
- Verify model integrity by running a test input through it.
300
- Returns whether the model passes basic integrity check.
301
- """
302
  try:
303
- print(f"\n=== Running model integrity check for {model_type} ===")
304
- # Create a deterministic test input directly on the correct device
305
- H = W = MODEL_PREPROC.get(model_type, {"size": 224})["size"]
306
- test_input = torch.zeros(1, 3, H, W, device=device)
307
- test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
308
-
309
- # Run forward pass
310
- with torch.no_grad():
311
- output = model(test_input)
312
-
313
- # Check output shape
314
- if output.shape != (1, 1000):
315
- print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)")
316
- return False
317
-
318
- # Get top prediction
319
- probs = torch.nn.functional.softmax(output, dim=1)
320
- confidence, prediction = torch.max(probs, 1)
321
-
322
- # Calculate basic statistics on output
323
- mean = output.mean().item()
324
- std = output.std().item()
325
- min_val = output.min().item()
326
- max_val = output.max().item()
327
-
328
- print(f"Model integrity check results:")
329
- print(f"- Output shape: {output.shape}")
330
- print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
331
- print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}")
332
-
333
- # Basic sanity checks
334
- if torch.isnan(output).any():
335
- print("❌ Model produced NaN outputs")
336
- return False
337
-
338
- if output.std().item() < 0.1:
339
- print("⚠️ Low output variance, model may not be discriminative")
340
-
341
- print("✅ Model passes basic integrity check")
342
- return True
343
-
344
  except Exception as e:
345
- print(f" Model integrity check failed with error: {e}")
346
- # Rather than failing completely, we'll continue
347
- return True
348
 
349
  def load_model(self, model_type):
 
350
  if model_type in self.models:
351
  print(f"Using cached {model_type} model")
352
  return self.models[model_type]
353
 
354
  start_time = time.time()
355
- model_path = download_model(model_type)
356
-
357
- # pick preproc for this model
358
- pre = MODEL_PREPROC.get(model_type, {"size": 224, "mean": IMAGENET_MEAN, "std": IMAGENET_STD})
359
- normalizer = NormalizeByChannelMeanStd(pre["mean"], pre["std"]).to(device)
360
- self.model_preproc[model_type] = pre
361
-
362
- resnet = models.resnet50()
 
 
 
 
 
 
363
  model = nn.Sequential(normalizer, resnet)
364
 
365
- # Load the model checkpoint
 
 
366
  if model_path:
367
  print(f"Loading {model_type} model from {model_path}...")
368
  try:
369
  checkpoint = torch.load(model_path, map_location=device)
370
 
371
- # Print checkpoint structure for better understanding
372
- print("\n=== Analyzing checkpoint structure ===")
373
- if isinstance(checkpoint, dict):
374
- print(f"Checkpoint contains keys: {list(checkpoint.keys())}")
375
-
376
- # Examine 'model' structure if it exists
377
- if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
378
- model_dict = checkpoint['model']
379
- # Get sample of keys to understand structure
380
- first_keys = list(model_dict.keys())[:5]
381
- print(f"'model' contains keys like: {first_keys}")
382
-
383
- # Check for common prefixes in the model dict
384
- prefixes = set()
385
- for key in list(model_dict.keys())[:100]: # Check first 100 keys
386
- parts = key.split('.')
387
- if len(parts) > 1:
388
- prefixes.add(parts[0])
389
- if prefixes:
390
- print(f"Common prefixes in model dict: {prefixes}")
391
- else:
392
- print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}")
393
-
394
  # Handle different checkpoint formats
395
  if 'model' in checkpoint:
396
- # Format from madrylab robust models
397
  state_dict = checkpoint['model']
398
  print("Using 'model' key from checkpoint")
399
  elif 'state_dict' in checkpoint:
400
  state_dict = checkpoint['state_dict']
401
  print("Using 'state_dict' key from checkpoint")
402
  else:
403
- # Direct state dict
404
  state_dict = checkpoint
405
  print("Using checkpoint directly as state_dict")
406
 
407
- # Handle prefix in state dict keys for ResNet part
408
  resnet_state_dict = {}
409
- prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.']
410
  resnet_keys = set(resnet.state_dict().keys())
411
 
412
- # First check if we can find keys directly in the attacker.model path
413
- print("\n=== Phase 1: Checking for specific model structures ===")
414
-
415
- # Check for 'module.model' structure (seen in actual checkpoint)
416
- module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')]
417
- if module_model_keys:
418
- print(f"Found 'module.model' structure with {len(module_model_keys)} parameters")
419
- # Extract all parameters from module.model
420
- for source_key, value in state_dict.items():
421
- if source_key.startswith('module.model.'):
422
- target_key = source_key[len('module.model.'):]
423
- # Some ckpts have 'module.model.model.<...>'; remove the extra 'model.' too
424
- if target_key.startswith('model.'):
425
- target_key = target_key[len('model.'):]
426
- resnet_state_dict[target_key] = value
427
-
428
- print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
429
-
430
- # Check for 'attacker.model' structure
431
- attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')]
432
- if attacker_model_keys:
433
- print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters")
434
- # Extract all parameters from attacker.model
435
- for source_key, value in state_dict.items():
436
- if source_key.startswith('attacker.model.'):
437
- target_key = source_key[len('attacker.model.'):]
438
- resnet_state_dict[target_key] = value
439
-
440
- print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model")
441
-
442
- # Check if 'model' (not attacker.model) exists as a fallback
443
- model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')]
444
- if model_keys and len(resnet_state_dict) < len(resnet_keys):
445
- print(f"Found additional 'model.' structure with {len(model_keys)} parameters")
446
- # Try to complete missing parameters
447
  for source_key, value in state_dict.items():
448
- if source_key.startswith('model.'):
449
- target_key = source_key[len('model.'):]
450
- if target_key in resnet_keys and target_key not in resnet_state_dict:
451
  resnet_state_dict[target_key] = value
452
-
453
- else:
454
- # Check for other known structures
455
- structure_found = False
456
-
457
- # Check for 'model.' prefix
458
- model_keys = [key for key in state_dict.keys() if key.startswith('model.')]
459
- if model_keys:
460
- print(f"Found 'model.' structure with {len(model_keys)} parameters")
461
- for source_key, value in state_dict.items():
462
- if source_key.startswith('model.'):
463
- target_key = source_key[len('model.'):]
464
- resnet_state_dict[target_key] = value
465
- structure_found = True
466
 
467
- # Check for ResNet parameters at the top level
468
- top_level_resnet_keys = 0
469
- for key in resnet_keys:
470
- if key in state_dict:
471
- top_level_resnet_keys += 1
472
-
473
- if top_level_resnet_keys > 0:
474
- print(f"Found {top_level_resnet_keys} ResNet parameters at top level")
475
- for target_key in resnet_keys:
476
- if target_key in state_dict:
477
- resnet_state_dict[target_key] = state_dict[target_key]
478
- structure_found = True
479
-
480
- # If no structure was recognized, try the prefix mapping approach
481
- if not structure_found:
482
- print("No standard model structure found, trying prefix mappings...")
483
- for target_key in resnet_keys:
484
- for prefix in prefixes_to_try:
485
- source_key = prefix + target_key
486
- if source_key in state_dict:
487
- resnet_state_dict[target_key] = state_dict[source_key]
488
- break
489
 
490
- # If we still can't find enough keys, try a final approach of removing prefixes
491
- if len(resnet_state_dict) < len(resnet_keys):
492
- print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...")
493
-
494
- # Track matches found through prefix removal
495
- prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']}
496
- layer_matches = {} # Track matches by layer type
497
-
498
- # Count parameter keys by layer type for analysis
499
- for key in resnet_keys:
500
- layer_name = key.split('.')[0] if '.' in key else key
501
- if layer_name not in layer_matches:
502
- layer_matches[layer_name] = {'total': 0, 'matched': 0}
503
- layer_matches[layer_name]['total'] += 1
504
 
505
- # Try keys with common prefixes
506
  for source_key, value in state_dict.items():
507
- # Skip if already found
508
  target_key = source_key
509
- matched_prefix = None
510
 
511
  # Try removing various prefixes
512
- for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']:
513
  if source_key.startswith(prefix):
514
  target_key = source_key[len(prefix):]
515
- matched_prefix = prefix
516
  break
517
 
518
- # If the target key is in the ResNet keys, add it to the state dict
519
- if target_key in resnet_keys and target_key not in resnet_state_dict:
 
 
 
 
520
  resnet_state_dict[target_key] = value
521
-
522
- # Update match statistics
523
- if matched_prefix:
524
- prefix_matches[matched_prefix] += 1
525
-
526
- # Update layer matches
527
- layer_name = target_key.split('.')[0] if '.' in target_key else target_key
528
- if layer_name in layer_matches:
529
- layer_matches[layer_name]['matched'] += 1
530
-
531
- # Print detailed prefix removal statistics
532
- print("\n=== Prefix Removal Statistics ===")
533
- total_matches = sum(prefix_matches.values())
534
- print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)")
535
-
536
- # Show matches by prefix
537
- print("\nMatches by prefix:")
538
- for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True):
539
- if count > 0:
540
- print(f" {prefix}: {count} parameters")
541
-
542
- # Show matches by layer type
543
- print("\nMatches by layer type:")
544
- for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True):
545
- match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0
546
- print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)")
547
-
548
- # Check for specific important layers (conv1, layer1, etc.)
549
- critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
550
- print("\nStatus of critical layers:")
551
- for layer in critical_layers:
552
- if layer in layer_matches:
553
- match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100
554
- status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE"
555
- print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}")
556
- else:
557
- print(f" {layer}: Not found in model")
558
 
559
- # Load the ResNet state dict
560
  if resnet_state_dict:
561
- try:
562
- # Use strict=False to allow missing keys
563
- result = resnet.load_state_dict(resnet_state_dict, strict=False)
564
- missing_keys, unexpected_keys = result
565
-
566
- # Generate detailed information with better formatting
567
- loading_report = []
568
- loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====")
569
- loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}")
570
- loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}")
571
- loading_report.append(f"Missing keys: {len(missing_keys):,} parameters")
572
- loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters")
573
-
574
- # Calculate percentage of parameters loaded
575
- loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys)
576
- loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100
577
-
578
- # Determine loading success status
579
- if loaded_percent >= 99.5:
580
- status = "✅ COMPLETE - All important parameters loaded"
581
- elif loaded_percent >= 90:
582
- status = "🟡 PARTIAL - Most parameters loaded, should still function"
583
- elif loaded_percent >= 50:
584
- status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly"
585
- else:
586
- status = "❌ FAILED - Critical parameters missing, will not function properly"
587
-
588
- loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)")
589
- loading_report.append(f"Loading status: {status}")
590
-
591
- # If loading is severely incomplete, fall back to PyTorch's pretrained model
592
- if loaded_percent < 50:
593
- loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.")
594
- loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.")
595
-
596
- # Create a new ResNet model with pretrained weights
597
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
598
  model = nn.Sequential(normalizer, resnet)
599
- loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
600
-
601
- # Show missing keys by layer type
602
- if missing_keys:
603
- loading_report.append("\nMissing keys by layer type:")
604
- layer_types = {}
605
- for key in missing_keys:
606
- # Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.)
607
- parts = key.split('.')
608
- if len(parts) > 0:
609
- layer_type = parts[0]
610
- if layer_type not in layer_types:
611
- layer_types[layer_type] = 0
612
- layer_types[layer_type] += 1
613
-
614
- # Add counts by layer type
615
- for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
616
- loading_report.append(f" {layer_type}: {count:,} parameters")
617
-
618
- loading_report.append("\nFirst 10 missing keys:")
619
- for i, key in enumerate(sorted(missing_keys)[:10]):
620
- loading_report.append(f" {i+1}. {key}")
621
-
622
- # Show unexpected keys if any
623
- if unexpected_keys:
624
- loading_report.append("\nFirst 10 unexpected keys:")
625
- for i, key in enumerate(sorted(unexpected_keys)[:10]):
626
- loading_report.append(f" {i+1}. {key}")
627
-
628
- loading_report.append("========================================")
629
 
630
- # Convert report to string and print it
631
- report_text = "\n".join(loading_report)
632
- print(report_text)
633
-
634
- # Also save to a file for reference
635
- os.makedirs("logs", exist_ok=True)
636
- with open(f"logs/model_loading_{model_type}.log", "w") as f:
637
- f.write(report_text)
638
 
639
- # Look for normalizer parameters as well
640
- if any(key.startswith('attacker.normalize.') for key in state_dict.keys()):
641
- norm_state_dict = {}
642
- for key, value in state_dict.items():
643
- if key.startswith('attacker.normalize.'):
644
- norm_key = key[len('attacker.normalize.'):]
645
- norm_state_dict[norm_key] = value
646
-
647
- if norm_state_dict:
648
- try:
649
- normalizer.load_state_dict(norm_state_dict, strict=False)
650
- print("Successfully loaded normalizer parameters")
651
- except Exception as e:
652
- print(f"Warning: Could not load normalizer parameters: {e}")
653
- except Exception as e:
654
- print(f"Warning: Error loading ResNet parameters: {e}")
655
- # Fall back to loading without normalizer
656
- model = resnet # Use just the ResNet model without normalizer
657
  except Exception as e:
658
- print(f"Error loading model checkpoint: {e}")
659
- # Fallback to PyTorch's pretrained model
660
- print("Falling back to PyTorch's pretrained model")
 
 
 
 
 
 
 
 
 
661
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
662
  model = nn.Sequential(normalizer, resnet)
663
- else:
664
- # Fallback to PyTorch's pretrained model
665
- print("No checkpoint available, using PyTorch's pretrained model")
666
- resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
667
- model = nn.Sequential(normalizer, resnet)
668
 
669
  model = model.to(device)
670
- model.eval() # Set to evaluation mode
671
 
672
- # Verify model integrity
673
  self.verify_model_integrity(model, model_type)
674
 
675
- # Store the model for future use
676
  self.models[model_type] = model
 
677
  end_time = time.time()
678
- load_time = end_time - start_time
679
- print(f"Model {model_type} loaded in {load_time:.2f} seconds")
680
  return model
681
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  def inference(self, image, model_type, config):
683
- """Run generative inference on the image."""
684
- # Time the entire inference process
685
  inference_start = time.time()
686
 
687
- # Load model if not already loaded
688
  model = self.load_model(model_type)
689
 
690
- # Check if image is a file path
691
  if isinstance(image, str):
692
  if os.path.exists(image):
693
  image = Image.open(image).convert('RGB')
694
  else:
695
  raise ValueError(f"Image path does not exist: {image}")
696
- elif isinstance(image, torch.Tensor):
697
- raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
 
699
- # Prepare image tensor - match original code's conditional transform
700
- load_start = time.time()
701
- # Pick the right preproc for this model
702
- pre = self.model_preproc.get(model_type, {"size": 224, "mean": IMAGENET_MEAN, "std": IMAGENET_STD})
703
 
704
- # IMPORTANT: the model already includes a NormalizeByChannelMeanStd as layer 0,
705
- # so do NOT normalize again here, or you’ll double-normalize.
706
- custom_transform = get_transform(
707
- input_size=pre["size"], # 112 for resnet50_robust_face
708
- normalize=False, # leave False; model handles normalization internally
709
- norm_mean=pre["mean"],
710
- norm_std=pre["std"]
711
- )
712
-
713
- print(f"[PREPROC] {model_type}: size={pre['size']} mean={pre['mean']} std={pre['std']} (transform normalize=False; model has internal normalizer)")
714
 
715
- # Special handling for GradModulation as in original
716
- if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']:
717
- grad_modulation = config['misc_info']['grad_modulation']
718
- image_tensor = custom_transform(image).unsqueeze(0).to(device)
719
- image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device)
 
 
 
 
720
  else:
721
- image_tensor = custom_transform(image).unsqueeze(0).to(device)
722
-
723
- image_tensor.requires_grad = True
724
- print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds")
 
 
725
 
726
- # Check model structure
727
- is_sequential = isinstance(model, nn.Sequential)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729
  # Get original predictions
730
  with torch.no_grad():
731
- # If the model is sequential with a normalizer, skip the normalization step
732
- if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
733
- print("Model is sequential with normalization")
734
- # Get the core model part (typically at index 1 in Sequential)
735
- core_model = model[1]
736
- if config['inference_normalization']:
737
- output_original = model(image_tensor) # Model includes normalization
738
- else:
739
- output_original = core_model(image_tensor) # Model includes normalization
740
-
741
  else:
742
- print("Model is not sequential with normalization")
743
- # Use manual normalization for non-sequential models
744
- if config['inference_normalization']:
745
- normalized_tensor = normalize_transform(image_tensor)
746
- output_original = model(normalized_tensor)
747
- else:
748
- output_original = model(image_tensor)
749
- core_model = model
750
 
751
  probs_orig = F.softmax(output_original, dim=1)
752
  conf_orig, classes_orig = torch.max(probs_orig, 1)
753
 
754
- # Get least confident classes
755
- _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False)
756
-
 
 
 
 
 
 
 
 
 
757
  # Initialize inference step
758
  infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
759
 
760
  # Storage for inference steps
761
- # Create a new tensor that requires gradients
762
  x = image_tensor.clone().detach().requires_grad_(True)
763
  all_steps = [image_tensor[0].detach().cpu()]
764
 
765
- # For Prior-Guided Drift Diffusion, extract selected layer and initialize with noisy features
766
- noisy_features = None
767
- layer_model = None
768
- if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
769
- print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...")
770
-
771
- # Extract model up to the specified layer
772
- try:
773
- # Start by finding the actual model to use
774
- base_model = model
775
-
776
- # Handle DataParallel wrapper if present
777
- if hasattr(base_model, 'module'):
778
- base_model = base_model.module
779
-
780
- # Log the initial model structure
781
- print(f"DEBUG - Initial model structure: {type(base_model)}")
782
-
783
- # If we have a Sequential model (which is likely our normalizer + model structure)
784
- if isinstance(base_model, nn.Sequential):
785
- print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children")
786
-
787
- # If this is our NormalizeByChannelMeanStd + ResNet pattern
788
- if len(list(base_model.children())) >= 2:
789
- # The actual ResNet model is the second component (index 1)
790
- actual_model = list(base_model.children())[1]
791
- print(f"DEBUG - Using ResNet component: {type(actual_model)}")
792
- print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}")
793
-
794
- # Extract from the actual ResNet
795
- layer_model = extract_middle_layers(actual_model, config['top_layer'])
796
- else:
797
- # Just a single component Sequential
798
- layer_model = extract_middle_layers(base_model, config['top_layer'])
799
- else:
800
- # Not Sequential, might be direct model
801
- print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}")
802
- layer_model = extract_middle_layers(base_model, config['top_layer'])
803
-
804
- print(f"Successfully extracted model up to layer: {config['top_layer']}")
805
- except ValueError as e:
806
- print(f"Layer extraction failed: {e}. Using full model.")
807
- layer_model = model
808
-
809
- # Add noise to the image - exactly match original code
810
- added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device)
811
- noisy_image_tensor = image_tensor + added_noise
812
-
813
- # Compute noisy features - simplified to match original code
814
- noisy_features = layer_model(noisy_image_tensor)
815
-
816
- print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}")
817
 
818
  # Main inference loop
819
  print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
820
- loop_start = time.time()
821
  for i in range(config['n_itr']):
822
  # Reset gradients
823
  x.grad = None
824
 
825
- # Forward pass - use layer_model for Prior-Guided Drift Diffusion, full model otherwise
826
- if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
827
- # Use the extracted layer model for Prior-Guided Drift Diffusion
828
- # In original code, normalization is handled at transform time, not during forward pass
829
- output = layer_model(x)
830
- else:
831
- # Standard forward pass with full model
832
- # Simplified to match original code's approach
833
- output = model(x)
834
-
835
- # Calculate loss and gradients based on inference type
836
- try:
837
- if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
838
- # Use MSE loss to match the noisy features
839
- assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE"
840
- if noisy_features is not None:
841
- loss = F.mse_loss(output, noisy_features)
842
- grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original
843
- else:
844
- raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
845
-
846
- else: # Default 'IncreaseConfidence' approach
847
- # Get the least confident classes
848
- num_classes = min(10, least_confident_classes.size(1))
849
- target_classes = least_confident_classes[0, :num_classes]
850
-
851
- # Create targets for least confident classes
852
- targets = torch.tensor([idx.item() for idx in target_classes], device=device)
853
-
854
- # Use a combined loss to increase confidence
855
- loss = 0
856
- for target in targets:
857
- # Create one-hot target
858
- one_hot = torch.zeros_like(output)
859
- one_hot[0, target] = 1
860
- # Use loss to maximize confidence
861
- loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
862
-
863
- grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
864
 
865
- if grad is None:
866
- print("Warning: Direct gradient calculation failed")
867
- # Fall back to random perturbation
868
- random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
869
- x = infer_step.project(x + random_noise)
870
  else:
871
- # Update image with gradient - do this exactly as in original code
872
- adjusted_grad = infer_step.step(x, grad)
 
 
 
 
 
 
873
 
874
- # Add diffusion noise if specified
875
- diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device)
 
 
 
 
 
 
876
 
877
- # Apply gradient and noise in one operation before projecting, exactly as in original
878
- x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
 
 
 
 
 
 
879
 
880
- except Exception as e:
881
- print(f"Error in gradient calculation: {e}")
882
- # Fall back to random perturbation - match original code
883
- random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
884
- x = infer_step.project(x.clone() + random_noise)
 
 
 
 
 
 
 
 
 
 
 
885
 
886
  # Store step if in iterations_to_show
887
- if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
888
  all_steps.append(x[0].detach().cpu())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889
 
890
- # Print some info about the inference
891
  with torch.no_grad():
892
- if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
893
- if config['inference_normalization']:
894
- final_output = model(x)
895
- else:
896
- final_output = core_model(x)
897
  else:
898
- if config['inference_normalization']:
899
- normalized_x = normalize_transform(x)
900
- final_output = model(normalized_x)
901
- else:
902
- final_output = model(x)
903
 
904
  final_probs = F.softmax(final_output, dim=1)
905
  final_conf, final_classes = torch.max(final_probs, 1)
906
 
907
- # Calculate timing information
908
- loop_time = time.time() - loop_start
909
  total_time = time.time() - inference_start
910
- avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0
911
 
912
  print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
913
  print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
914
- print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)")
915
  print(f"Total inference time: {total_time:.2f} seconds")
916
 
917
- # Return results in format compatible with both old and new code
918
  return {
919
  'final_image': x[0].detach().cpu(),
920
  'steps': all_steps,
921
  'original_class': classes_orig.item(),
922
  'original_confidence': conf_orig.item(),
923
  'final_class': final_classes.item(),
924
- 'final_confidence': final_conf.item()
 
 
925
  }
926
 
927
- # Utility function to show inference steps
928
  def show_inference_steps(steps, figsize=(15, 10)):
929
- import matplotlib.pyplot as plt
930
-
931
- n_steps = len(steps)
932
- fig, axes = plt.subplots(1, n_steps, figsize=figsize)
933
-
934
- for i, step_img in enumerate(steps):
935
- img = step_img.permute(1, 2, 0).numpy()
936
- axes[i].imshow(img)
937
- axes[i].set_title(f"Step {i}")
938
- axes[i].axis('off')
939
-
940
- plt.tight_layout()
941
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Complete generative inference module with model loading and inference capabilities."""
2
+
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
 
6
  import torchvision.transforms as transforms
7
+ import torchvision.models as models
8
  from torchvision.models.resnet import ResNet50_Weights
9
  from PIL import Image
10
  import numpy as np
 
14
  import copy
15
  from collections import OrderedDict
16
  from pathlib import Path
17
+ from typing import Dict, List, Optional, Tuple, Union
18
 
19
  # Check for available hardware acceleration
20
  if torch.cuda.is_available():
 
25
  device = torch.device("cpu")
26
  print(f"Using device: {device}")
27
 
28
+ # Constants for model URLs
29
  MODEL_URLS = {
30
  'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
31
  'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
32
  'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt'
33
  }
34
 
35
+ # Model-specific preprocessing configurations
36
+ MODEL_CONFIGS = {
37
+ 'resnet50_robust_face': {
38
+ 'input_size': 112,
39
+ 'norm_mean': [0.5, 0.5, 0.5],
40
+ 'norm_std': [0.5, 0.5, 0.5],
41
+ 'n_classes': 500,
42
+ 'dataset': 'VGGFace2'
43
+ },
44
+ 'resnet50_standard': {
45
+ 'input_size': 224,
46
+ 'norm_mean': [0.485, 0.456, 0.406],
47
+ 'norm_std': [0.229, 0.224, 0.225],
48
+ 'n_classes': 1000,
49
+ 'dataset': 'ImageNet'
50
+ },
51
+ 'resnet50_robust': {
52
+ 'input_size': 224,
53
+ 'norm_mean': [0.485, 0.456, 0.406],
54
+ 'norm_std': [0.229, 0.224, 0.225],
55
+ 'n_classes': 1000,
56
+ 'dataset': 'ImageNet'
57
+ }
58
  }
59
 
60
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
61
  IMAGENET_STD = [0.229, 0.224, 0.225]
62
 
63
+ def get_iterations_to_show(n_itr):
64
+ """Generate a dynamic list of iterations to show based on total iterations."""
65
+ if n_itr <= 50:
66
+ return [1, 5, 10, 20, 30, 40, 50, n_itr]
67
+ elif n_itr <= 100:
68
+ return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
69
+ elif n_itr <= 200:
70
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
71
+ elif n_itr <= 500:
72
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
73
  else:
74
+ return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
75
+ int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
78
+ """Generate inference configuration with customizable parameters."""
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ config = {
81
+ 'loss_infer': inference_type,
82
+ 'n_itr': n_itr,
83
+ 'eps': eps,
84
+ 'step_size': step_size,
85
+ 'diffusion_noise_ratio': 0.0,
86
+ 'initial_inference_noise_ratio': 0.0,
87
+ 'top_layer': 'all',
88
+ 'inference_normalization': False,
89
+ 'recognition_normalization': False,
90
+ 'iterations_to_show': get_iterations_to_show(n_itr),
91
+ 'misc_info': {'keep_grads': False}
92
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ if inference_type == 'IncreaseConfidence':
95
+ config['loss_function'] = 'CE'
96
+ elif inference_type == 'Prior-Guided Drift Diffusion':
97
+ config['loss_function'] = 'MSE'
98
+ config['initial_inference_noise_ratio'] = 0.05
99
+ config['diffusion_noise_ratio'] = 0.01
100
+ config['top_layer'] = 'layer4'
101
+ elif inference_type == 'GradModulation':
102
+ config['loss_function'] = 'CE'
103
+ config['misc_info']['grad_modulation'] = 0.5
104
+ elif inference_type == 'CompositionalFusion':
105
+ config['loss_function'] = 'CE'
106
+ config['misc_info']['positive_classes'] = []
107
+ config['misc_info']['negative_classes'] = []
 
108
 
109
+ return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ def get_model_preprocessing(model_type: str) -> Dict:
112
+ """Get preprocessing configuration for specific model type."""
113
+ if model_type not in MODEL_CONFIGS:
114
+ print(f"Fall-back: Unknown model type {model_type}, using ImageNet defaults")
115
+ return MODEL_CONFIGS['resnet50_standard']
116
+ return MODEL_CONFIGS[model_type]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  class NormalizeByChannelMeanStd(nn.Module):
119
+ """Normalization layer for models."""
120
  def __init__(self, mean, std):
121
  super(NormalizeByChannelMeanStd, self).__init__()
122
  if not isinstance(mean, torch.Tensor):
 
131
 
132
  def normalize_fn(self, tensor, mean, std):
133
  """Differentiable version of torchvision.functional.normalize"""
 
134
  mean = mean[None, :, None, None]
135
  std = std[None, :, None, None]
136
  return tensor.sub(mean).div(std)
137
 
138
  class InferStep:
139
+ """Inference step class for gradient-based optimization."""
140
+
141
+ def __init__(self, orig_image: torch.Tensor, eps: float, step_size: float):
142
  self.orig_image = orig_image
143
  self.eps = eps
144
  self.step_size = step_size
145
 
146
+ def project(self, x: torch.Tensor) -> torch.Tensor:
147
+ """Project x onto epsilon-ball around original image."""
148
  diff = x - self.orig_image
149
  diff = torch.clamp(diff, -self.eps, self.eps)
150
  return torch.clamp(self.orig_image + diff, 0, 1)
151
 
152
+ def step(self, x: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
153
+ """Take a normalized gradient step."""
154
+ dim = len(x.shape) - 1
155
+ grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=1).reshape(-1, *([1] * dim))
156
  scaled_grad = grad / (grad_norm + 1e-10)
157
  return scaled_grad * self.step_size
158
 
159
+ def extract_middle_layers(model: nn.Module, layer_index: Union[str, int]) -> nn.Module:
160
+ """Extract middle layers from a model up to a specified layer index."""
161
+ if isinstance(layer_index, str) and layer_index == 'all':
162
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ # Handle ResNet layer extraction
165
+ modules = list(model.named_children())
166
+ cutoff_idx = next(
167
+ (i for i, (name, _) in enumerate(modules) if name == str(layer_index)),
168
+ None
169
+ )
 
 
 
 
 
 
 
 
170
 
171
+ if cutoff_idx is not None:
172
+ new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx + 1]))
173
+ return new_model
174
+ else:
175
+ print(f"Fall-back: Module {layer_index} not found, using full model")
176
+ return model
177
+
178
+ def calculate_loss(output_model: torch.Tensor, class_indices: List[int], loss_inference: str) -> torch.Tensor:
179
+ """Calculate loss for specified class indices."""
180
+ losses = []
181
+ for idx in class_indices:
182
+ target = torch.full((1,), idx, dtype=torch.long, device=output_model.device)
183
+ if loss_inference == 'CE':
184
+ loss = nn.CrossEntropyLoss()(output_model, target)
185
+ elif loss_inference == 'MSE':
186
+ one_hot_target = torch.zeros_like(output_model)
187
+ one_hot_target[0, target] = 1
188
+ loss = nn.MSELoss()(output_model, one_hot_target)
189
+ else:
190
+ raise ValueError(f"Unsupported loss_inference: {loss_inference}")
191
+ losses.append(loss)
192
 
193
+ return torch.stack(losses).mean()
194
+
195
+ def download_model(model_type):
196
+ """Download model if needed."""
197
+ if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
198
+ return None
199
 
200
+ os.makedirs("models", exist_ok=True)
 
 
201
 
202
+ if model_type == 'resnet50_robust_face':
203
+ model_path = Path("models/resnet50_vggface2_L2_eps_0.50_checkpoint150.pt")
204
+ else:
205
+ model_path = Path(f"models/{model_type}.pt")
206
 
207
+ if not model_path.exists():
208
+ print(f"Downloading {model_type} model...")
209
+ url = MODEL_URLS[model_type]
210
+ response = requests.get(url, stream=True)
211
+ if response.status_code == 200:
212
+ with open(model_path, 'wb') as f:
213
+ for chunk in response.iter_content(chunk_size=8192):
214
+ f.write(chunk)
215
+ print(f"Model downloaded and saved to {model_path}")
216
+ else:
217
+ raise RuntimeError(f"Failed to download model: {response.status_code}")
218
+ return model_path
219
 
220
  class GenerativeInferenceModel:
221
+ """Complete generative inference model with model loading and inference."""
222
+
223
  def __init__(self):
224
  self.models = {}
 
225
  self.model_preproc = {}
226
+ self.labels = self.get_imagenet_labels()
227
 
228
+ def get_imagenet_labels(self):
229
+ """Get ImageNet labels."""
230
+ url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
 
 
231
  try:
232
+ response = requests.get(url)
233
+ if response.status_code == 200:
234
+ return response.json()
235
+ else:
236
+ print("Fall-back: Failed to fetch ImageNet labels, using placeholder")
237
+ return [f"class_{i}" for i in range(1000)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  except Exception as e:
239
+ print(f"Fall-back: Error fetching labels: {e}")
240
+ return [f"class_{i}" for i in range(1000)]
 
241
 
242
  def load_model(self, model_type):
243
+ """Load and cache models for different model types."""
244
  if model_type in self.models:
245
  print(f"Using cached {model_type} model")
246
  return self.models[model_type]
247
 
248
  start_time = time.time()
249
+
250
+ # Get model-specific preprocessing config
251
+ preproc_config = get_model_preprocessing(model_type)
252
+ self.model_preproc[model_type] = preproc_config
253
+
254
+ # Create normalizer
255
+ normalizer = NormalizeByChannelMeanStd(
256
+ preproc_config['norm_mean'],
257
+ preproc_config['norm_std']
258
+ ).to(device)
259
+
260
+ # Create base model architecture
261
+ num_classes = preproc_config['n_classes']
262
+ resnet = models.resnet50(num_classes=num_classes)
263
  model = nn.Sequential(normalizer, resnet)
264
 
265
+ # Download and load checkpoint
266
+ model_path = download_model(model_type)
267
+
268
  if model_path:
269
  print(f"Loading {model_type} model from {model_path}...")
270
  try:
271
  checkpoint = torch.load(model_path, map_location=device)
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  # Handle different checkpoint formats
274
  if 'model' in checkpoint:
 
275
  state_dict = checkpoint['model']
276
  print("Using 'model' key from checkpoint")
277
  elif 'state_dict' in checkpoint:
278
  state_dict = checkpoint['state_dict']
279
  print("Using 'state_dict' key from checkpoint")
280
  else:
 
281
  state_dict = checkpoint
282
  print("Using checkpoint directly as state_dict")
283
 
284
+ # Extract ResNet state dict
285
  resnet_state_dict = {}
 
286
  resnet_keys = set(resnet.state_dict().keys())
287
 
288
+ # For face model, prioritize 'module.model.model.' structure (seen in actual checkpoint)
289
+ if model_type == 'resnet50_robust_face':
290
+ # Check for 'module.model.model.' structure first (face checkpoints use this)
291
+ module_model_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.model.')]
292
+ if module_model_model_keys:
293
+ print(f"Found 'module.model.model.' structure with {len(module_model_model_keys)} parameters (face model)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  for source_key, value in state_dict.items():
295
+ if source_key.startswith('module.model.model.'):
296
+ target_key = source_key[len('module.model.model.'):]
297
+ if target_key in resnet_keys:
298
  resnet_state_dict[target_key] = value
299
+ print(f"Extracted {len(resnet_state_dict)} parameters from module.model.model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ # Also check for 'module.model.' structure as fallback
302
+ if len(resnet_state_dict) < len(resnet_keys):
303
+ module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.') and not key.startswith('module.model.model.')]
304
+ if module_model_keys:
305
+ print(f"Found additional 'module.model.' structure with {len(module_model_keys)} parameters")
306
+ for source_key, value in state_dict.items():
307
+ if source_key.startswith('module.model.') and not source_key.startswith('module.model.model.'):
308
+ target_key = source_key[len('module.model.'):]
309
+ # Remove extra 'model.' if present
310
+ if target_key.startswith('model.'):
311
+ target_key = target_key[len('model.'):]
312
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
313
+ resnet_state_dict[target_key] = value
314
+ print(f"Now have {len(resnet_state_dict)} parameters after adding module.model. keys")
 
 
 
 
 
 
 
 
315
 
316
+ # Handle different key prefixes in checkpoints (for other models)
317
+ if len(resnet_state_dict) == 0:
318
+ prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.', 'attacker.']
 
 
 
 
 
 
 
 
 
 
 
319
 
 
320
  for source_key, value in state_dict.items():
 
321
  target_key = source_key
 
322
 
323
  # Try removing various prefixes
324
+ for prefix in prefixes_to_try:
325
  if source_key.startswith(prefix):
326
  target_key = source_key[len(prefix):]
 
327
  break
328
 
329
+ # Handle nested model keys
330
+ if target_key.startswith('model.'):
331
+ target_key = target_key[len('model.'):]
332
+
333
+ # If the target key is in ResNet keys, add it
334
+ if target_key in resnet_keys:
335
  resnet_state_dict[target_key] = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
+ # Load the state dict
338
  if resnet_state_dict:
339
+ result = resnet.load_state_dict(resnet_state_dict, strict=False)
340
+ missing_keys, unexpected_keys = result
341
+
342
+ loaded_percent = (len(resnet_state_dict) / len(resnet_keys)) * 100
343
+ print(f"Model loading: {len(resnet_state_dict)}/{len(resnet_keys)} parameters ({loaded_percent:.1f}%)")
344
+
345
+ if loaded_percent < 50:
346
+ print(f"Fall-back: Loading too incomplete ({loaded_percent:.1f}%), using PyTorch pretrained")
347
+ if model_type != 'resnet50_robust_face':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
349
  model = nn.Sequential(normalizer, resnet)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ else:
352
+ print("Fall-back: No matching keys found in checkpoint, using PyTorch pretrained")
353
+ if model_type != 'resnet50_robust_face':
354
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
355
+ model = nn.Sequential(normalizer, resnet)
 
 
 
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  except Exception as e:
358
+ print(f"Fall-back: Error loading checkpoint: {e}")
359
+ if model_type != 'resnet50_robust_face':
360
+ print("Fall-back: Using PyTorch pretrained model")
361
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
362
+ model = nn.Sequential(normalizer, resnet)
363
+ else:
364
+ print("Fall-back: Face model checkpoint failed, model may not work properly")
365
+
366
+ else:
367
+ # Use PyTorch's pretrained model for ImageNet models
368
+ if model_type != 'resnet50_robust_face':
369
+ print(f"No checkpoint for {model_type}, using PyTorch pretrained")
370
  resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
371
  model = nn.Sequential(normalizer, resnet)
372
+ else:
373
+ print("Fall-back: Face model requires checkpoint, model may not work properly")
 
 
 
374
 
375
  model = model.to(device)
376
+ model.eval()
377
 
378
+ # Verify model
379
  self.verify_model_integrity(model, model_type)
380
 
381
+ # Cache the model
382
  self.models[model_type] = model
383
+
384
  end_time = time.time()
385
+ print(f"Model {model_type} loaded in {end_time - start_time:.2f} seconds")
 
386
  return model
387
+
388
+ def verify_model_integrity(self, model, model_type):
389
+ """Verify model integrity."""
390
+ try:
391
+ print(f"Fall-back: Running model integrity check for {model_type}")
392
+ config = get_model_preprocessing(model_type)
393
+ H = W = config['input_size']
394
+
395
+ test_input = torch.zeros(1, 3, H, W, device=device)
396
+ test_input[0, 0, H//4:3*H//4, W//4:3*W//4] = 0.5
397
+
398
+ with torch.no_grad():
399
+ output = model(test_input)
400
+
401
+ expected_classes = config['n_classes']
402
+ if output.shape != (1, expected_classes):
403
+ print(f"Fall-back: Unexpected output shape: {output.shape}, expected (1, {expected_classes})")
404
+ return False
405
+
406
+ probs = torch.nn.functional.softmax(output, dim=1)
407
+ confidence, prediction = torch.max(probs, 1)
408
+
409
+ print(f"Model integrity check passed:")
410
+ print(f"- Output shape: {output.shape}")
411
+ print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
412
+
413
+ return True
414
+
415
+ except Exception as e:
416
+ print(f"Fall-back: Model integrity check failed with error: {e}")
417
+ return False
418
+
419
  def inference(self, image, model_type, config):
420
+ """Run generative inference."""
 
421
  inference_start = time.time()
422
 
423
+ # Load the model
424
  model = self.load_model(model_type)
425
 
426
+ # Handle image input
427
  if isinstance(image, str):
428
  if os.path.exists(image):
429
  image = Image.open(image).convert('RGB')
430
  else:
431
  raise ValueError(f"Image path does not exist: {image}")
432
+ elif isinstance(image, np.ndarray):
433
+ if image.dtype != np.uint8:
434
+ if image.max() <= 1.0:
435
+ image = (image * 255).astype(np.uint8)
436
+ else:
437
+ image = image.astype(np.uint8)
438
+ if len(image.shape) == 3:
439
+ if image.shape[0] == 3 or image.shape[0] == 1:
440
+ image = np.transpose(image, (1, 2, 0))
441
+ if image.shape[2] == 4:
442
+ image = image[:, :, :3]
443
+ elif image.shape[2] == 1:
444
+ image = np.repeat(image, 3, axis=2)
445
+ image = Image.fromarray(image)
446
+ elif not isinstance(image, Image.Image):
447
+ try:
448
+ image = Image.fromarray(np.array(image)).convert('RGB')
449
+ except Exception as e:
450
+ raise ValueError(f"Cannot convert image type {type(image)} to PIL Image: {e}")
451
 
452
+ if isinstance(image, Image.Image) and image.mode != 'RGB':
453
+ image = image.convert('RGB')
 
 
454
 
455
+ # Get preprocessing config
456
+ preproc_config = get_model_preprocessing(model_type)
457
+ input_size = preproc_config['input_size']
458
+ norm_mean = torch.tensor(preproc_config['norm_mean'])
459
+ norm_std = torch.tensor(preproc_config['norm_std'])
460
+ n_classes = preproc_config['n_classes']
 
 
 
 
461
 
462
+ # Create transform
463
+ if config.get('inference_normalization', False):
464
+ transform = transforms.Compose([
465
+ transforms.Resize(input_size),
466
+ transforms.CenterCrop(input_size),
467
+ transforms.ToTensor(),
468
+ transforms.Normalize(norm_mean.tolist(), norm_std.tolist()),
469
+ ])
470
+ print(f"Fall-back: Using normalization with mean={norm_mean.tolist()}, std={norm_std.tolist()}")
471
  else:
472
+ transform = transforms.Compose([
473
+ transforms.Resize(input_size),
474
+ transforms.CenterCrop(input_size),
475
+ transforms.ToTensor(),
476
+ ])
477
+ print(f"Normalization OFF - feeding raw [0,1] tensors to model (normalization applied in the model)")
478
 
479
+ # Helper function to safely apply transform with fallback for numpy compatibility
480
+ def safe_transform(img):
481
+ try:
482
+ return transform(img)
483
+ except TypeError as e:
484
+ if "expected np.ndarray" in str(e) or "got numpy.ndarray" in str(e):
485
+ # Fallback: manually convert PIL to tensor
486
+ print(f"[WARNING] Transform failed with numpy compatibility issue, using manual conversion")
487
+ # Apply resize and center crop manually
488
+ resize_transform = transforms.Resize(input_size)
489
+ crop_transform = transforms.CenterCrop(input_size)
490
+ img = crop_transform(resize_transform(img))
491
+ # Convert to numpy array and then to tensor using torch.tensor() to avoid numpy compatibility issues
492
+ img_array = np.array(img, dtype=np.uint8)
493
+ # Use torch.tensor() instead of torch.from_numpy() to avoid compatibility issues
494
+ # Convert to float and normalize to [0, 1], then convert from HWC to CHW format
495
+ img_tensor = torch.tensor(img_array, dtype=torch.float32).div(255.0).permute(2, 0, 1)
496
+ # Apply normalization if needed
497
+ if config.get('inference_normalization', False):
498
+ img_tensor = transforms.Normalize(norm_mean.tolist(), norm_std.tolist())(img_tensor)
499
+ return img_tensor
500
+ else:
501
+ raise
502
 
503
+ # Prepare image tensor with safe transform
504
+ image_tensor = safe_transform(image).unsqueeze(0).to(device)
505
+ image_tensor.requires_grad = True
506
+
507
+ # Get model components
508
+ is_sequential = isinstance(model, nn.Sequential)
509
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
510
+ core_model = model[1]
511
+ else:
512
+ core_model = model
513
+
514
+ # Prepare model for layer extraction
515
+ if config.get('top_layer', 'all') != 'all':
516
+ new_model = extract_middle_layers(core_model, config['top_layer'])
517
+ else:
518
+ new_model = model
519
+
520
  # Get original predictions
521
  with torch.no_grad():
522
+ if config.get('inference_normalization', False):
523
+ output_original = model(image_tensor)
 
 
 
 
 
 
 
 
524
  else:
525
+ output_original = core_model(image_tensor)
 
 
 
 
 
 
 
526
 
527
  probs_orig = F.softmax(output_original, dim=1)
528
  conf_orig, classes_orig = torch.max(probs_orig, 1)
529
 
530
+ # Get least confident classes for IncreaseConfidence
531
+ if config['loss_infer'] == 'IncreaseConfidence':
532
+ _, least_confident_classes = torch.topk(probs_orig, k=int(n_classes / 10), largest=False)
533
+
534
+ # Setup for Prior-Guided Drift Diffusion
535
+ noisy_features = None
536
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
537
+ print(f"Setting up Prior-Guided Drift Diffusion...")
538
+ added_noise = config.get('initial_inference_noise_ratio', 0.05) * torch.randn_like(image_tensor).to(device)
539
+ noisy_image_tensor = image_tensor + added_noise
540
+ noisy_features = new_model(noisy_image_tensor)
541
+
542
  # Initialize inference step
543
  infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
544
 
545
  # Storage for inference steps
 
546
  x = image_tensor.clone().detach().requires_grad_(True)
547
  all_steps = [image_tensor[0].detach().cpu()]
548
 
549
+ selected_inferred_patterns = []
550
+ perceived_categories = []
551
+ confidence_list = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  # Main inference loop
554
  print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
555
+
556
  for i in range(config['n_itr']):
557
  # Reset gradients
558
  x.grad = None
559
 
560
+ if i == 0:
561
+ # Get predictions for first iteration
562
+ if config.get('inference_normalization', False):
563
+ output = model(x)
564
+ else:
565
+ output = core_model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
 
567
+ if isinstance(output, torch.Tensor) and output.size(-1) == n_classes:
568
+ probs = F.softmax(output, dim=1)
569
+ conf, classes = torch.max(probs, 1)
 
 
570
  else:
571
+ probs = 0
572
+ conf = 0
573
+ classes = 'N/A'
574
+ else:
575
+ # Calculate loss and gradients
576
+ try:
577
+ # Forward pass through new_model for feature extraction
578
+ features = new_model(x)
579
 
580
+ if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
581
+ assert config.get('loss_function', 'MSE') == 'MSE', "Prior-Guided Drift Diffusion requires MSE loss"
582
+ if noisy_features is not None:
583
+ loss = F.mse_loss(features, noisy_features)
584
+ grad = torch.autograd.grad(loss, x)[0]
585
+ adjusted_grad = infer_step.step(x, grad)
586
+ else:
587
+ raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
588
 
589
+ elif config['loss_infer'] == 'IncreaseConfidence':
590
+ # Calculate loss using least confident classes
591
+ num_target_classes = min(int(n_classes / 10), least_confident_classes.size(1))
592
+ target_classes = least_confident_classes[0, :num_target_classes]
593
+
594
+ loss = calculate_loss(features, target_classes.tolist(), config.get('loss_function', 'CE'))
595
+ grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
596
+ adjusted_grad = infer_step.step(x, grad)
597
 
598
+ else:
599
+ raise ValueError(f"Loss inference method {config['loss_infer']} not supported")
600
+
601
+ if grad is None:
602
+ print("Fall-back: Direct gradient calculation failed")
603
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
604
+ x = infer_step.project(x.clone() + random_noise)
605
+ else:
606
+ # Add diffusion noise if specified
607
+ diffusion_noise = config.get('diffusion_noise_ratio', 0.0) * torch.randn_like(x).to(device)
608
+ x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
609
+
610
+ except Exception as e:
611
+ print(f"Fall-back: Error in gradient calculation: {e}")
612
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
613
+ x = infer_step.project(x.clone() + random_noise)
614
 
615
  # Store step if in iterations_to_show
616
+ if i+1 in config.get('iterations_to_show', []) or i+1 == config['n_itr']:
617
  all_steps.append(x[0].detach().cpu())
618
+ selected_inferred_patterns.append(x[0].detach().cpu())
619
+
620
+ # Get current predictions
621
+ with torch.no_grad():
622
+ if config.get('inference_normalization', False):
623
+ current_output = model(x)
624
+ else:
625
+ current_output = core_model(x)
626
+
627
+ if isinstance(current_output, torch.Tensor) and current_output.size(-1) == n_classes:
628
+ current_probs = F.softmax(current_output, dim=1)
629
+ current_conf, current_classes = torch.max(current_probs, 1)
630
+ perceived_categories.append(current_classes.item())
631
+ confidence_list.append(current_conf.item())
632
+ else:
633
+ perceived_categories.append('N/A')
634
+ confidence_list.append(0.0)
635
 
636
+ # Final predictions
637
  with torch.no_grad():
638
+ if config.get('inference_normalization', False):
639
+ final_output = model(x)
 
 
 
640
  else:
641
+ final_output = core_model(x)
 
 
 
 
642
 
643
  final_probs = F.softmax(final_output, dim=1)
644
  final_conf, final_classes = torch.max(final_probs, 1)
645
 
 
 
646
  total_time = time.time() - inference_start
 
647
 
648
  print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
649
  print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
 
650
  print(f"Total inference time: {total_time:.2f} seconds")
651
 
652
+ # Return results in Code 1 format
653
  return {
654
  'final_image': x[0].detach().cpu(),
655
  'steps': all_steps,
656
  'original_class': classes_orig.item(),
657
  'original_confidence': conf_orig.item(),
658
  'final_class': final_classes.item(),
659
+ 'final_confidence': final_conf.item(),
660
+ 'all_categories': perceived_categories,
661
+ 'all_confidences': confidence_list,
662
  }
663
 
 
664
  def show_inference_steps(steps, figsize=(15, 10)):
665
+ """Show inference steps using matplotlib."""
666
+ try:
667
+ import matplotlib.pyplot as plt
668
+
669
+ n_steps = len(steps)
670
+ fig, axes = plt.subplots(1, n_steps, figsize=figsize)
671
+
672
+ if n_steps == 1:
673
+ axes = [axes]
674
+
675
+ for i, step_img in enumerate(steps):
676
+ if isinstance(step_img, torch.Tensor):
677
+ img = step_img.permute(1, 2, 0).numpy()
678
+ img = np.clip(img, 0, 1)
679
+ else:
680
+ img = step_img
681
+
682
+ axes[i].imshow(img)
683
+ axes[i].set_title(f"Step {i+1}")
684
+ axes[i].axis('off')
685
+
686
+ plt.tight_layout()
687
+ return fig
688
+
689
+ except ImportError:
690
+ print("Fall-back: matplotlib not available for visualization")
691
+ return None
692
+ except Exception as e:
693
+ print(f"Fall-back: Visualization failed: {e}")
694
+ return None
695
+
696
+ # Export the main classes and functions
697
+ __all__ = ['GenerativeInferenceModel', 'get_inference_configs', 'show_inference_steps']
logs/model_loading_resnet50_robust_face.log CHANGED
@@ -2,45 +2,8 @@
2
  ===== MODEL LOADING REPORT: resnet50_robust_face =====
3
  Total parameters in checkpoint: 320
4
  Total parameters in model: 320
5
- Missing keys: 267 parameters
6
- Unexpected keys: 320 parameters
7
- Successfully loaded: 0 parameters (0.0%)
8
- Loading status: FAILED - Critical parameters missing, will not function properly
9
-
10
- ⚠️ WARNING: Loading from checkpoint is too incomplete.
11
- ⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.
12
- ✅ Successfully loaded PyTorch's pretrained ResNet50 model
13
-
14
- Missing keys by layer type:
15
- layer3: 95 parameters
16
- layer2: 65 parameters
17
- layer1: 50 parameters
18
- layer4: 50 parameters
19
- bn1: 4 parameters
20
- fc: 2 parameters
21
- conv1: 1 parameters
22
-
23
- First 10 missing keys:
24
- 1. bn1.bias
25
- 2. bn1.running_mean
26
- 3. bn1.running_var
27
- 4. bn1.weight
28
- 5. conv1.weight
29
- 6. fc.bias
30
- 7. fc.weight
31
- 8. layer1.0.bn1.bias
32
- 9. layer1.0.bn1.running_mean
33
- 10. layer1.0.bn1.running_var
34
-
35
- First 10 unexpected keys:
36
- 1. model.bn1.bias
37
- 2. model.bn1.num_batches_tracked
38
- 3. model.bn1.running_mean
39
- 4. model.bn1.running_var
40
- 5. model.bn1.weight
41
- 6. model.conv1.weight
42
- 7. model.fc.bias
43
- 8. model.fc.weight
44
- 9. model.layer1.0.bn1.bias
45
- 10. model.layer1.0.bn1.num_batches_tracked
46
  ========================================
 
2
  ===== MODEL LOADING REPORT: resnet50_robust_face =====
3
  Total parameters in checkpoint: 320
4
  Total parameters in model: 320
5
+ Missing keys: 0 parameters
6
+ Unexpected keys: 0 parameters
7
+ Successfully loaded: 320 parameters (100.0%)
8
+ Loading status: COMPLETE - All important parameters loaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ========================================
models/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:40bfb9a204f1d9a305ed6374acbfc55fe2745433cf1e421952d4b461f577486a
3
- size 196695413
 
 
 
 
models/resnet50_robust.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:380b14e6f9750bffa1447cf7017f65da4dc5ce71a3dd112f107515dcf7b14d9d
3
- size 204818947
 
 
 
 
models/resnet50_robust_face_100_checkpoint.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c48a5c16ca0d5ac4cb20f1b98e2128838746f18b658728ac661f1ffd589c37bf
3
- size 196695413
 
 
 
 
models/robust_resnet50.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:380b14e6f9750bffa1447cf7017f65da4dc5ce71a3dd112f107515dcf7b14d9d
3
- size 204818947
 
 
 
 
models/standard_resnet50.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:72d4a99582db5d7fa86c3fd2a089f0bfd6a10f69d635bca51f6ad72ac6b458f0
3
- size 204818947
 
 
 
 
stimuli/RandomizedPhaseOvalGray.png ADDED