Spaces:
Sleeping
Sleeping
Tahereh
commited on
Commit
·
420f791
1
Parent(s):
0b968b4
Update to Generative Inference for Psychiatry Demo: add Noise stimulus, update parameters, fix model loading, and improve UI
Browse files- .DS_Store +0 -0
- .gitignore +30 -0
- .smbdeleteAAA29de78c16 +0 -0
- DIFFERENCES.md +137 -0
- README.md +1 -1
- TROUBLESHOOTING.md +140 -0
- __pycache__/inference.cpython-311.pyc +0 -0
- app.py +121 -39
- face_vase_black.png +0 -0
- huggingface-metadata.json +1 -1
- inference.py +485 -729
- logs/model_loading_resnet50_robust_face.log +4 -41
- models/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt +0 -3
- models/resnet50_robust.pt +0 -3
- models/resnet50_robust_face_100_checkpoint.pt +0 -3
- models/robust_resnet50.pt +0 -3
- models/standard_resnet50.pt +0 -3
- stimuli/RandomizedPhaseOvalGray.png +0 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model checkpoints (downloaded automatically)
|
| 2 |
+
models/*.pt
|
| 3 |
+
models/*.ckpt
|
| 4 |
+
|
| 5 |
+
# Python cache
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
*.so
|
| 10 |
+
|
| 11 |
+
# Logs
|
| 12 |
+
logs/
|
| 13 |
+
*.log
|
| 14 |
+
|
| 15 |
+
# Environment
|
| 16 |
+
.env
|
| 17 |
+
.venv
|
| 18 |
+
env/
|
| 19 |
+
venv/
|
| 20 |
+
|
| 21 |
+
# IDE
|
| 22 |
+
.vscode/
|
| 23 |
+
.idea/
|
| 24 |
+
*.swp
|
| 25 |
+
*.swo
|
| 26 |
+
|
| 27 |
+
# OS
|
| 28 |
+
.DS_Store
|
| 29 |
+
Thumbs.db
|
| 30 |
+
|
.smbdeleteAAA29de78c16
ADDED
|
Binary file (10.9 kB). View file
|
|
|
DIFFERENCES.md
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Differences Between Reference Code and Current Implementation
|
| 2 |
+
|
| 3 |
+
## Critical Differences Affecting Results
|
| 4 |
+
|
| 5 |
+
### 1. **First Iteration Handling** ⚠️ **CRITICAL**
|
| 6 |
+
**Reference Code:**
|
| 7 |
+
```python
|
| 8 |
+
if itr == 0:
|
| 9 |
+
# Don't add priors or diffusion noise to the first iteration
|
| 10 |
+
output = model(image_tensor)
|
| 11 |
+
# ... just get predictions, no gradient update
|
| 12 |
+
else:
|
| 13 |
+
# Calculate loss and gradients
|
| 14 |
+
if loss_infer == 'PGDD':
|
| 15 |
+
loss = torch.nn.functional.mse_loss(features, noisy_features)
|
| 16 |
+
grad = torch.autograd.grad(loss, image_tensor)[0]
|
| 17 |
+
adjusted_grad = inferstep.step(image_tensor, grad)
|
| 18 |
+
# ... apply gradient and noise
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
**Current Implementation:**
|
| 22 |
+
- **MISSING**: No check for `itr == 0` or `i == 0`
|
| 23 |
+
- Applies gradients and diffusion noise from the very first iteration
|
| 24 |
+
- This causes different starting behavior
|
| 25 |
+
|
| 26 |
+
### 2. **Model Extraction for PGDD**
|
| 27 |
+
**Reference Code:**
|
| 28 |
+
```python
|
| 29 |
+
new_model = extract_middle_layers(model.module, top_layer)
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
**Current Implementation:**
|
| 33 |
+
- Complex logic to handle Sequential models with normalizers
|
| 34 |
+
- Extracts from `model[1]` if Sequential, otherwise from `model`
|
| 35 |
+
- May handle DataParallel differently
|
| 36 |
+
|
| 37 |
+
### 3. **Gradient Calculation**
|
| 38 |
+
**Reference Code:**
|
| 39 |
+
```python
|
| 40 |
+
grad = torch.autograd.grad(loss, image_tensor)[0] # No retain_graph for PGDD
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
**Current Implementation:**
|
| 44 |
+
- Same for PGDD (no retain_graph)
|
| 45 |
+
- But uses `retain_graph=True` for IncreaseConfidence
|
| 46 |
+
|
| 47 |
+
### 4. **Normalization Handling**
|
| 48 |
+
**Reference Code:**
|
| 49 |
+
- Normalization is applied in the transform at the beginning
|
| 50 |
+
- `inference_normalization` controls whether transform includes normalization
|
| 51 |
+
- Model forward pass uses the already-normalized tensor
|
| 52 |
+
|
| 53 |
+
**Current Implementation:**
|
| 54 |
+
- Complex logic checking if model is Sequential with NormalizeByChannelMeanStd
|
| 55 |
+
- May apply normalization multiple times or inconsistently
|
| 56 |
+
- Different paths for sequential vs non-sequential models
|
| 57 |
+
|
| 58 |
+
### 5. **Variable Naming and Structure**
|
| 59 |
+
**Reference Code:**
|
| 60 |
+
- Uses `image_tensor` throughout the loop
|
| 61 |
+
- Directly modifies `image_tensor` with `requires_grad=True`
|
| 62 |
+
|
| 63 |
+
**Current Implementation:**
|
| 64 |
+
- Creates separate `x = image_tensor.clone().detach().requires_grad_(True)`
|
| 65 |
+
- Uses `x` in the loop instead of `image_tensor`
|
| 66 |
+
|
| 67 |
+
### 6. **Loss Function for IncreaseConfidence**
|
| 68 |
+
**Reference Code:**
|
| 69 |
+
```python
|
| 70 |
+
loss = calculate_loss(features, least_confident_classes[0], loss_function)
|
| 71 |
+
# Uses CrossEntropyLoss or MSELoss based on loss_function
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
**Current Implementation:**
|
| 75 |
+
```python
|
| 76 |
+
# Creates one-hot targets and uses MSE on softmax outputs
|
| 77 |
+
loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
|
| 78 |
+
```
|
| 79 |
+
- Different loss calculation method
|
| 80 |
+
- Uses MSE on softmax probabilities vs CrossEntropy on logits
|
| 81 |
+
|
| 82 |
+
### 7. **Diffusion Noise Application**
|
| 83 |
+
**Reference Code:**
|
| 84 |
+
```python
|
| 85 |
+
if itr == 0:
|
| 86 |
+
# Skip noise
|
| 87 |
+
else:
|
| 88 |
+
diffusion_noise = diffusion_noise_ratio * torch.randn_like(image_tensor).cuda()
|
| 89 |
+
if loss_infer == 'GradModulation':
|
| 90 |
+
image_tensor = inferstep.project(
|
| 91 |
+
image_tensor.clone() +
|
| 92 |
+
adjusted_grad * grad_modulation +
|
| 93 |
+
diffusion_noise * grad_modulation
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
image_tensor = inferstep.project(
|
| 97 |
+
image_tensor.clone() + adjusted_grad + diffusion_noise
|
| 98 |
+
)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
**Current Implementation:**
|
| 102 |
+
- Always applies diffusion noise (no `itr == 0` check)
|
| 103 |
+
- Applies noise in all iterations including the first
|
| 104 |
+
|
| 105 |
+
### 8. **Model Forward Pass in Loop**
|
| 106 |
+
**Reference Code:**
|
| 107 |
+
```python
|
| 108 |
+
if inference_config['misc_info'].get('smooth_inference', False):
|
| 109 |
+
# Smooth inference logic
|
| 110 |
+
else:
|
| 111 |
+
new_model.zero_grad()
|
| 112 |
+
features = new_model(image_tensor)
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
**Current Implementation:**
|
| 116 |
+
```python
|
| 117 |
+
x.grad = None # Instead of new_model.zero_grad()
|
| 118 |
+
if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
|
| 119 |
+
output = layer_model(x)
|
| 120 |
+
else:
|
| 121 |
+
output = model(x)
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
## Summary of Impact
|
| 125 |
+
|
| 126 |
+
1. **First iteration difference**: Most critical - reference skips gradient update on iteration 0
|
| 127 |
+
2. **Normalization**: Different application may cause numerical differences
|
| 128 |
+
3. **Loss calculation**: Different methods for IncreaseConfidence
|
| 129 |
+
4. **Model extraction**: May extract different layers due to Sequential handling
|
| 130 |
+
|
| 131 |
+
## Recommended Fixes
|
| 132 |
+
|
| 133 |
+
1. Add `if i == 0:` check to skip gradient update on first iteration
|
| 134 |
+
2. Simplify model extraction to match reference: `extract_middle_layers(model.module, top_layer)`
|
| 135 |
+
3. Align loss calculation for IncreaseConfidence with reference
|
| 136 |
+
4. Ensure normalization is applied consistently
|
| 137 |
+
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title: Generative Inference Demo
|
| 3 |
emoji: 🧠
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Generative Inference for Psychiatry Demo
|
| 3 |
emoji: 🧠
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
TROUBLESHOOTING.md
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Troubleshooting: Why It Works on Hugging Face Spaces But Not Locally
|
| 2 |
+
|
| 3 |
+
## Common Issues and Solutions
|
| 4 |
+
|
| 5 |
+
### 1. **Missing Dependencies** ⚠️ (Most Common)
|
| 6 |
+
|
| 7 |
+
**Problem**: The required Python packages are not installed locally.
|
| 8 |
+
|
| 9 |
+
**Solution**: Install all dependencies:
|
| 10 |
+
```bash
|
| 11 |
+
cd /home/tahereh/engram/users/Tahereh/Codes/Public_Codes/Generative_Inference_Faces
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
**Required packages**:
|
| 16 |
+
- `torch` and `torchvision` (PyTorch)
|
| 17 |
+
- `gradio` (for the web interface)
|
| 18 |
+
- `numpy`, `pillow` (PIL), `matplotlib`
|
| 19 |
+
- `requests`, `tqdm`, `huggingface_hub`
|
| 20 |
+
|
| 21 |
+
### 2. **GPU Decorator** ✅ (Fixed)
|
| 22 |
+
|
| 23 |
+
**Problem**: The `@GPU` decorator from Hugging Face Spaces is not available locally.
|
| 24 |
+
|
| 25 |
+
**Solution**: The code now automatically handles this:
|
| 26 |
+
- On Hugging Face Spaces: Uses the `spaces.GPU` decorator
|
| 27 |
+
- Locally: Uses a no-op decorator (GPU detection is automatic via PyTorch)
|
| 28 |
+
|
| 29 |
+
**Status**: ✅ Fixed in the code
|
| 30 |
+
|
| 31 |
+
### 3. **Port Configuration** ✅ (Fixed)
|
| 32 |
+
|
| 33 |
+
**Problem**: Port configuration was inconsistent between local and Spaces environments.
|
| 34 |
+
|
| 35 |
+
**Solution**: The code now:
|
| 36 |
+
- Uses port 7860 by default (same as Spaces)
|
| 37 |
+
- Allows custom port via `--port` argument
|
| 38 |
+
- Automatically detects Hugging Face Spaces environment
|
| 39 |
+
|
| 40 |
+
**Status**: ✅ Fixed in the code
|
| 41 |
+
|
| 42 |
+
### 4. **Model Files Not Downloaded**
|
| 43 |
+
|
| 44 |
+
**Problem**: Model checkpoint files may not be downloaded yet.
|
| 45 |
+
|
| 46 |
+
**Solution**: The code will automatically download models on first run, but you can verify:
|
| 47 |
+
```bash
|
| 48 |
+
ls models/
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Expected files:
|
| 52 |
+
- `resnet50_robust.pt`
|
| 53 |
+
- `standard_resnet50.pt` (optional)
|
| 54 |
+
- `resnet50_robust_face_100_checkpoint.pt` (optional)
|
| 55 |
+
|
| 56 |
+
### 5. **Missing Stimuli Images**
|
| 57 |
+
|
| 58 |
+
**Problem**: Example images may be missing.
|
| 59 |
+
|
| 60 |
+
**Solution**: Verify stimuli directory exists:
|
| 61 |
+
```bash
|
| 62 |
+
ls stimuli/
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
All example images should be present for the demo to work fully.
|
| 66 |
+
|
| 67 |
+
### 6. **CUDA/GPU Issues**
|
| 68 |
+
|
| 69 |
+
**Problem**: GPU may not be available or configured correctly.
|
| 70 |
+
|
| 71 |
+
**Solution**: The code automatically detects available hardware:
|
| 72 |
+
- CUDA (NVIDIA GPUs)
|
| 73 |
+
- MPS (Apple Silicon)
|
| 74 |
+
- CPU (fallback)
|
| 75 |
+
|
| 76 |
+
Check your setup:
|
| 77 |
+
```python
|
| 78 |
+
import torch
|
| 79 |
+
print("CUDA available:", torch.cuda.is_available())
|
| 80 |
+
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### 7. **Python Version**
|
| 84 |
+
|
| 85 |
+
**Problem**: Incompatible Python version.
|
| 86 |
+
|
| 87 |
+
**Solution**: Use Python 3.8+ (tested with 3.11.5):
|
| 88 |
+
```bash
|
| 89 |
+
python --version
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## Quick Start Guide
|
| 93 |
+
|
| 94 |
+
1. **Install dependencies**:
|
| 95 |
+
```bash
|
| 96 |
+
pip install -r requirements.txt
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
2. **Run the app**:
|
| 100 |
+
```bash
|
| 101 |
+
python app.py
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Or with a custom port:
|
| 105 |
+
```bash
|
| 106 |
+
python app.py --port 8080
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
3. **Access the web interface**:
|
| 110 |
+
- Open your browser to `http://localhost:7860`
|
| 111 |
+
- Or the port you specified
|
| 112 |
+
|
| 113 |
+
## Differences Between Hugging Face Spaces and Local
|
| 114 |
+
|
| 115 |
+
| Feature | Hugging Face Spaces | Local |
|
| 116 |
+
|---------|-------------------|-------|
|
| 117 |
+
| GPU Decorator | `@spaces.GPU` available | No-op decorator (automatic GPU) |
|
| 118 |
+
| Port | Set via `PORT` env var | Default 7860, or `--port` arg |
|
| 119 |
+
| Dependencies | Pre-installed | Must install manually |
|
| 120 |
+
| Environment | `SPACE_ID` env var set | Not set |
|
| 121 |
+
| Model Storage | Persistent storage | Local `models/` directory |
|
| 122 |
+
|
| 123 |
+
## Testing the Fixes
|
| 124 |
+
|
| 125 |
+
After applying the fixes, test with:
|
| 126 |
+
```bash
|
| 127 |
+
# Check imports work
|
| 128 |
+
python -c "import gradio, torch, numpy, PIL; print('All imports OK')"
|
| 129 |
+
|
| 130 |
+
# Run the app
|
| 131 |
+
python app.py --port 7860
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## Still Having Issues?
|
| 135 |
+
|
| 136 |
+
1. **Check error messages**: Look for specific import errors or file not found errors
|
| 137 |
+
2. **Verify Python environment**: Make sure you're using the correct virtual environment
|
| 138 |
+
3. **Check file permissions**: Ensure the `models/` and `stimuli/` directories are writable
|
| 139 |
+
4. **Review logs**: Check the `logs/` directory for model loading issues
|
| 140 |
+
|
__pycache__/inference.cpython-311.pyc
ADDED
|
Binary file (40.8 kB). View file
|
|
|
app.py
CHANGED
|
@@ -4,10 +4,13 @@ import numpy as np
|
|
| 4 |
from PIL import Image
|
| 5 |
try:
|
| 6 |
from spaces import GPU
|
|
|
|
| 7 |
except ImportError:
|
| 8 |
# Define a no-op decorator if running locally
|
| 9 |
def GPU(func):
|
|
|
|
| 10 |
return func
|
|
|
|
| 11 |
|
| 12 |
import os
|
| 13 |
import argparse
|
|
@@ -15,7 +18,7 @@ from inference import GenerativeInferenceModel, get_inference_configs
|
|
| 15 |
|
| 16 |
# Parse command line arguments
|
| 17 |
parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
|
| 18 |
-
parser.add_argument('--port', type=int, default=
|
| 19 |
args = parser.parse_args()
|
| 20 |
|
| 21 |
# Create model directories if they don't exist
|
|
@@ -26,13 +29,54 @@ os.makedirs("stimuli", exist_ok=True)
|
|
| 26 |
if "SPACE_ID" in os.environ:
|
| 27 |
default_port = int(os.environ.get("PORT", 7860))
|
| 28 |
else:
|
| 29 |
-
default_port =
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Initialize model
|
| 32 |
model = GenerativeInferenceModel()
|
| 33 |
|
| 34 |
# Define example images and their parameters with updated values from the research
|
| 35 |
examples = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
{
|
| 37 |
"image": os.path.join("stimuli", "Neon_Color_Circle.jpg"),
|
| 38 |
"name": "Neon Color Spreading",
|
|
@@ -91,25 +135,6 @@ examples = [
|
|
| 91 |
"epsilon": 20.0
|
| 92 |
}
|
| 93 |
},
|
| 94 |
-
{
|
| 95 |
-
"image": os.path.join("stimuli", "face_vase.png"),
|
| 96 |
-
"name": "Rubin's Face-Vase (Object Prior)",
|
| 97 |
-
"wiki": "https://en.wikipedia.org/wiki/Rubin_vase",
|
| 98 |
-
"papers": [
|
| 99 |
-
"[Figure-Ground Perception](https://en.wikipedia.org/wiki/Figure-ground_(perception))",
|
| 100 |
-
"[Bistable Perception](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 101 |
-
],
|
| 102 |
-
"method": "Prior-Guided Drift Diffusion",
|
| 103 |
-
"reverse_diff": {
|
| 104 |
-
"model": "resnet50_robust",
|
| 105 |
-
"layer": "avgpool",
|
| 106 |
-
"initial_noise": 0.9,
|
| 107 |
-
"diffusion_noise": 0.003,
|
| 108 |
-
"step_size": 0.58,
|
| 109 |
-
"iterations": 100,
|
| 110 |
-
"epsilon": 0.81
|
| 111 |
-
}
|
| 112 |
-
},
|
| 113 |
{
|
| 114 |
"image": os.path.join("stimuli", "Confetti_illusion.png"),
|
| 115 |
"name": "Confetti Illusion",
|
|
@@ -223,21 +248,76 @@ def run_inference(image, model_type, inference_type, eps_value, num_iterations,
|
|
| 223 |
# Create animation frames
|
| 224 |
frames = []
|
| 225 |
for i, step_image in enumerate(all_steps):
|
| 226 |
-
# Convert tensor to PIL image
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
# Convert the final output image to PIL
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
# Return the final inferred image and the animation frames directly
|
| 234 |
return final_image, frames
|
| 235 |
|
| 236 |
# Helper function to apply example parameters
|
| 237 |
def apply_example(example):
|
|
|
|
|
|
|
| 238 |
return [
|
| 239 |
-
|
| 240 |
-
"
|
| 241 |
example["method"], # Inference type
|
| 242 |
example["reverse_diff"]["epsilon"], # Epsilon value
|
| 243 |
example["reverse_diff"]["iterations"], # Number of iterations
|
|
@@ -249,7 +329,7 @@ def apply_example(example):
|
|
| 249 |
]
|
| 250 |
|
| 251 |
# Define the interface
|
| 252 |
-
with gr.Blocks(title="Generative Inference Demo", css="""
|
| 253 |
.purple-button {
|
| 254 |
background-color: #8B5CF6 !important;
|
| 255 |
color: white !important;
|
|
@@ -259,7 +339,7 @@ with gr.Blocks(title="Generative Inference Demo", css="""
|
|
| 259 |
background-color: #7C3AED !important;
|
| 260 |
}
|
| 261 |
""") as demo:
|
| 262 |
-
gr.Markdown("# Generative Inference Demo")
|
| 263 |
gr.Markdown("This demo showcases how neural networks can perceive visual illusions and develop Gestalt principles of perceptual organization through generative inference.")
|
| 264 |
|
| 265 |
gr.Markdown("""
|
|
@@ -273,7 +353,9 @@ with gr.Blocks(title="Generative Inference Demo", css="""
|
|
| 273 |
with gr.Row():
|
| 274 |
with gr.Column(scale=1):
|
| 275 |
# Inputs
|
| 276 |
-
|
|
|
|
|
|
|
| 277 |
|
| 278 |
# Run Inference button right below the image
|
| 279 |
run_button = gr.Button("🪄 Run Generative Inference", variant="primary", elem_classes="purple-button")
|
|
@@ -286,7 +368,7 @@ with gr.Blocks(title="Generative Inference Demo", css="""
|
|
| 286 |
with gr.Row():
|
| 287 |
model_choice = gr.Dropdown(
|
| 288 |
choices=["resnet50_robust", "standard_resnet50", "resnet50_robust_face"], # "resnet50_robust_face" - hidden for deployment
|
| 289 |
-
value="
|
| 290 |
label="Model"
|
| 291 |
)
|
| 292 |
|
|
@@ -297,21 +379,21 @@ with gr.Blocks(title="Generative Inference Demo", css="""
|
|
| 297 |
)
|
| 298 |
|
| 299 |
with gr.Row():
|
| 300 |
-
eps_slider = gr.Slider(minimum=0.0, maximum=
|
| 301 |
-
iterations_slider = gr.Slider(minimum=1, maximum=600, value=
|
| 302 |
|
| 303 |
with gr.Row():
|
| 304 |
-
initial_noise_slider = gr.Slider(minimum=0.0, maximum=
|
| 305 |
label="Drift Noise")
|
| 306 |
-
diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=0
|
| 307 |
label="Diffusion Noise") # Corrected name
|
| 308 |
|
| 309 |
with gr.Row():
|
| 310 |
-
step_size_slider = gr.Slider(minimum=0.
|
| 311 |
label="Update Rate") # Added step size slider
|
| 312 |
layer_choice = gr.Dropdown(
|
| 313 |
choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
|
| 314 |
-
value="
|
| 315 |
label="Model Layer"
|
| 316 |
)
|
| 317 |
|
|
@@ -403,10 +485,10 @@ with gr.Blocks(title="Generative Inference Demo", css="""
|
|
| 403 |
|
| 404 |
# Launch the demo
|
| 405 |
if __name__ == "__main__":
|
| 406 |
-
print(f"Starting server on port {
|
| 407 |
demo.launch(
|
| 408 |
server_name="0.0.0.0",
|
| 409 |
-
server_port=
|
| 410 |
share=False,
|
| 411 |
debug=True
|
| 412 |
)
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
try:
|
| 6 |
from spaces import GPU
|
| 7 |
+
print("Running on Hugging Face Spaces - GPU decorator available")
|
| 8 |
except ImportError:
|
| 9 |
# Define a no-op decorator if running locally
|
| 10 |
def GPU(func):
|
| 11 |
+
"""No-op decorator for local execution (GPU handling is automatic)"""
|
| 12 |
return func
|
| 13 |
+
print("Running locally - GPU decorator not available (using automatic GPU detection)")
|
| 14 |
|
| 15 |
import os
|
| 16 |
import argparse
|
|
|
|
| 18 |
|
| 19 |
# Parse command line arguments
|
| 20 |
parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
|
| 21 |
+
parser.add_argument('--port', type=int, default=None, help='Port to run the server on')
|
| 22 |
args = parser.parse_args()
|
| 23 |
|
| 24 |
# Create model directories if they don't exist
|
|
|
|
| 29 |
if "SPACE_ID" in os.environ:
|
| 30 |
default_port = int(os.environ.get("PORT", 7860))
|
| 31 |
else:
|
| 32 |
+
default_port = 7860 # Use same default port locally
|
| 33 |
+
|
| 34 |
+
# Use command line port if provided, otherwise use default
|
| 35 |
+
server_port = args.port if args.port is not None else default_port
|
| 36 |
|
| 37 |
# Initialize model
|
| 38 |
model = GenerativeInferenceModel()
|
| 39 |
|
| 40 |
# Define example images and their parameters with updated values from the research
|
| 41 |
examples = [
|
| 42 |
+
{
|
| 43 |
+
"image": os.path.join("stimuli", "face_vase.png"),
|
| 44 |
+
"name": "Rubin's Face-Vase (Object Prior)",
|
| 45 |
+
"wiki": "https://en.wikipedia.org/wiki/Rubin_vase",
|
| 46 |
+
"papers": [
|
| 47 |
+
"[Figure-Ground Perception](https://en.wikipedia.org/wiki/Figure-ground_(perception))",
|
| 48 |
+
"[Bistable Perception](https://doi.org/10.1016/j.tics.2003.08.003)"
|
| 49 |
+
],
|
| 50 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 51 |
+
"reverse_diff": {
|
| 52 |
+
"model": "resnet50_robust_face",
|
| 53 |
+
"layer": "layer4",
|
| 54 |
+
"initial_noise": 0.0,
|
| 55 |
+
"diffusion_noise": 0.006,
|
| 56 |
+
"step_size": 0.18,
|
| 57 |
+
"iterations": 100,
|
| 58 |
+
"epsilon": 9.53
|
| 59 |
+
}
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"image": os.path.join("stimuli", "RandomizedPhaseOvalGray.png"),
|
| 63 |
+
"name": "Noise (Randomized Phase Oval)",
|
| 64 |
+
"wiki": "https://en.wikipedia.org/wiki/Visual_noise",
|
| 65 |
+
"papers": [
|
| 66 |
+
"[Perceptual Organization](https://doi.org/10.1016/j.tics.2003.08.003)",
|
| 67 |
+
"[Pattern Recognition](https://en.wikipedia.org/wiki/Pattern_recognition)"
|
| 68 |
+
],
|
| 69 |
+
"method": "Prior-Guided Drift Diffusion",
|
| 70 |
+
"reverse_diff": {
|
| 71 |
+
"model": "resnet50_robust_face",
|
| 72 |
+
"layer": "all",
|
| 73 |
+
"initial_noise": 0.0,
|
| 74 |
+
"diffusion_noise": 0.05,
|
| 75 |
+
"step_size": 1.12,
|
| 76 |
+
"iterations": 428,
|
| 77 |
+
"epsilon": 198.62
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
{
|
| 81 |
"image": os.path.join("stimuli", "Neon_Color_Circle.jpg"),
|
| 82 |
"name": "Neon Color Spreading",
|
|
|
|
| 135 |
"epsilon": 20.0
|
| 136 |
}
|
| 137 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
{
|
| 139 |
"image": os.path.join("stimuli", "Confetti_illusion.png"),
|
| 140 |
"name": "Confetti Illusion",
|
|
|
|
| 248 |
# Create animation frames
|
| 249 |
frames = []
|
| 250 |
for i, step_image in enumerate(all_steps):
|
| 251 |
+
# Convert tensor to PIL image with proper error handling
|
| 252 |
+
try:
|
| 253 |
+
# Ensure tensor is on CPU and detached
|
| 254 |
+
if isinstance(step_image, torch.Tensor):
|
| 255 |
+
step_image = step_image.detach().cpu()
|
| 256 |
+
# Handle different tensor shapes
|
| 257 |
+
if len(step_image.shape) == 4: # [B, C, H, W]
|
| 258 |
+
step_image = step_image[0] # Take first batch item
|
| 259 |
+
elif len(step_image.shape) == 3: # [C, H, W]
|
| 260 |
+
pass # Already correct shape
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unexpected tensor shape: {step_image.shape}")
|
| 263 |
+
|
| 264 |
+
# Clamp values to [0, 1] range before converting
|
| 265 |
+
step_image = torch.clamp(step_image, 0, 1)
|
| 266 |
+
# Convert to numpy and ensure contiguous array
|
| 267 |
+
step_np = step_image.permute(1, 2, 0).numpy()
|
| 268 |
+
# Ensure it's a contiguous array with correct dtype
|
| 269 |
+
step_np = np.ascontiguousarray(step_np, dtype=np.float32)
|
| 270 |
+
# Convert to uint8
|
| 271 |
+
step_np = (step_np * 255).astype(np.uint8)
|
| 272 |
+
# Create PIL image
|
| 273 |
+
step_pil = Image.fromarray(step_np, mode='RGB')
|
| 274 |
+
frames.append(step_pil)
|
| 275 |
+
else:
|
| 276 |
+
print(f"Warning: step_image at index {i} is not a tensor: {type(step_image)}")
|
| 277 |
+
except Exception as e:
|
| 278 |
+
print(f"Error converting step {i} to PIL image: {e}, shape: {step_image.shape if hasattr(step_image, 'shape') else 'N/A'}")
|
| 279 |
+
# Skip this frame if conversion fails
|
| 280 |
+
continue
|
| 281 |
|
| 282 |
# Convert the final output image to PIL
|
| 283 |
+
try:
|
| 284 |
+
if isinstance(output_image, torch.Tensor):
|
| 285 |
+
output_image = output_image.detach().cpu()
|
| 286 |
+
# Handle different tensor shapes
|
| 287 |
+
if len(output_image.shape) == 4: # [B, C, H, W]
|
| 288 |
+
output_image = output_image[0] # Take first batch item
|
| 289 |
+
elif len(output_image.shape) == 3: # [C, H, W]
|
| 290 |
+
pass # Already correct shape
|
| 291 |
+
else:
|
| 292 |
+
raise ValueError(f"Unexpected tensor shape: {output_image.shape}")
|
| 293 |
+
|
| 294 |
+
# Clamp values to [0, 1] range before converting
|
| 295 |
+
output_image = torch.clamp(output_image, 0, 1)
|
| 296 |
+
# Convert to numpy and ensure contiguous array
|
| 297 |
+
output_np = output_image.permute(1, 2, 0).numpy()
|
| 298 |
+
# Ensure it's a contiguous array with correct dtype
|
| 299 |
+
output_np = np.ascontiguousarray(output_np, dtype=np.float32)
|
| 300 |
+
# Convert to uint8
|
| 301 |
+
output_np = (output_np * 255).astype(np.uint8)
|
| 302 |
+
# Create PIL image
|
| 303 |
+
final_image = Image.fromarray(output_np, mode='RGB')
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(f"output_image is not a tensor: {type(output_image)}")
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"Error converting final image to PIL: {e}, shape: {output_image.shape if hasattr(output_image, 'shape') else 'N/A'}")
|
| 308 |
+
# Return a black image as fallback
|
| 309 |
+
final_image = Image.new('RGB', (224, 224), color='black')
|
| 310 |
|
| 311 |
# Return the final inferred image and the animation frames directly
|
| 312 |
return final_image, frames
|
| 313 |
|
| 314 |
# Helper function to apply example parameters
|
| 315 |
def apply_example(example):
|
| 316 |
+
# Get the full path to the image file
|
| 317 |
+
image_path = os.path.abspath(example["image"]) if os.path.exists(example["image"]) else example["image"]
|
| 318 |
return [
|
| 319 |
+
image_path,
|
| 320 |
+
example["reverse_diff"]["model"], # Model type from example
|
| 321 |
example["method"], # Inference type
|
| 322 |
example["reverse_diff"]["epsilon"], # Epsilon value
|
| 323 |
example["reverse_diff"]["iterations"], # Number of iterations
|
|
|
|
| 329 |
]
|
| 330 |
|
| 331 |
# Define the interface
|
| 332 |
+
with gr.Blocks(title="Generative Inference for Psychiatry Demo", css="""
|
| 333 |
.purple-button {
|
| 334 |
background-color: #8B5CF6 !important;
|
| 335 |
color: white !important;
|
|
|
|
| 339 |
background-color: #7C3AED !important;
|
| 340 |
}
|
| 341 |
""") as demo:
|
| 342 |
+
gr.Markdown("# Generative Inference for Psychiatry Demo")
|
| 343 |
gr.Markdown("This demo showcases how neural networks can perceive visual illusions and develop Gestalt principles of perceptual organization through generative inference.")
|
| 344 |
|
| 345 |
gr.Markdown("""
|
|
|
|
| 353 |
with gr.Row():
|
| 354 |
with gr.Column(scale=1):
|
| 355 |
# Inputs
|
| 356 |
+
# Use absolute path for default image to avoid directory errors
|
| 357 |
+
default_image_path = os.path.abspath(os.path.join("stimuli", "face_vase.png")) if os.path.exists(os.path.join("stimuli", "face_vase.png")) else None
|
| 358 |
+
image_input = gr.Image(label="Input Image", type="pil", value=default_image_path)
|
| 359 |
|
| 360 |
# Run Inference button right below the image
|
| 361 |
run_button = gr.Button("🪄 Run Generative Inference", variant="primary", elem_classes="purple-button")
|
|
|
|
| 368 |
with gr.Row():
|
| 369 |
model_choice = gr.Dropdown(
|
| 370 |
choices=["resnet50_robust", "standard_resnet50", "resnet50_robust_face"], # "resnet50_robust_face" - hidden for deployment
|
| 371 |
+
value="resnet50_robust_face",
|
| 372 |
label="Model"
|
| 373 |
)
|
| 374 |
|
|
|
|
| 379 |
)
|
| 380 |
|
| 381 |
with gr.Row():
|
| 382 |
+
eps_slider = gr.Slider(minimum=0.0, maximum=200.0, value=9.53, step=0.01, label="Epsilon (Stimulus Fidelity)")
|
| 383 |
+
iterations_slider = gr.Slider(minimum=1, maximum=600, value=100, step=1, label="Number of Iterations") # Updated max to 600
|
| 384 |
|
| 385 |
with gr.Row():
|
| 386 |
+
initial_noise_slider = gr.Slider(minimum=0.0, maximum=5.0, value=0.0, step=0.01,
|
| 387 |
label="Drift Noise")
|
| 388 |
+
diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.006, step=0.001,
|
| 389 |
label="Diffusion Noise") # Corrected name
|
| 390 |
|
| 391 |
with gr.Row():
|
| 392 |
+
step_size_slider = gr.Slider(minimum=0.0, maximum=10.0, value=0.18, step=0.01,
|
| 393 |
label="Update Rate") # Added step size slider
|
| 394 |
layer_choice = gr.Dropdown(
|
| 395 |
choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
|
| 396 |
+
value="layer4",
|
| 397 |
label="Model Layer"
|
| 398 |
)
|
| 399 |
|
|
|
|
| 485 |
|
| 486 |
# Launch the demo
|
| 487 |
if __name__ == "__main__":
|
| 488 |
+
print(f"Starting server on port {server_port}")
|
| 489 |
demo.launch(
|
| 490 |
server_name="0.0.0.0",
|
| 491 |
+
server_port=server_port,
|
| 492 |
share=False,
|
| 493 |
debug=True
|
| 494 |
)
|
face_vase_black.png
ADDED
|
huggingface-metadata.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{
|
| 2 |
-
"title": "Generative Inference Demo",
|
| 3 |
"emoji": "🧠",
|
| 4 |
"colorFrom": "indigo",
|
| 5 |
"colorTo": "purple",
|
|
|
|
| 1 |
{
|
| 2 |
+
"title": "Generative Inference for Psychiatry Demo",
|
| 3 |
"emoji": "🧠",
|
| 4 |
"colorFrom": "indigo",
|
| 5 |
"colorTo": "purple",
|
inference.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
import torchvision.models as models
|
| 5 |
import torchvision.transforms as transforms
|
|
|
|
| 6 |
from torchvision.models.resnet import ResNet50_Weights
|
| 7 |
from PIL import Image
|
| 8 |
import numpy as np
|
|
@@ -12,6 +14,7 @@ import time
|
|
| 12 |
import copy
|
| 13 |
from collections import OrderedDict
|
| 14 |
from pathlib import Path
|
|
|
|
| 15 |
|
| 16 |
# Check for available hardware acceleration
|
| 17 |
if torch.cuda.is_available():
|
|
@@ -22,175 +25,98 @@ else:
|
|
| 22 |
device = torch.device("cpu")
|
| 23 |
print(f"Using device: {device}")
|
| 24 |
|
| 25 |
-
# Constants
|
| 26 |
MODEL_URLS = {
|
| 27 |
'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
|
| 28 |
'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
|
| 29 |
'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt'
|
| 30 |
}
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 41 |
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
if
|
| 46 |
-
return
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
| 52 |
else:
|
| 53 |
-
return
|
| 54 |
-
|
| 55 |
-
transforms.CenterCrop(input_size),
|
| 56 |
-
transforms.ToTensor(),
|
| 57 |
-
])
|
| 58 |
-
|
| 59 |
-
# Default transform without normalization
|
| 60 |
-
transform = transforms.Compose([
|
| 61 |
-
transforms.Resize(224),
|
| 62 |
-
transforms.CenterCrop(224),
|
| 63 |
-
transforms.ToTensor(),
|
| 64 |
-
])
|
| 65 |
-
|
| 66 |
-
normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
|
| 67 |
|
| 68 |
-
def
|
| 69 |
-
"""
|
| 70 |
-
Extract a subset of the model up to a specific layer.
|
| 71 |
-
|
| 72 |
-
Args:
|
| 73 |
-
model: The neural network model
|
| 74 |
-
layer_index: String 'all' for the full model, or a layer identifier (string or int)
|
| 75 |
-
For ResNet: integers 0-8 representing specific layers
|
| 76 |
-
For ViT: strings like 'encoder.layers.encoder_layer_3'
|
| 77 |
-
|
| 78 |
-
Returns:
|
| 79 |
-
A modified model that outputs features from the specified layer
|
| 80 |
-
"""
|
| 81 |
-
if isinstance(layer_index, str) and layer_index == 'all':
|
| 82 |
-
return model
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
layer_name = f"encoder_layer_{i}"
|
| 98 |
-
if hasattr(new_model.module.encoder.layers, layer_name):
|
| 99 |
-
encoder_layers.add_module(layer_name,
|
| 100 |
-
getattr(new_model.module.encoder.layers, layer_name))
|
| 101 |
-
|
| 102 |
-
# Replace the encoder layers with our truncated version
|
| 103 |
-
new_model.module.encoder.layers = encoder_layers
|
| 104 |
-
|
| 105 |
-
# Remove the heads since we're stopping at the encoder layer
|
| 106 |
-
new_model.module.heads = nn.Identity()
|
| 107 |
-
|
| 108 |
-
return new_model
|
| 109 |
-
else:
|
| 110 |
-
# Direct model access (not DataParallel)
|
| 111 |
-
encoder_layers = nn.Sequential()
|
| 112 |
-
for i in range(target_layer_idx + 1):
|
| 113 |
-
layer_name = f"encoder_layer_{i}"
|
| 114 |
-
if hasattr(new_model.encoder.layers, layer_name):
|
| 115 |
-
encoder_layers.add_module(layer_name,
|
| 116 |
-
getattr(new_model.encoder.layers, layer_name))
|
| 117 |
-
|
| 118 |
-
# Replace the encoder layers with our truncated version
|
| 119 |
-
new_model.encoder.layers = encoder_layers
|
| 120 |
-
|
| 121 |
-
# Remove the heads since we're stopping at the encoder layer
|
| 122 |
-
new_model.heads = nn.Identity()
|
| 123 |
-
|
| 124 |
-
return new_model
|
| 125 |
-
|
| 126 |
-
except (ValueError, IndexError) as e:
|
| 127 |
-
raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}")
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
return new_model
|
| 144 |
|
| 145 |
-
|
| 146 |
-
# Original ResNet/VGG handling
|
| 147 |
-
modules = list(model.named_children())
|
| 148 |
-
print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}")
|
| 149 |
-
|
| 150 |
-
cutoff_idx = next((i for i, (name, _) in enumerate(modules)
|
| 151 |
-
if name == str(layer_index)), None)
|
| 152 |
-
|
| 153 |
-
if cutoff_idx is not None:
|
| 154 |
-
# Keep modules up to and including the target
|
| 155 |
-
new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1]))
|
| 156 |
-
return new_model
|
| 157 |
-
else:
|
| 158 |
-
raise ValueError(f"Module {layer_index} not found in model")
|
| 159 |
-
|
| 160 |
-
# Get ImageNet labels
|
| 161 |
-
def get_imagenet_labels():
|
| 162 |
-
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
|
| 163 |
-
response = requests.get(url)
|
| 164 |
-
if response.status_code == 200:
|
| 165 |
-
return response.json()
|
| 166 |
-
else:
|
| 167 |
-
raise RuntimeError("Failed to fetch ImageNet labels")
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
if model_type not in
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
if model_type == 'resnet50_robust_face':
|
| 176 |
-
model_path = Path("models/resnet50_robust_face_100_checkpoint.pt")
|
| 177 |
-
else:
|
| 178 |
-
model_path = Path(f"models/{model_type}.pt")
|
| 179 |
-
|
| 180 |
-
if not model_path.exists():
|
| 181 |
-
print(f"Downloading {model_type} model...")
|
| 182 |
-
url = MODEL_URLS[model_type]
|
| 183 |
-
response = requests.get(url, stream=True)
|
| 184 |
-
if response.status_code == 200:
|
| 185 |
-
with open(model_path, 'wb') as f:
|
| 186 |
-
for chunk in response.iter_content(chunk_size=8192):
|
| 187 |
-
f.write(chunk)
|
| 188 |
-
print(f"Model downloaded and saved to {model_path}")
|
| 189 |
-
else:
|
| 190 |
-
raise RuntimeError(f"Failed to download model: {response.status_code}")
|
| 191 |
-
return model_path
|
| 192 |
|
| 193 |
class NormalizeByChannelMeanStd(nn.Module):
|
|
|
|
| 194 |
def __init__(self, mean, std):
|
| 195 |
super(NormalizeByChannelMeanStd, self).__init__()
|
| 196 |
if not isinstance(mean, torch.Tensor):
|
|
@@ -205,737 +131,567 @@ class NormalizeByChannelMeanStd(nn.Module):
|
|
| 205 |
|
| 206 |
def normalize_fn(self, tensor, mean, std):
|
| 207 |
"""Differentiable version of torchvision.functional.normalize"""
|
| 208 |
-
# here we assume the color channel is at dim=1
|
| 209 |
mean = mean[None, :, None, None]
|
| 210 |
std = std[None, :, None, None]
|
| 211 |
return tensor.sub(mean).div(std)
|
| 212 |
|
| 213 |
class InferStep:
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
self.orig_image = orig_image
|
| 216 |
self.eps = eps
|
| 217 |
self.step_size = step_size
|
| 218 |
|
| 219 |
-
def project(self, x):
|
|
|
|
| 220 |
diff = x - self.orig_image
|
| 221 |
diff = torch.clamp(diff, -self.eps, self.eps)
|
| 222 |
return torch.clamp(self.orig_image + diff, 0, 1)
|
| 223 |
|
| 224 |
-
def step(self, x, grad):
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
scaled_grad = grad / (grad_norm + 1e-10)
|
| 228 |
return scaled_grad * self.step_size
|
| 229 |
|
| 230 |
-
def
|
| 231 |
-
"""
|
| 232 |
-
if
|
| 233 |
-
return
|
| 234 |
-
elif n_itr <= 100:
|
| 235 |
-
return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
|
| 236 |
-
elif n_itr <= 200:
|
| 237 |
-
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
|
| 238 |
-
elif n_itr <= 500:
|
| 239 |
-
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
|
| 240 |
-
else:
|
| 241 |
-
# For very large iterations, show more evenly distributed points
|
| 242 |
-
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
|
| 243 |
-
int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
|
| 244 |
-
|
| 245 |
-
def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
|
| 246 |
-
"""Generate inference configuration with customizable parameters.
|
| 247 |
-
|
| 248 |
-
Args:
|
| 249 |
-
inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion')
|
| 250 |
-
eps (float): Maximum perturbation size
|
| 251 |
-
n_itr (int): Number of iterations
|
| 252 |
-
step_size (float): Step size for each iteration
|
| 253 |
-
"""
|
| 254 |
|
| 255 |
-
#
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
'diffusion_noise_ratio': 0.0, # No diffusion noise
|
| 262 |
-
'initial_inference_noise_ratio': 0.0, # No initial noise
|
| 263 |
-
'top_layer': 'all', # Use all layers of the model
|
| 264 |
-
'inference_normalization': False, # Apply normalization during inference
|
| 265 |
-
'recognition_normalization': False, # Apply normalization during recognition
|
| 266 |
-
'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
|
| 267 |
-
'misc_info': {'keep_grads': False} # Additional configuration
|
| 268 |
-
}
|
| 269 |
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
| 278 |
|
| 279 |
-
|
| 280 |
-
config['loss_function'] = 'CE' # Cross Entropy
|
| 281 |
-
config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
class GenerativeInferenceModel:
|
|
|
|
|
|
|
| 291 |
def __init__(self):
|
| 292 |
self.models = {}
|
| 293 |
-
#self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
|
| 294 |
self.model_preproc = {}
|
| 295 |
-
self.labels = get_imagenet_labels()
|
| 296 |
|
| 297 |
-
def
|
| 298 |
-
"""
|
| 299 |
-
|
| 300 |
-
Returns whether the model passes basic integrity check.
|
| 301 |
-
"""
|
| 302 |
try:
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
# Run forward pass
|
| 310 |
-
with torch.no_grad():
|
| 311 |
-
output = model(test_input)
|
| 312 |
-
|
| 313 |
-
# Check output shape
|
| 314 |
-
if output.shape != (1, 1000):
|
| 315 |
-
print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)")
|
| 316 |
-
return False
|
| 317 |
-
|
| 318 |
-
# Get top prediction
|
| 319 |
-
probs = torch.nn.functional.softmax(output, dim=1)
|
| 320 |
-
confidence, prediction = torch.max(probs, 1)
|
| 321 |
-
|
| 322 |
-
# Calculate basic statistics on output
|
| 323 |
-
mean = output.mean().item()
|
| 324 |
-
std = output.std().item()
|
| 325 |
-
min_val = output.min().item()
|
| 326 |
-
max_val = output.max().item()
|
| 327 |
-
|
| 328 |
-
print(f"Model integrity check results:")
|
| 329 |
-
print(f"- Output shape: {output.shape}")
|
| 330 |
-
print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
|
| 331 |
-
print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}")
|
| 332 |
-
|
| 333 |
-
# Basic sanity checks
|
| 334 |
-
if torch.isnan(output).any():
|
| 335 |
-
print("❌ Model produced NaN outputs")
|
| 336 |
-
return False
|
| 337 |
-
|
| 338 |
-
if output.std().item() < 0.1:
|
| 339 |
-
print("⚠️ Low output variance, model may not be discriminative")
|
| 340 |
-
|
| 341 |
-
print("✅ Model passes basic integrity check")
|
| 342 |
-
return True
|
| 343 |
-
|
| 344 |
except Exception as e:
|
| 345 |
-
print(f"
|
| 346 |
-
|
| 347 |
-
return True
|
| 348 |
|
| 349 |
def load_model(self, model_type):
|
|
|
|
| 350 |
if model_type in self.models:
|
| 351 |
print(f"Using cached {model_type} model")
|
| 352 |
return self.models[model_type]
|
| 353 |
|
| 354 |
start_time = time.time()
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
model = nn.Sequential(normalizer, resnet)
|
| 364 |
|
| 365 |
-
#
|
|
|
|
|
|
|
| 366 |
if model_path:
|
| 367 |
print(f"Loading {model_type} model from {model_path}...")
|
| 368 |
try:
|
| 369 |
checkpoint = torch.load(model_path, map_location=device)
|
| 370 |
|
| 371 |
-
# Print checkpoint structure for better understanding
|
| 372 |
-
print("\n=== Analyzing checkpoint structure ===")
|
| 373 |
-
if isinstance(checkpoint, dict):
|
| 374 |
-
print(f"Checkpoint contains keys: {list(checkpoint.keys())}")
|
| 375 |
-
|
| 376 |
-
# Examine 'model' structure if it exists
|
| 377 |
-
if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
|
| 378 |
-
model_dict = checkpoint['model']
|
| 379 |
-
# Get sample of keys to understand structure
|
| 380 |
-
first_keys = list(model_dict.keys())[:5]
|
| 381 |
-
print(f"'model' contains keys like: {first_keys}")
|
| 382 |
-
|
| 383 |
-
# Check for common prefixes in the model dict
|
| 384 |
-
prefixes = set()
|
| 385 |
-
for key in list(model_dict.keys())[:100]: # Check first 100 keys
|
| 386 |
-
parts = key.split('.')
|
| 387 |
-
if len(parts) > 1:
|
| 388 |
-
prefixes.add(parts[0])
|
| 389 |
-
if prefixes:
|
| 390 |
-
print(f"Common prefixes in model dict: {prefixes}")
|
| 391 |
-
else:
|
| 392 |
-
print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}")
|
| 393 |
-
|
| 394 |
# Handle different checkpoint formats
|
| 395 |
if 'model' in checkpoint:
|
| 396 |
-
# Format from madrylab robust models
|
| 397 |
state_dict = checkpoint['model']
|
| 398 |
print("Using 'model' key from checkpoint")
|
| 399 |
elif 'state_dict' in checkpoint:
|
| 400 |
state_dict = checkpoint['state_dict']
|
| 401 |
print("Using 'state_dict' key from checkpoint")
|
| 402 |
else:
|
| 403 |
-
# Direct state dict
|
| 404 |
state_dict = checkpoint
|
| 405 |
print("Using checkpoint directly as state_dict")
|
| 406 |
|
| 407 |
-
#
|
| 408 |
resnet_state_dict = {}
|
| 409 |
-
prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.']
|
| 410 |
resnet_keys = set(resnet.state_dict().keys())
|
| 411 |
|
| 412 |
-
#
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
print(f"Found 'module.model' structure with {len(module_model_keys)} parameters")
|
| 419 |
-
# Extract all parameters from module.model
|
| 420 |
-
for source_key, value in state_dict.items():
|
| 421 |
-
if source_key.startswith('module.model.'):
|
| 422 |
-
target_key = source_key[len('module.model.'):]
|
| 423 |
-
# Some ckpts have 'module.model.model.<...>'; remove the extra 'model.' too
|
| 424 |
-
if target_key.startswith('model.'):
|
| 425 |
-
target_key = target_key[len('model.'):]
|
| 426 |
-
resnet_state_dict[target_key] = value
|
| 427 |
-
|
| 428 |
-
print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
|
| 429 |
-
|
| 430 |
-
# Check for 'attacker.model' structure
|
| 431 |
-
attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')]
|
| 432 |
-
if attacker_model_keys:
|
| 433 |
-
print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters")
|
| 434 |
-
# Extract all parameters from attacker.model
|
| 435 |
-
for source_key, value in state_dict.items():
|
| 436 |
-
if source_key.startswith('attacker.model.'):
|
| 437 |
-
target_key = source_key[len('attacker.model.'):]
|
| 438 |
-
resnet_state_dict[target_key] = value
|
| 439 |
-
|
| 440 |
-
print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model")
|
| 441 |
-
|
| 442 |
-
# Check if 'model' (not attacker.model) exists as a fallback
|
| 443 |
-
model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')]
|
| 444 |
-
if model_keys and len(resnet_state_dict) < len(resnet_keys):
|
| 445 |
-
print(f"Found additional 'model.' structure with {len(model_keys)} parameters")
|
| 446 |
-
# Try to complete missing parameters
|
| 447 |
for source_key, value in state_dict.items():
|
| 448 |
-
if source_key.startswith('model.'):
|
| 449 |
-
target_key = source_key[len('model.'):]
|
| 450 |
-
if target_key in resnet_keys
|
| 451 |
resnet_state_dict[target_key] = value
|
| 452 |
-
|
| 453 |
-
else:
|
| 454 |
-
# Check for other known structures
|
| 455 |
-
structure_found = False
|
| 456 |
-
|
| 457 |
-
# Check for 'model.' prefix
|
| 458 |
-
model_keys = [key for key in state_dict.keys() if key.startswith('model.')]
|
| 459 |
-
if model_keys:
|
| 460 |
-
print(f"Found 'model.' structure with {len(model_keys)} parameters")
|
| 461 |
-
for source_key, value in state_dict.items():
|
| 462 |
-
if source_key.startswith('model.'):
|
| 463 |
-
target_key = source_key[len('model.'):]
|
| 464 |
-
resnet_state_dict[target_key] = value
|
| 465 |
-
structure_found = True
|
| 466 |
|
| 467 |
-
#
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
if
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
if not structure_found:
|
| 482 |
-
print("No standard model structure found, trying prefix mappings...")
|
| 483 |
-
for target_key in resnet_keys:
|
| 484 |
-
for prefix in prefixes_to_try:
|
| 485 |
-
source_key = prefix + target_key
|
| 486 |
-
if source_key in state_dict:
|
| 487 |
-
resnet_state_dict[target_key] = state_dict[source_key]
|
| 488 |
-
break
|
| 489 |
|
| 490 |
-
#
|
| 491 |
-
if len(resnet_state_dict)
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
# Track matches found through prefix removal
|
| 495 |
-
prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']}
|
| 496 |
-
layer_matches = {} # Track matches by layer type
|
| 497 |
-
|
| 498 |
-
# Count parameter keys by layer type for analysis
|
| 499 |
-
for key in resnet_keys:
|
| 500 |
-
layer_name = key.split('.')[0] if '.' in key else key
|
| 501 |
-
if layer_name not in layer_matches:
|
| 502 |
-
layer_matches[layer_name] = {'total': 0, 'matched': 0}
|
| 503 |
-
layer_matches[layer_name]['total'] += 1
|
| 504 |
|
| 505 |
-
# Try keys with common prefixes
|
| 506 |
for source_key, value in state_dict.items():
|
| 507 |
-
# Skip if already found
|
| 508 |
target_key = source_key
|
| 509 |
-
matched_prefix = None
|
| 510 |
|
| 511 |
# Try removing various prefixes
|
| 512 |
-
for prefix in
|
| 513 |
if source_key.startswith(prefix):
|
| 514 |
target_key = source_key[len(prefix):]
|
| 515 |
-
matched_prefix = prefix
|
| 516 |
break
|
| 517 |
|
| 518 |
-
#
|
| 519 |
-
if target_key
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
resnet_state_dict[target_key] = value
|
| 521 |
-
|
| 522 |
-
# Update match statistics
|
| 523 |
-
if matched_prefix:
|
| 524 |
-
prefix_matches[matched_prefix] += 1
|
| 525 |
-
|
| 526 |
-
# Update layer matches
|
| 527 |
-
layer_name = target_key.split('.')[0] if '.' in target_key else target_key
|
| 528 |
-
if layer_name in layer_matches:
|
| 529 |
-
layer_matches[layer_name]['matched'] += 1
|
| 530 |
-
|
| 531 |
-
# Print detailed prefix removal statistics
|
| 532 |
-
print("\n=== Prefix Removal Statistics ===")
|
| 533 |
-
total_matches = sum(prefix_matches.values())
|
| 534 |
-
print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)")
|
| 535 |
-
|
| 536 |
-
# Show matches by prefix
|
| 537 |
-
print("\nMatches by prefix:")
|
| 538 |
-
for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True):
|
| 539 |
-
if count > 0:
|
| 540 |
-
print(f" {prefix}: {count} parameters")
|
| 541 |
-
|
| 542 |
-
# Show matches by layer type
|
| 543 |
-
print("\nMatches by layer type:")
|
| 544 |
-
for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True):
|
| 545 |
-
match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0
|
| 546 |
-
print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)")
|
| 547 |
-
|
| 548 |
-
# Check for specific important layers (conv1, layer1, etc.)
|
| 549 |
-
critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
|
| 550 |
-
print("\nStatus of critical layers:")
|
| 551 |
-
for layer in critical_layers:
|
| 552 |
-
if layer in layer_matches:
|
| 553 |
-
match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100
|
| 554 |
-
status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE"
|
| 555 |
-
print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}")
|
| 556 |
-
else:
|
| 557 |
-
print(f" {layer}: Not found in model")
|
| 558 |
|
| 559 |
-
# Load the
|
| 560 |
if resnet_state_dict:
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}")
|
| 571 |
-
loading_report.append(f"Missing keys: {len(missing_keys):,} parameters")
|
| 572 |
-
loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters")
|
| 573 |
-
|
| 574 |
-
# Calculate percentage of parameters loaded
|
| 575 |
-
loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys)
|
| 576 |
-
loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100
|
| 577 |
-
|
| 578 |
-
# Determine loading success status
|
| 579 |
-
if loaded_percent >= 99.5:
|
| 580 |
-
status = "✅ COMPLETE - All important parameters loaded"
|
| 581 |
-
elif loaded_percent >= 90:
|
| 582 |
-
status = "🟡 PARTIAL - Most parameters loaded, should still function"
|
| 583 |
-
elif loaded_percent >= 50:
|
| 584 |
-
status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly"
|
| 585 |
-
else:
|
| 586 |
-
status = "❌ FAILED - Critical parameters missing, will not function properly"
|
| 587 |
-
|
| 588 |
-
loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)")
|
| 589 |
-
loading_report.append(f"Loading status: {status}")
|
| 590 |
-
|
| 591 |
-
# If loading is severely incomplete, fall back to PyTorch's pretrained model
|
| 592 |
-
if loaded_percent < 50:
|
| 593 |
-
loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.")
|
| 594 |
-
loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.")
|
| 595 |
-
|
| 596 |
-
# Create a new ResNet model with pretrained weights
|
| 597 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 598 |
model = nn.Sequential(normalizer, resnet)
|
| 599 |
-
loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
|
| 600 |
-
|
| 601 |
-
# Show missing keys by layer type
|
| 602 |
-
if missing_keys:
|
| 603 |
-
loading_report.append("\nMissing keys by layer type:")
|
| 604 |
-
layer_types = {}
|
| 605 |
-
for key in missing_keys:
|
| 606 |
-
# Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.)
|
| 607 |
-
parts = key.split('.')
|
| 608 |
-
if len(parts) > 0:
|
| 609 |
-
layer_type = parts[0]
|
| 610 |
-
if layer_type not in layer_types:
|
| 611 |
-
layer_types[layer_type] = 0
|
| 612 |
-
layer_types[layer_type] += 1
|
| 613 |
-
|
| 614 |
-
# Add counts by layer type
|
| 615 |
-
for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
|
| 616 |
-
loading_report.append(f" {layer_type}: {count:,} parameters")
|
| 617 |
-
|
| 618 |
-
loading_report.append("\nFirst 10 missing keys:")
|
| 619 |
-
for i, key in enumerate(sorted(missing_keys)[:10]):
|
| 620 |
-
loading_report.append(f" {i+1}. {key}")
|
| 621 |
-
|
| 622 |
-
# Show unexpected keys if any
|
| 623 |
-
if unexpected_keys:
|
| 624 |
-
loading_report.append("\nFirst 10 unexpected keys:")
|
| 625 |
-
for i, key in enumerate(sorted(unexpected_keys)[:10]):
|
| 626 |
-
loading_report.append(f" {i+1}. {key}")
|
| 627 |
-
|
| 628 |
-
loading_report.append("========================================")
|
| 629 |
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
os.makedirs("logs", exist_ok=True)
|
| 636 |
-
with open(f"logs/model_loading_{model_type}.log", "w") as f:
|
| 637 |
-
f.write(report_text)
|
| 638 |
|
| 639 |
-
# Look for normalizer parameters as well
|
| 640 |
-
if any(key.startswith('attacker.normalize.') for key in state_dict.keys()):
|
| 641 |
-
norm_state_dict = {}
|
| 642 |
-
for key, value in state_dict.items():
|
| 643 |
-
if key.startswith('attacker.normalize.'):
|
| 644 |
-
norm_key = key[len('attacker.normalize.'):]
|
| 645 |
-
norm_state_dict[norm_key] = value
|
| 646 |
-
|
| 647 |
-
if norm_state_dict:
|
| 648 |
-
try:
|
| 649 |
-
normalizer.load_state_dict(norm_state_dict, strict=False)
|
| 650 |
-
print("Successfully loaded normalizer parameters")
|
| 651 |
-
except Exception as e:
|
| 652 |
-
print(f"Warning: Could not load normalizer parameters: {e}")
|
| 653 |
-
except Exception as e:
|
| 654 |
-
print(f"Warning: Error loading ResNet parameters: {e}")
|
| 655 |
-
# Fall back to loading without normalizer
|
| 656 |
-
model = resnet # Use just the ResNet model without normalizer
|
| 657 |
except Exception as e:
|
| 658 |
-
print(f"Error loading
|
| 659 |
-
|
| 660 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 662 |
model = nn.Sequential(normalizer, resnet)
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
print("No checkpoint available, using PyTorch's pretrained model")
|
| 666 |
-
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 667 |
-
model = nn.Sequential(normalizer, resnet)
|
| 668 |
|
| 669 |
model = model.to(device)
|
| 670 |
-
model.eval()
|
| 671 |
|
| 672 |
-
# Verify model
|
| 673 |
self.verify_model_integrity(model, model_type)
|
| 674 |
|
| 675 |
-
#
|
| 676 |
self.models[model_type] = model
|
|
|
|
| 677 |
end_time = time.time()
|
| 678 |
-
|
| 679 |
-
print(f"Model {model_type} loaded in {load_time:.2f} seconds")
|
| 680 |
return model
|
| 681 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
def inference(self, image, model_type, config):
|
| 683 |
-
"""Run generative inference
|
| 684 |
-
# Time the entire inference process
|
| 685 |
inference_start = time.time()
|
| 686 |
|
| 687 |
-
# Load model
|
| 688 |
model = self.load_model(model_type)
|
| 689 |
|
| 690 |
-
#
|
| 691 |
if isinstance(image, str):
|
| 692 |
if os.path.exists(image):
|
| 693 |
image = Image.open(image).convert('RGB')
|
| 694 |
else:
|
| 695 |
raise ValueError(f"Image path does not exist: {image}")
|
| 696 |
-
elif isinstance(image,
|
| 697 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
# Pick the right preproc for this model
|
| 702 |
-
pre = self.model_preproc.get(model_type, {"size": 224, "mean": IMAGENET_MEAN, "std": IMAGENET_STD})
|
| 703 |
|
| 704 |
-
#
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
norm_std=pre["std"]
|
| 711 |
-
)
|
| 712 |
-
|
| 713 |
-
print(f"[PREPROC] {model_type}: size={pre['size']} mean={pre['mean']} std={pre['std']} (transform normalize=False; model has internal normalizer)")
|
| 714 |
|
| 715 |
-
#
|
| 716 |
-
if config
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
else:
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
|
|
|
|
|
|
| 725 |
|
| 726 |
-
#
|
| 727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
# Get original predictions
|
| 730 |
with torch.no_grad():
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
print("Model is sequential with normalization")
|
| 734 |
-
# Get the core model part (typically at index 1 in Sequential)
|
| 735 |
-
core_model = model[1]
|
| 736 |
-
if config['inference_normalization']:
|
| 737 |
-
output_original = model(image_tensor) # Model includes normalization
|
| 738 |
-
else:
|
| 739 |
-
output_original = core_model(image_tensor) # Model includes normalization
|
| 740 |
-
|
| 741 |
else:
|
| 742 |
-
|
| 743 |
-
# Use manual normalization for non-sequential models
|
| 744 |
-
if config['inference_normalization']:
|
| 745 |
-
normalized_tensor = normalize_transform(image_tensor)
|
| 746 |
-
output_original = model(normalized_tensor)
|
| 747 |
-
else:
|
| 748 |
-
output_original = model(image_tensor)
|
| 749 |
-
core_model = model
|
| 750 |
|
| 751 |
probs_orig = F.softmax(output_original, dim=1)
|
| 752 |
conf_orig, classes_orig = torch.max(probs_orig, 1)
|
| 753 |
|
| 754 |
-
# Get least confident classes
|
| 755 |
-
|
| 756 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
# Initialize inference step
|
| 758 |
infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
|
| 759 |
|
| 760 |
# Storage for inference steps
|
| 761 |
-
# Create a new tensor that requires gradients
|
| 762 |
x = image_tensor.clone().detach().requires_grad_(True)
|
| 763 |
all_steps = [image_tensor[0].detach().cpu()]
|
| 764 |
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
|
| 769 |
-
print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...")
|
| 770 |
-
|
| 771 |
-
# Extract model up to the specified layer
|
| 772 |
-
try:
|
| 773 |
-
# Start by finding the actual model to use
|
| 774 |
-
base_model = model
|
| 775 |
-
|
| 776 |
-
# Handle DataParallel wrapper if present
|
| 777 |
-
if hasattr(base_model, 'module'):
|
| 778 |
-
base_model = base_model.module
|
| 779 |
-
|
| 780 |
-
# Log the initial model structure
|
| 781 |
-
print(f"DEBUG - Initial model structure: {type(base_model)}")
|
| 782 |
-
|
| 783 |
-
# If we have a Sequential model (which is likely our normalizer + model structure)
|
| 784 |
-
if isinstance(base_model, nn.Sequential):
|
| 785 |
-
print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children")
|
| 786 |
-
|
| 787 |
-
# If this is our NormalizeByChannelMeanStd + ResNet pattern
|
| 788 |
-
if len(list(base_model.children())) >= 2:
|
| 789 |
-
# The actual ResNet model is the second component (index 1)
|
| 790 |
-
actual_model = list(base_model.children())[1]
|
| 791 |
-
print(f"DEBUG - Using ResNet component: {type(actual_model)}")
|
| 792 |
-
print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}")
|
| 793 |
-
|
| 794 |
-
# Extract from the actual ResNet
|
| 795 |
-
layer_model = extract_middle_layers(actual_model, config['top_layer'])
|
| 796 |
-
else:
|
| 797 |
-
# Just a single component Sequential
|
| 798 |
-
layer_model = extract_middle_layers(base_model, config['top_layer'])
|
| 799 |
-
else:
|
| 800 |
-
# Not Sequential, might be direct model
|
| 801 |
-
print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}")
|
| 802 |
-
layer_model = extract_middle_layers(base_model, config['top_layer'])
|
| 803 |
-
|
| 804 |
-
print(f"Successfully extracted model up to layer: {config['top_layer']}")
|
| 805 |
-
except ValueError as e:
|
| 806 |
-
print(f"Layer extraction failed: {e}. Using full model.")
|
| 807 |
-
layer_model = model
|
| 808 |
-
|
| 809 |
-
# Add noise to the image - exactly match original code
|
| 810 |
-
added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device)
|
| 811 |
-
noisy_image_tensor = image_tensor + added_noise
|
| 812 |
-
|
| 813 |
-
# Compute noisy features - simplified to match original code
|
| 814 |
-
noisy_features = layer_model(noisy_image_tensor)
|
| 815 |
-
|
| 816 |
-
print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}")
|
| 817 |
|
| 818 |
# Main inference loop
|
| 819 |
print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
|
| 820 |
-
|
| 821 |
for i in range(config['n_itr']):
|
| 822 |
# Reset gradients
|
| 823 |
x.grad = None
|
| 824 |
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
# Standard forward pass with full model
|
| 832 |
-
# Simplified to match original code's approach
|
| 833 |
-
output = model(x)
|
| 834 |
-
|
| 835 |
-
# Calculate loss and gradients based on inference type
|
| 836 |
-
try:
|
| 837 |
-
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
|
| 838 |
-
# Use MSE loss to match the noisy features
|
| 839 |
-
assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE"
|
| 840 |
-
if noisy_features is not None:
|
| 841 |
-
loss = F.mse_loss(output, noisy_features)
|
| 842 |
-
grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original
|
| 843 |
-
else:
|
| 844 |
-
raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
|
| 845 |
-
|
| 846 |
-
else: # Default 'IncreaseConfidence' approach
|
| 847 |
-
# Get the least confident classes
|
| 848 |
-
num_classes = min(10, least_confident_classes.size(1))
|
| 849 |
-
target_classes = least_confident_classes[0, :num_classes]
|
| 850 |
-
|
| 851 |
-
# Create targets for least confident classes
|
| 852 |
-
targets = torch.tensor([idx.item() for idx in target_classes], device=device)
|
| 853 |
-
|
| 854 |
-
# Use a combined loss to increase confidence
|
| 855 |
-
loss = 0
|
| 856 |
-
for target in targets:
|
| 857 |
-
# Create one-hot target
|
| 858 |
-
one_hot = torch.zeros_like(output)
|
| 859 |
-
one_hot[0, target] = 1
|
| 860 |
-
# Use loss to maximize confidence
|
| 861 |
-
loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
|
| 862 |
-
|
| 863 |
-
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
|
| 864 |
|
| 865 |
-
if
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
|
| 869 |
-
x = infer_step.project(x + random_noise)
|
| 870 |
else:
|
| 871 |
-
|
| 872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
-
|
| 875 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
-
|
| 878 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 885 |
|
| 886 |
# Store step if in iterations_to_show
|
| 887 |
-
if i+1 in config
|
| 888 |
all_steps.append(x[0].detach().cpu())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 889 |
|
| 890 |
-
#
|
| 891 |
with torch.no_grad():
|
| 892 |
-
if
|
| 893 |
-
|
| 894 |
-
final_output = model(x)
|
| 895 |
-
else:
|
| 896 |
-
final_output = core_model(x)
|
| 897 |
else:
|
| 898 |
-
|
| 899 |
-
normalized_x = normalize_transform(x)
|
| 900 |
-
final_output = model(normalized_x)
|
| 901 |
-
else:
|
| 902 |
-
final_output = model(x)
|
| 903 |
|
| 904 |
final_probs = F.softmax(final_output, dim=1)
|
| 905 |
final_conf, final_classes = torch.max(final_probs, 1)
|
| 906 |
|
| 907 |
-
# Calculate timing information
|
| 908 |
-
loop_time = time.time() - loop_start
|
| 909 |
total_time = time.time() - inference_start
|
| 910 |
-
avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0
|
| 911 |
|
| 912 |
print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
|
| 913 |
print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
|
| 914 |
-
print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)")
|
| 915 |
print(f"Total inference time: {total_time:.2f} seconds")
|
| 916 |
|
| 917 |
-
# Return results in
|
| 918 |
return {
|
| 919 |
'final_image': x[0].detach().cpu(),
|
| 920 |
'steps': all_steps,
|
| 921 |
'original_class': classes_orig.item(),
|
| 922 |
'original_confidence': conf_orig.item(),
|
| 923 |
'final_class': final_classes.item(),
|
| 924 |
-
'final_confidence': final_conf.item()
|
|
|
|
|
|
|
| 925 |
}
|
| 926 |
|
| 927 |
-
# Utility function to show inference steps
|
| 928 |
def show_inference_steps(steps, figsize=(15, 10)):
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Complete generative inference module with model loading and inference capabilities."""
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
|
|
|
| 6 |
import torchvision.transforms as transforms
|
| 7 |
+
import torchvision.models as models
|
| 8 |
from torchvision.models.resnet import ResNet50_Weights
|
| 9 |
from PIL import Image
|
| 10 |
import numpy as np
|
|
|
|
| 14 |
import copy
|
| 15 |
from collections import OrderedDict
|
| 16 |
from pathlib import Path
|
| 17 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 18 |
|
| 19 |
# Check for available hardware acceleration
|
| 20 |
if torch.cuda.is_available():
|
|
|
|
| 25 |
device = torch.device("cpu")
|
| 26 |
print(f"Using device: {device}")
|
| 27 |
|
| 28 |
+
# Constants for model URLs
|
| 29 |
MODEL_URLS = {
|
| 30 |
'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
|
| 31 |
'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
|
| 32 |
'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt'
|
| 33 |
}
|
| 34 |
|
| 35 |
+
# Model-specific preprocessing configurations
|
| 36 |
+
MODEL_CONFIGS = {
|
| 37 |
+
'resnet50_robust_face': {
|
| 38 |
+
'input_size': 112,
|
| 39 |
+
'norm_mean': [0.5, 0.5, 0.5],
|
| 40 |
+
'norm_std': [0.5, 0.5, 0.5],
|
| 41 |
+
'n_classes': 500,
|
| 42 |
+
'dataset': 'VGGFace2'
|
| 43 |
+
},
|
| 44 |
+
'resnet50_standard': {
|
| 45 |
+
'input_size': 224,
|
| 46 |
+
'norm_mean': [0.485, 0.456, 0.406],
|
| 47 |
+
'norm_std': [0.229, 0.224, 0.225],
|
| 48 |
+
'n_classes': 1000,
|
| 49 |
+
'dataset': 'ImageNet'
|
| 50 |
+
},
|
| 51 |
+
'resnet50_robust': {
|
| 52 |
+
'input_size': 224,
|
| 53 |
+
'norm_mean': [0.485, 0.456, 0.406],
|
| 54 |
+
'norm_std': [0.229, 0.224, 0.225],
|
| 55 |
+
'n_classes': 1000,
|
| 56 |
+
'dataset': 'ImageNet'
|
| 57 |
+
}
|
| 58 |
}
|
| 59 |
|
| 60 |
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 61 |
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 62 |
|
| 63 |
+
def get_iterations_to_show(n_itr):
|
| 64 |
+
"""Generate a dynamic list of iterations to show based on total iterations."""
|
| 65 |
+
if n_itr <= 50:
|
| 66 |
+
return [1, 5, 10, 20, 30, 40, 50, n_itr]
|
| 67 |
+
elif n_itr <= 100:
|
| 68 |
+
return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
|
| 69 |
+
elif n_itr <= 200:
|
| 70 |
+
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
|
| 71 |
+
elif n_itr <= 500:
|
| 72 |
+
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
|
| 73 |
else:
|
| 74 |
+
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
|
| 75 |
+
int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
|
| 78 |
+
"""Generate inference configuration with customizable parameters."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
+
config = {
|
| 81 |
+
'loss_infer': inference_type,
|
| 82 |
+
'n_itr': n_itr,
|
| 83 |
+
'eps': eps,
|
| 84 |
+
'step_size': step_size,
|
| 85 |
+
'diffusion_noise_ratio': 0.0,
|
| 86 |
+
'initial_inference_noise_ratio': 0.0,
|
| 87 |
+
'top_layer': 'all',
|
| 88 |
+
'inference_normalization': False,
|
| 89 |
+
'recognition_normalization': False,
|
| 90 |
+
'iterations_to_show': get_iterations_to_show(n_itr),
|
| 91 |
+
'misc_info': {'keep_grads': False}
|
| 92 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
if inference_type == 'IncreaseConfidence':
|
| 95 |
+
config['loss_function'] = 'CE'
|
| 96 |
+
elif inference_type == 'Prior-Guided Drift Diffusion':
|
| 97 |
+
config['loss_function'] = 'MSE'
|
| 98 |
+
config['initial_inference_noise_ratio'] = 0.05
|
| 99 |
+
config['diffusion_noise_ratio'] = 0.01
|
| 100 |
+
config['top_layer'] = 'layer4'
|
| 101 |
+
elif inference_type == 'GradModulation':
|
| 102 |
+
config['loss_function'] = 'CE'
|
| 103 |
+
config['misc_info']['grad_modulation'] = 0.5
|
| 104 |
+
elif inference_type == 'CompositionalFusion':
|
| 105 |
+
config['loss_function'] = 'CE'
|
| 106 |
+
config['misc_info']['positive_classes'] = []
|
| 107 |
+
config['misc_info']['negative_classes'] = []
|
|
|
|
| 108 |
|
| 109 |
+
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
def get_model_preprocessing(model_type: str) -> Dict:
|
| 112 |
+
"""Get preprocessing configuration for specific model type."""
|
| 113 |
+
if model_type not in MODEL_CONFIGS:
|
| 114 |
+
print(f"Fall-back: Unknown model type {model_type}, using ImageNet defaults")
|
| 115 |
+
return MODEL_CONFIGS['resnet50_standard']
|
| 116 |
+
return MODEL_CONFIGS[model_type]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
class NormalizeByChannelMeanStd(nn.Module):
|
| 119 |
+
"""Normalization layer for models."""
|
| 120 |
def __init__(self, mean, std):
|
| 121 |
super(NormalizeByChannelMeanStd, self).__init__()
|
| 122 |
if not isinstance(mean, torch.Tensor):
|
|
|
|
| 131 |
|
| 132 |
def normalize_fn(self, tensor, mean, std):
|
| 133 |
"""Differentiable version of torchvision.functional.normalize"""
|
|
|
|
| 134 |
mean = mean[None, :, None, None]
|
| 135 |
std = std[None, :, None, None]
|
| 136 |
return tensor.sub(mean).div(std)
|
| 137 |
|
| 138 |
class InferStep:
|
| 139 |
+
"""Inference step class for gradient-based optimization."""
|
| 140 |
+
|
| 141 |
+
def __init__(self, orig_image: torch.Tensor, eps: float, step_size: float):
|
| 142 |
self.orig_image = orig_image
|
| 143 |
self.eps = eps
|
| 144 |
self.step_size = step_size
|
| 145 |
|
| 146 |
+
def project(self, x: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
"""Project x onto epsilon-ball around original image."""
|
| 148 |
diff = x - self.orig_image
|
| 149 |
diff = torch.clamp(diff, -self.eps, self.eps)
|
| 150 |
return torch.clamp(self.orig_image + diff, 0, 1)
|
| 151 |
|
| 152 |
+
def step(self, x: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
| 153 |
+
"""Take a normalized gradient step."""
|
| 154 |
+
dim = len(x.shape) - 1
|
| 155 |
+
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=1).reshape(-1, *([1] * dim))
|
| 156 |
scaled_grad = grad / (grad_norm + 1e-10)
|
| 157 |
return scaled_grad * self.step_size
|
| 158 |
|
| 159 |
+
def extract_middle_layers(model: nn.Module, layer_index: Union[str, int]) -> nn.Module:
|
| 160 |
+
"""Extract middle layers from a model up to a specified layer index."""
|
| 161 |
+
if isinstance(layer_index, str) and layer_index == 'all':
|
| 162 |
+
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
# Handle ResNet layer extraction
|
| 165 |
+
modules = list(model.named_children())
|
| 166 |
+
cutoff_idx = next(
|
| 167 |
+
(i for i, (name, _) in enumerate(modules) if name == str(layer_index)),
|
| 168 |
+
None
|
| 169 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
if cutoff_idx is not None:
|
| 172 |
+
new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx + 1]))
|
| 173 |
+
return new_model
|
| 174 |
+
else:
|
| 175 |
+
print(f"Fall-back: Module {layer_index} not found, using full model")
|
| 176 |
+
return model
|
| 177 |
+
|
| 178 |
+
def calculate_loss(output_model: torch.Tensor, class_indices: List[int], loss_inference: str) -> torch.Tensor:
|
| 179 |
+
"""Calculate loss for specified class indices."""
|
| 180 |
+
losses = []
|
| 181 |
+
for idx in class_indices:
|
| 182 |
+
target = torch.full((1,), idx, dtype=torch.long, device=output_model.device)
|
| 183 |
+
if loss_inference == 'CE':
|
| 184 |
+
loss = nn.CrossEntropyLoss()(output_model, target)
|
| 185 |
+
elif loss_inference == 'MSE':
|
| 186 |
+
one_hot_target = torch.zeros_like(output_model)
|
| 187 |
+
one_hot_target[0, target] = 1
|
| 188 |
+
loss = nn.MSELoss()(output_model, one_hot_target)
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(f"Unsupported loss_inference: {loss_inference}")
|
| 191 |
+
losses.append(loss)
|
| 192 |
|
| 193 |
+
return torch.stack(losses).mean()
|
| 194 |
+
|
| 195 |
+
def download_model(model_type):
|
| 196 |
+
"""Download model if needed."""
|
| 197 |
+
if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
|
| 198 |
+
return None
|
| 199 |
|
| 200 |
+
os.makedirs("models", exist_ok=True)
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
if model_type == 'resnet50_robust_face':
|
| 203 |
+
model_path = Path("models/resnet50_vggface2_L2_eps_0.50_checkpoint150.pt")
|
| 204 |
+
else:
|
| 205 |
+
model_path = Path(f"models/{model_type}.pt")
|
| 206 |
|
| 207 |
+
if not model_path.exists():
|
| 208 |
+
print(f"Downloading {model_type} model...")
|
| 209 |
+
url = MODEL_URLS[model_type]
|
| 210 |
+
response = requests.get(url, stream=True)
|
| 211 |
+
if response.status_code == 200:
|
| 212 |
+
with open(model_path, 'wb') as f:
|
| 213 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 214 |
+
f.write(chunk)
|
| 215 |
+
print(f"Model downloaded and saved to {model_path}")
|
| 216 |
+
else:
|
| 217 |
+
raise RuntimeError(f"Failed to download model: {response.status_code}")
|
| 218 |
+
return model_path
|
| 219 |
|
| 220 |
class GenerativeInferenceModel:
|
| 221 |
+
"""Complete generative inference model with model loading and inference."""
|
| 222 |
+
|
| 223 |
def __init__(self):
|
| 224 |
self.models = {}
|
|
|
|
| 225 |
self.model_preproc = {}
|
| 226 |
+
self.labels = self.get_imagenet_labels()
|
| 227 |
|
| 228 |
+
def get_imagenet_labels(self):
|
| 229 |
+
"""Get ImageNet labels."""
|
| 230 |
+
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
|
|
|
|
|
|
|
| 231 |
try:
|
| 232 |
+
response = requests.get(url)
|
| 233 |
+
if response.status_code == 200:
|
| 234 |
+
return response.json()
|
| 235 |
+
else:
|
| 236 |
+
print("Fall-back: Failed to fetch ImageNet labels, using placeholder")
|
| 237 |
+
return [f"class_{i}" for i in range(1000)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
except Exception as e:
|
| 239 |
+
print(f"Fall-back: Error fetching labels: {e}")
|
| 240 |
+
return [f"class_{i}" for i in range(1000)]
|
|
|
|
| 241 |
|
| 242 |
def load_model(self, model_type):
|
| 243 |
+
"""Load and cache models for different model types."""
|
| 244 |
if model_type in self.models:
|
| 245 |
print(f"Using cached {model_type} model")
|
| 246 |
return self.models[model_type]
|
| 247 |
|
| 248 |
start_time = time.time()
|
| 249 |
+
|
| 250 |
+
# Get model-specific preprocessing config
|
| 251 |
+
preproc_config = get_model_preprocessing(model_type)
|
| 252 |
+
self.model_preproc[model_type] = preproc_config
|
| 253 |
+
|
| 254 |
+
# Create normalizer
|
| 255 |
+
normalizer = NormalizeByChannelMeanStd(
|
| 256 |
+
preproc_config['norm_mean'],
|
| 257 |
+
preproc_config['norm_std']
|
| 258 |
+
).to(device)
|
| 259 |
+
|
| 260 |
+
# Create base model architecture
|
| 261 |
+
num_classes = preproc_config['n_classes']
|
| 262 |
+
resnet = models.resnet50(num_classes=num_classes)
|
| 263 |
model = nn.Sequential(normalizer, resnet)
|
| 264 |
|
| 265 |
+
# Download and load checkpoint
|
| 266 |
+
model_path = download_model(model_type)
|
| 267 |
+
|
| 268 |
if model_path:
|
| 269 |
print(f"Loading {model_type} model from {model_path}...")
|
| 270 |
try:
|
| 271 |
checkpoint = torch.load(model_path, map_location=device)
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
# Handle different checkpoint formats
|
| 274 |
if 'model' in checkpoint:
|
|
|
|
| 275 |
state_dict = checkpoint['model']
|
| 276 |
print("Using 'model' key from checkpoint")
|
| 277 |
elif 'state_dict' in checkpoint:
|
| 278 |
state_dict = checkpoint['state_dict']
|
| 279 |
print("Using 'state_dict' key from checkpoint")
|
| 280 |
else:
|
|
|
|
| 281 |
state_dict = checkpoint
|
| 282 |
print("Using checkpoint directly as state_dict")
|
| 283 |
|
| 284 |
+
# Extract ResNet state dict
|
| 285 |
resnet_state_dict = {}
|
|
|
|
| 286 |
resnet_keys = set(resnet.state_dict().keys())
|
| 287 |
|
| 288 |
+
# For face model, prioritize 'module.model.model.' structure (seen in actual checkpoint)
|
| 289 |
+
if model_type == 'resnet50_robust_face':
|
| 290 |
+
# Check for 'module.model.model.' structure first (face checkpoints use this)
|
| 291 |
+
module_model_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.model.')]
|
| 292 |
+
if module_model_model_keys:
|
| 293 |
+
print(f"Found 'module.model.model.' structure with {len(module_model_model_keys)} parameters (face model)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
for source_key, value in state_dict.items():
|
| 295 |
+
if source_key.startswith('module.model.model.'):
|
| 296 |
+
target_key = source_key[len('module.model.model.'):]
|
| 297 |
+
if target_key in resnet_keys:
|
| 298 |
resnet_state_dict[target_key] = value
|
| 299 |
+
print(f"Extracted {len(resnet_state_dict)} parameters from module.model.model.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
+
# Also check for 'module.model.' structure as fallback
|
| 302 |
+
if len(resnet_state_dict) < len(resnet_keys):
|
| 303 |
+
module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.') and not key.startswith('module.model.model.')]
|
| 304 |
+
if module_model_keys:
|
| 305 |
+
print(f"Found additional 'module.model.' structure with {len(module_model_keys)} parameters")
|
| 306 |
+
for source_key, value in state_dict.items():
|
| 307 |
+
if source_key.startswith('module.model.') and not source_key.startswith('module.model.model.'):
|
| 308 |
+
target_key = source_key[len('module.model.'):]
|
| 309 |
+
# Remove extra 'model.' if present
|
| 310 |
+
if target_key.startswith('model.'):
|
| 311 |
+
target_key = target_key[len('model.'):]
|
| 312 |
+
if target_key in resnet_keys and target_key not in resnet_state_dict:
|
| 313 |
+
resnet_state_dict[target_key] = value
|
| 314 |
+
print(f"Now have {len(resnet_state_dict)} parameters after adding module.model. keys")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
+
# Handle different key prefixes in checkpoints (for other models)
|
| 317 |
+
if len(resnet_state_dict) == 0:
|
| 318 |
+
prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.', 'attacker.']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
|
|
|
| 320 |
for source_key, value in state_dict.items():
|
|
|
|
| 321 |
target_key = source_key
|
|
|
|
| 322 |
|
| 323 |
# Try removing various prefixes
|
| 324 |
+
for prefix in prefixes_to_try:
|
| 325 |
if source_key.startswith(prefix):
|
| 326 |
target_key = source_key[len(prefix):]
|
|
|
|
| 327 |
break
|
| 328 |
|
| 329 |
+
# Handle nested model keys
|
| 330 |
+
if target_key.startswith('model.'):
|
| 331 |
+
target_key = target_key[len('model.'):]
|
| 332 |
+
|
| 333 |
+
# If the target key is in ResNet keys, add it
|
| 334 |
+
if target_key in resnet_keys:
|
| 335 |
resnet_state_dict[target_key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
+
# Load the state dict
|
| 338 |
if resnet_state_dict:
|
| 339 |
+
result = resnet.load_state_dict(resnet_state_dict, strict=False)
|
| 340 |
+
missing_keys, unexpected_keys = result
|
| 341 |
+
|
| 342 |
+
loaded_percent = (len(resnet_state_dict) / len(resnet_keys)) * 100
|
| 343 |
+
print(f"Model loading: {len(resnet_state_dict)}/{len(resnet_keys)} parameters ({loaded_percent:.1f}%)")
|
| 344 |
+
|
| 345 |
+
if loaded_percent < 50:
|
| 346 |
+
print(f"Fall-back: Loading too incomplete ({loaded_percent:.1f}%), using PyTorch pretrained")
|
| 347 |
+
if model_type != 'resnet50_robust_face':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 349 |
model = nn.Sequential(normalizer, resnet)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
else:
|
| 352 |
+
print("Fall-back: No matching keys found in checkpoint, using PyTorch pretrained")
|
| 353 |
+
if model_type != 'resnet50_robust_face':
|
| 354 |
+
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 355 |
+
model = nn.Sequential(normalizer, resnet)
|
|
|
|
|
|
|
|
|
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
except Exception as e:
|
| 358 |
+
print(f"Fall-back: Error loading checkpoint: {e}")
|
| 359 |
+
if model_type != 'resnet50_robust_face':
|
| 360 |
+
print("Fall-back: Using PyTorch pretrained model")
|
| 361 |
+
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 362 |
+
model = nn.Sequential(normalizer, resnet)
|
| 363 |
+
else:
|
| 364 |
+
print("Fall-back: Face model checkpoint failed, model may not work properly")
|
| 365 |
+
|
| 366 |
+
else:
|
| 367 |
+
# Use PyTorch's pretrained model for ImageNet models
|
| 368 |
+
if model_type != 'resnet50_robust_face':
|
| 369 |
+
print(f"No checkpoint for {model_type}, using PyTorch pretrained")
|
| 370 |
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 371 |
model = nn.Sequential(normalizer, resnet)
|
| 372 |
+
else:
|
| 373 |
+
print("Fall-back: Face model requires checkpoint, model may not work properly")
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
model = model.to(device)
|
| 376 |
+
model.eval()
|
| 377 |
|
| 378 |
+
# Verify model
|
| 379 |
self.verify_model_integrity(model, model_type)
|
| 380 |
|
| 381 |
+
# Cache the model
|
| 382 |
self.models[model_type] = model
|
| 383 |
+
|
| 384 |
end_time = time.time()
|
| 385 |
+
print(f"Model {model_type} loaded in {end_time - start_time:.2f} seconds")
|
|
|
|
| 386 |
return model
|
| 387 |
+
|
| 388 |
+
def verify_model_integrity(self, model, model_type):
|
| 389 |
+
"""Verify model integrity."""
|
| 390 |
+
try:
|
| 391 |
+
print(f"Fall-back: Running model integrity check for {model_type}")
|
| 392 |
+
config = get_model_preprocessing(model_type)
|
| 393 |
+
H = W = config['input_size']
|
| 394 |
+
|
| 395 |
+
test_input = torch.zeros(1, 3, H, W, device=device)
|
| 396 |
+
test_input[0, 0, H//4:3*H//4, W//4:3*W//4] = 0.5
|
| 397 |
+
|
| 398 |
+
with torch.no_grad():
|
| 399 |
+
output = model(test_input)
|
| 400 |
+
|
| 401 |
+
expected_classes = config['n_classes']
|
| 402 |
+
if output.shape != (1, expected_classes):
|
| 403 |
+
print(f"Fall-back: Unexpected output shape: {output.shape}, expected (1, {expected_classes})")
|
| 404 |
+
return False
|
| 405 |
+
|
| 406 |
+
probs = torch.nn.functional.softmax(output, dim=1)
|
| 407 |
+
confidence, prediction = torch.max(probs, 1)
|
| 408 |
+
|
| 409 |
+
print(f"Model integrity check passed:")
|
| 410 |
+
print(f"- Output shape: {output.shape}")
|
| 411 |
+
print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
|
| 412 |
+
|
| 413 |
+
return True
|
| 414 |
+
|
| 415 |
+
except Exception as e:
|
| 416 |
+
print(f"Fall-back: Model integrity check failed with error: {e}")
|
| 417 |
+
return False
|
| 418 |
+
|
| 419 |
def inference(self, image, model_type, config):
|
| 420 |
+
"""Run generative inference."""
|
|
|
|
| 421 |
inference_start = time.time()
|
| 422 |
|
| 423 |
+
# Load the model
|
| 424 |
model = self.load_model(model_type)
|
| 425 |
|
| 426 |
+
# Handle image input
|
| 427 |
if isinstance(image, str):
|
| 428 |
if os.path.exists(image):
|
| 429 |
image = Image.open(image).convert('RGB')
|
| 430 |
else:
|
| 431 |
raise ValueError(f"Image path does not exist: {image}")
|
| 432 |
+
elif isinstance(image, np.ndarray):
|
| 433 |
+
if image.dtype != np.uint8:
|
| 434 |
+
if image.max() <= 1.0:
|
| 435 |
+
image = (image * 255).astype(np.uint8)
|
| 436 |
+
else:
|
| 437 |
+
image = image.astype(np.uint8)
|
| 438 |
+
if len(image.shape) == 3:
|
| 439 |
+
if image.shape[0] == 3 or image.shape[0] == 1:
|
| 440 |
+
image = np.transpose(image, (1, 2, 0))
|
| 441 |
+
if image.shape[2] == 4:
|
| 442 |
+
image = image[:, :, :3]
|
| 443 |
+
elif image.shape[2] == 1:
|
| 444 |
+
image = np.repeat(image, 3, axis=2)
|
| 445 |
+
image = Image.fromarray(image)
|
| 446 |
+
elif not isinstance(image, Image.Image):
|
| 447 |
+
try:
|
| 448 |
+
image = Image.fromarray(np.array(image)).convert('RGB')
|
| 449 |
+
except Exception as e:
|
| 450 |
+
raise ValueError(f"Cannot convert image type {type(image)} to PIL Image: {e}")
|
| 451 |
|
| 452 |
+
if isinstance(image, Image.Image) and image.mode != 'RGB':
|
| 453 |
+
image = image.convert('RGB')
|
|
|
|
|
|
|
| 454 |
|
| 455 |
+
# Get preprocessing config
|
| 456 |
+
preproc_config = get_model_preprocessing(model_type)
|
| 457 |
+
input_size = preproc_config['input_size']
|
| 458 |
+
norm_mean = torch.tensor(preproc_config['norm_mean'])
|
| 459 |
+
norm_std = torch.tensor(preproc_config['norm_std'])
|
| 460 |
+
n_classes = preproc_config['n_classes']
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
+
# Create transform
|
| 463 |
+
if config.get('inference_normalization', False):
|
| 464 |
+
transform = transforms.Compose([
|
| 465 |
+
transforms.Resize(input_size),
|
| 466 |
+
transforms.CenterCrop(input_size),
|
| 467 |
+
transforms.ToTensor(),
|
| 468 |
+
transforms.Normalize(norm_mean.tolist(), norm_std.tolist()),
|
| 469 |
+
])
|
| 470 |
+
print(f"Fall-back: Using normalization with mean={norm_mean.tolist()}, std={norm_std.tolist()}")
|
| 471 |
else:
|
| 472 |
+
transform = transforms.Compose([
|
| 473 |
+
transforms.Resize(input_size),
|
| 474 |
+
transforms.CenterCrop(input_size),
|
| 475 |
+
transforms.ToTensor(),
|
| 476 |
+
])
|
| 477 |
+
print(f"Normalization OFF - feeding raw [0,1] tensors to model (normalization applied in the model)")
|
| 478 |
|
| 479 |
+
# Helper function to safely apply transform with fallback for numpy compatibility
|
| 480 |
+
def safe_transform(img):
|
| 481 |
+
try:
|
| 482 |
+
return transform(img)
|
| 483 |
+
except TypeError as e:
|
| 484 |
+
if "expected np.ndarray" in str(e) or "got numpy.ndarray" in str(e):
|
| 485 |
+
# Fallback: manually convert PIL to tensor
|
| 486 |
+
print(f"[WARNING] Transform failed with numpy compatibility issue, using manual conversion")
|
| 487 |
+
# Apply resize and center crop manually
|
| 488 |
+
resize_transform = transforms.Resize(input_size)
|
| 489 |
+
crop_transform = transforms.CenterCrop(input_size)
|
| 490 |
+
img = crop_transform(resize_transform(img))
|
| 491 |
+
# Convert to numpy array and then to tensor using torch.tensor() to avoid numpy compatibility issues
|
| 492 |
+
img_array = np.array(img, dtype=np.uint8)
|
| 493 |
+
# Use torch.tensor() instead of torch.from_numpy() to avoid compatibility issues
|
| 494 |
+
# Convert to float and normalize to [0, 1], then convert from HWC to CHW format
|
| 495 |
+
img_tensor = torch.tensor(img_array, dtype=torch.float32).div(255.0).permute(2, 0, 1)
|
| 496 |
+
# Apply normalization if needed
|
| 497 |
+
if config.get('inference_normalization', False):
|
| 498 |
+
img_tensor = transforms.Normalize(norm_mean.tolist(), norm_std.tolist())(img_tensor)
|
| 499 |
+
return img_tensor
|
| 500 |
+
else:
|
| 501 |
+
raise
|
| 502 |
|
| 503 |
+
# Prepare image tensor with safe transform
|
| 504 |
+
image_tensor = safe_transform(image).unsqueeze(0).to(device)
|
| 505 |
+
image_tensor.requires_grad = True
|
| 506 |
+
|
| 507 |
+
# Get model components
|
| 508 |
+
is_sequential = isinstance(model, nn.Sequential)
|
| 509 |
+
if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
|
| 510 |
+
core_model = model[1]
|
| 511 |
+
else:
|
| 512 |
+
core_model = model
|
| 513 |
+
|
| 514 |
+
# Prepare model for layer extraction
|
| 515 |
+
if config.get('top_layer', 'all') != 'all':
|
| 516 |
+
new_model = extract_middle_layers(core_model, config['top_layer'])
|
| 517 |
+
else:
|
| 518 |
+
new_model = model
|
| 519 |
+
|
| 520 |
# Get original predictions
|
| 521 |
with torch.no_grad():
|
| 522 |
+
if config.get('inference_normalization', False):
|
| 523 |
+
output_original = model(image_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
else:
|
| 525 |
+
output_original = core_model(image_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
probs_orig = F.softmax(output_original, dim=1)
|
| 528 |
conf_orig, classes_orig = torch.max(probs_orig, 1)
|
| 529 |
|
| 530 |
+
# Get least confident classes for IncreaseConfidence
|
| 531 |
+
if config['loss_infer'] == 'IncreaseConfidence':
|
| 532 |
+
_, least_confident_classes = torch.topk(probs_orig, k=int(n_classes / 10), largest=False)
|
| 533 |
+
|
| 534 |
+
# Setup for Prior-Guided Drift Diffusion
|
| 535 |
+
noisy_features = None
|
| 536 |
+
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
|
| 537 |
+
print(f"Setting up Prior-Guided Drift Diffusion...")
|
| 538 |
+
added_noise = config.get('initial_inference_noise_ratio', 0.05) * torch.randn_like(image_tensor).to(device)
|
| 539 |
+
noisy_image_tensor = image_tensor + added_noise
|
| 540 |
+
noisy_features = new_model(noisy_image_tensor)
|
| 541 |
+
|
| 542 |
# Initialize inference step
|
| 543 |
infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
|
| 544 |
|
| 545 |
# Storage for inference steps
|
|
|
|
| 546 |
x = image_tensor.clone().detach().requires_grad_(True)
|
| 547 |
all_steps = [image_tensor[0].detach().cpu()]
|
| 548 |
|
| 549 |
+
selected_inferred_patterns = []
|
| 550 |
+
perceived_categories = []
|
| 551 |
+
confidence_list = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
|
| 553 |
# Main inference loop
|
| 554 |
print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
|
| 555 |
+
|
| 556 |
for i in range(config['n_itr']):
|
| 557 |
# Reset gradients
|
| 558 |
x.grad = None
|
| 559 |
|
| 560 |
+
if i == 0:
|
| 561 |
+
# Get predictions for first iteration
|
| 562 |
+
if config.get('inference_normalization', False):
|
| 563 |
+
output = model(x)
|
| 564 |
+
else:
|
| 565 |
+
output = core_model(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
+
if isinstance(output, torch.Tensor) and output.size(-1) == n_classes:
|
| 568 |
+
probs = F.softmax(output, dim=1)
|
| 569 |
+
conf, classes = torch.max(probs, 1)
|
|
|
|
|
|
|
| 570 |
else:
|
| 571 |
+
probs = 0
|
| 572 |
+
conf = 0
|
| 573 |
+
classes = 'N/A'
|
| 574 |
+
else:
|
| 575 |
+
# Calculate loss and gradients
|
| 576 |
+
try:
|
| 577 |
+
# Forward pass through new_model for feature extraction
|
| 578 |
+
features = new_model(x)
|
| 579 |
|
| 580 |
+
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
|
| 581 |
+
assert config.get('loss_function', 'MSE') == 'MSE', "Prior-Guided Drift Diffusion requires MSE loss"
|
| 582 |
+
if noisy_features is not None:
|
| 583 |
+
loss = F.mse_loss(features, noisy_features)
|
| 584 |
+
grad = torch.autograd.grad(loss, x)[0]
|
| 585 |
+
adjusted_grad = infer_step.step(x, grad)
|
| 586 |
+
else:
|
| 587 |
+
raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
|
| 588 |
|
| 589 |
+
elif config['loss_infer'] == 'IncreaseConfidence':
|
| 590 |
+
# Calculate loss using least confident classes
|
| 591 |
+
num_target_classes = min(int(n_classes / 10), least_confident_classes.size(1))
|
| 592 |
+
target_classes = least_confident_classes[0, :num_target_classes]
|
| 593 |
+
|
| 594 |
+
loss = calculate_loss(features, target_classes.tolist(), config.get('loss_function', 'CE'))
|
| 595 |
+
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
|
| 596 |
+
adjusted_grad = infer_step.step(x, grad)
|
| 597 |
|
| 598 |
+
else:
|
| 599 |
+
raise ValueError(f"Loss inference method {config['loss_infer']} not supported")
|
| 600 |
+
|
| 601 |
+
if grad is None:
|
| 602 |
+
print("Fall-back: Direct gradient calculation failed")
|
| 603 |
+
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
|
| 604 |
+
x = infer_step.project(x.clone() + random_noise)
|
| 605 |
+
else:
|
| 606 |
+
# Add diffusion noise if specified
|
| 607 |
+
diffusion_noise = config.get('diffusion_noise_ratio', 0.0) * torch.randn_like(x).to(device)
|
| 608 |
+
x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
|
| 609 |
+
|
| 610 |
+
except Exception as e:
|
| 611 |
+
print(f"Fall-back: Error in gradient calculation: {e}")
|
| 612 |
+
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
|
| 613 |
+
x = infer_step.project(x.clone() + random_noise)
|
| 614 |
|
| 615 |
# Store step if in iterations_to_show
|
| 616 |
+
if i+1 in config.get('iterations_to_show', []) or i+1 == config['n_itr']:
|
| 617 |
all_steps.append(x[0].detach().cpu())
|
| 618 |
+
selected_inferred_patterns.append(x[0].detach().cpu())
|
| 619 |
+
|
| 620 |
+
# Get current predictions
|
| 621 |
+
with torch.no_grad():
|
| 622 |
+
if config.get('inference_normalization', False):
|
| 623 |
+
current_output = model(x)
|
| 624 |
+
else:
|
| 625 |
+
current_output = core_model(x)
|
| 626 |
+
|
| 627 |
+
if isinstance(current_output, torch.Tensor) and current_output.size(-1) == n_classes:
|
| 628 |
+
current_probs = F.softmax(current_output, dim=1)
|
| 629 |
+
current_conf, current_classes = torch.max(current_probs, 1)
|
| 630 |
+
perceived_categories.append(current_classes.item())
|
| 631 |
+
confidence_list.append(current_conf.item())
|
| 632 |
+
else:
|
| 633 |
+
perceived_categories.append('N/A')
|
| 634 |
+
confidence_list.append(0.0)
|
| 635 |
|
| 636 |
+
# Final predictions
|
| 637 |
with torch.no_grad():
|
| 638 |
+
if config.get('inference_normalization', False):
|
| 639 |
+
final_output = model(x)
|
|
|
|
|
|
|
|
|
|
| 640 |
else:
|
| 641 |
+
final_output = core_model(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
final_probs = F.softmax(final_output, dim=1)
|
| 644 |
final_conf, final_classes = torch.max(final_probs, 1)
|
| 645 |
|
|
|
|
|
|
|
| 646 |
total_time = time.time() - inference_start
|
|
|
|
| 647 |
|
| 648 |
print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
|
| 649 |
print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
|
|
|
|
| 650 |
print(f"Total inference time: {total_time:.2f} seconds")
|
| 651 |
|
| 652 |
+
# Return results in Code 1 format
|
| 653 |
return {
|
| 654 |
'final_image': x[0].detach().cpu(),
|
| 655 |
'steps': all_steps,
|
| 656 |
'original_class': classes_orig.item(),
|
| 657 |
'original_confidence': conf_orig.item(),
|
| 658 |
'final_class': final_classes.item(),
|
| 659 |
+
'final_confidence': final_conf.item(),
|
| 660 |
+
'all_categories': perceived_categories,
|
| 661 |
+
'all_confidences': confidence_list,
|
| 662 |
}
|
| 663 |
|
|
|
|
| 664 |
def show_inference_steps(steps, figsize=(15, 10)):
|
| 665 |
+
"""Show inference steps using matplotlib."""
|
| 666 |
+
try:
|
| 667 |
+
import matplotlib.pyplot as plt
|
| 668 |
+
|
| 669 |
+
n_steps = len(steps)
|
| 670 |
+
fig, axes = plt.subplots(1, n_steps, figsize=figsize)
|
| 671 |
+
|
| 672 |
+
if n_steps == 1:
|
| 673 |
+
axes = [axes]
|
| 674 |
+
|
| 675 |
+
for i, step_img in enumerate(steps):
|
| 676 |
+
if isinstance(step_img, torch.Tensor):
|
| 677 |
+
img = step_img.permute(1, 2, 0).numpy()
|
| 678 |
+
img = np.clip(img, 0, 1)
|
| 679 |
+
else:
|
| 680 |
+
img = step_img
|
| 681 |
+
|
| 682 |
+
axes[i].imshow(img)
|
| 683 |
+
axes[i].set_title(f"Step {i+1}")
|
| 684 |
+
axes[i].axis('off')
|
| 685 |
+
|
| 686 |
+
plt.tight_layout()
|
| 687 |
+
return fig
|
| 688 |
+
|
| 689 |
+
except ImportError:
|
| 690 |
+
print("Fall-back: matplotlib not available for visualization")
|
| 691 |
+
return None
|
| 692 |
+
except Exception as e:
|
| 693 |
+
print(f"Fall-back: Visualization failed: {e}")
|
| 694 |
+
return None
|
| 695 |
+
|
| 696 |
+
# Export the main classes and functions
|
| 697 |
+
__all__ = ['GenerativeInferenceModel', 'get_inference_configs', 'show_inference_steps']
|
logs/model_loading_resnet50_robust_face.log
CHANGED
|
@@ -2,45 +2,8 @@
|
|
| 2 |
===== MODEL LOADING REPORT: resnet50_robust_face =====
|
| 3 |
Total parameters in checkpoint: 320
|
| 4 |
Total parameters in model: 320
|
| 5 |
-
Missing keys:
|
| 6 |
-
Unexpected keys:
|
| 7 |
-
Successfully loaded:
|
| 8 |
-
Loading status:
|
| 9 |
-
|
| 10 |
-
⚠️ WARNING: Loading from checkpoint is too incomplete.
|
| 11 |
-
⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.
|
| 12 |
-
✅ Successfully loaded PyTorch's pretrained ResNet50 model
|
| 13 |
-
|
| 14 |
-
Missing keys by layer type:
|
| 15 |
-
layer3: 95 parameters
|
| 16 |
-
layer2: 65 parameters
|
| 17 |
-
layer1: 50 parameters
|
| 18 |
-
layer4: 50 parameters
|
| 19 |
-
bn1: 4 parameters
|
| 20 |
-
fc: 2 parameters
|
| 21 |
-
conv1: 1 parameters
|
| 22 |
-
|
| 23 |
-
First 10 missing keys:
|
| 24 |
-
1. bn1.bias
|
| 25 |
-
2. bn1.running_mean
|
| 26 |
-
3. bn1.running_var
|
| 27 |
-
4. bn1.weight
|
| 28 |
-
5. conv1.weight
|
| 29 |
-
6. fc.bias
|
| 30 |
-
7. fc.weight
|
| 31 |
-
8. layer1.0.bn1.bias
|
| 32 |
-
9. layer1.0.bn1.running_mean
|
| 33 |
-
10. layer1.0.bn1.running_var
|
| 34 |
-
|
| 35 |
-
First 10 unexpected keys:
|
| 36 |
-
1. model.bn1.bias
|
| 37 |
-
2. model.bn1.num_batches_tracked
|
| 38 |
-
3. model.bn1.running_mean
|
| 39 |
-
4. model.bn1.running_var
|
| 40 |
-
5. model.bn1.weight
|
| 41 |
-
6. model.conv1.weight
|
| 42 |
-
7. model.fc.bias
|
| 43 |
-
8. model.fc.weight
|
| 44 |
-
9. model.layer1.0.bn1.bias
|
| 45 |
-
10. model.layer1.0.bn1.num_batches_tracked
|
| 46 |
========================================
|
|
|
|
| 2 |
===== MODEL LOADING REPORT: resnet50_robust_face =====
|
| 3 |
Total parameters in checkpoint: 320
|
| 4 |
Total parameters in model: 320
|
| 5 |
+
Missing keys: 0 parameters
|
| 6 |
+
Unexpected keys: 0 parameters
|
| 7 |
+
Successfully loaded: 320 parameters (100.0%)
|
| 8 |
+
Loading status: ✅ COMPLETE - All important parameters loaded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
========================================
|
models/resnet50_imagenet_L2_eps_0.50_checkpoint150.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:40bfb9a204f1d9a305ed6374acbfc55fe2745433cf1e421952d4b461f577486a
|
| 3 |
-
size 196695413
|
|
|
|
|
|
|
|
|
|
|
|
models/resnet50_robust.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:380b14e6f9750bffa1447cf7017f65da4dc5ce71a3dd112f107515dcf7b14d9d
|
| 3 |
-
size 204818947
|
|
|
|
|
|
|
|
|
|
|
|
models/resnet50_robust_face_100_checkpoint.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c48a5c16ca0d5ac4cb20f1b98e2128838746f18b658728ac661f1ffd589c37bf
|
| 3 |
-
size 196695413
|
|
|
|
|
|
|
|
|
|
|
|
models/robust_resnet50.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:380b14e6f9750bffa1447cf7017f65da4dc5ce71a3dd112f107515dcf7b14d9d
|
| 3 |
-
size 204818947
|
|
|
|
|
|
|
|
|
|
|
|
models/standard_resnet50.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:72d4a99582db5d7fa86c3fd2a089f0bfd6a10f69d635bca51f6ad72ac6b458f0
|
| 3 |
-
size 204818947
|
|
|
|
|
|
|
|
|
|
|
|
stimuli/RandomizedPhaseOvalGray.png
ADDED
|