Tahereh Toosi commited on
Commit
679456b
·
1 Parent(s): c4da6bc

Update application for Hugging Face Space deployment

Browse files

- Updated app.py with comprehensive visual illusion examples
- Modified README.md with deployment information
- Updated huggingface-metadata.json configuration
- Enhanced inference.py functionality
- Added face_vase.png stimulus image

Files changed (5) hide show
  1. README.md +49 -36
  2. app.py +101 -75
  3. face_vase.png +0 -0
  4. huggingface-metadata.json +4 -4
  5. inference.py +8 -2
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- title: Generative Inference Demo
3
- emoji: 🧠
4
- colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.23.1
@@ -10,37 +10,38 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- # Generative Inference Demo
14
 
15
- This Gradio demo showcases how neural networks perceive visual illusions through generative inference. The demo uses both standard and robust ResNet50 models to reveal emergent perception of contours, figure-ground separation, and other visual phenomena.
16
 
17
- ## Models
 
 
18
 
19
- - **Robust ResNet50**: A model trained with adversarial examples (ε=3.0), exhibiting more human-like visual perception
20
- - **Standard ResNet50**: A model trained without adversarial examples (ε=0.0)
 
 
 
21
 
22
  ## Features
23
 
24
- - Upload your own images or use example illusions
25
- - Choose between robust and standard models
26
- - Adjust perturbation size (epsilon) and iteration count
27
- - Visualize how perception emerges over time
28
- - Includes classic illusions:
29
- - Kanizsa shapes
30
- - Face-Vase illusions
31
- - Figure-Ground segmentation
32
- - Neon color spreading
33
 
34
  ## Usage
35
 
36
- 1. Select an example image or upload your own
37
- 2. Choose the model type (robust or standard)
38
- 3. Adjust epsilon and iteration parameters
39
- 4. Click "Run Inference" to see how the model perceives the image
40
 
41
- ## About
42
 
43
- This demo is based on research showing how adversarially robust models develop more human-like visual representations. The generative inference process reveals these perceptual biases by optimizing the input to maximize the model's confidence.
44
 
45
  ## Installation
46
 
@@ -48,8 +49,8 @@ To run this demo locally:
48
 
49
  ```bash
50
  # Clone the repository
51
- git clone [repo-url]
52
- cd GenerativeInferenceDemo
53
 
54
  # Install dependencies
55
  pip install -r requirements.txt
@@ -58,25 +59,37 @@ pip install -r requirements.txt
58
  python app.py
59
  ```
60
 
61
- The web app will be available at http://localhost:7860 (or another port if 7860 is busy).
62
 
63
- ## About the Models
64
 
65
- - **Robust ResNet50**: A model trained with adversarial examples, making it more robust to small perturbations. These models often exhibit more human-like visual perception.
66
- - **Standard ResNet50**: A standard ImageNet-trained ResNet50 model.
 
 
67
 
68
- ## How It Works
69
 
70
- 1. The algorithm starts with an input image
71
- 2. It iteratively updates the image to increase the model's confidence in its predictions
72
- 3. These updates are constrained to a small neighborhood (controlled by epsilon) around the original image
73
- 4. The resulting changes reveal how the network "sees" the image
74
 
75
  ## Citation
76
 
77
- If you use this work in your research, please cite the original paper:
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- [Citation information will be added here]
80
 
81
  ## License
82
 
 
1
  ---
2
+ title: Human Hallucination Prediction
3
+ emoji: 👁️
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.23.1
 
10
  license: mit
11
  ---
12
 
13
+ # Human Hallucination Prediction
14
 
15
+ This Gradio demo predicts whether humans will experience visual hallucinations or illusions when viewing specific images. Using adversarially robust neural networks, this tool can forecast perceptual phenomena like illusory contours, figure-ground reversals, and other Gestalt effects before humans report them.
16
 
17
+ ## How It Works
18
+
19
+ This tool uses **generative inference** with adversarially robust neural networks to predict human visual hallucinations. Robust models trained with adversarial examples develop more human-like perceptual biases, allowing them to predict when humans will perceive:
20
 
21
+ - **Illusory contours** (Kanizsa shapes, Ehrenstein illusion)
22
+ - **Figure-ground ambiguity** (Rubin's vase, bistable images)
23
+ - **Color spreading effects** (Neon color illusion)
24
+ - **Gestalt grouping** (Continuity, proximity)
25
+ - **Brightness illusions** (Cornsweet effect)
26
 
27
  ## Features
28
 
29
+ - **Predict hallucinations** from uploaded images or example illusions
30
+ - **Visualize the prediction process** step-by-step
31
+ - **Compare different models** (robust vs. standard)
32
+ - **Adjust prediction parameters** for different perceptual phenomena
33
+ - **Pre-configured examples** of classic visual illusions
 
 
 
 
34
 
35
  ## Usage
36
 
37
+ 1. **Select an example illusion** or upload your own image
38
+ 2. **Click "Load Parameters"** to set optimal prediction settings
39
+ 3. **Click "Run Generative Inference"** to predict the hallucination
40
+ 4. **View the results**: The model will show what perceptual effects it predicts humans will experience
41
 
42
+ ## Scientific Background
43
 
44
+ This demo is based on research showing that adversarially robust neural networks develop perceptual representations similar to human vision. By using generative inference (optimizing images to maximize model confidence), we can reveal what perceptual structures the network expects to see—which often matches what humans hallucinate or perceive in ambiguous images.
45
 
46
  ## Installation
47
 
 
49
 
50
  ```bash
51
  # Clone the repository
52
+ git clone https://huggingface.co/spaces/ttoosi/Human_Hallucination_Prediction
53
+ cd Human_Hallucination_Prediction
54
 
55
  # Install dependencies
56
  pip install -r requirements.txt
 
59
  python app.py
60
  ```
61
 
62
+ The web app will be available at http://localhost:7860.
63
 
64
+ ## The Prediction Process
65
 
66
+ 1. **Input**: Start with an ambiguous or illusion-inducing image
67
+ 2. **Generative Inference**: The robust neural network iteratively modifies the image to maximize its confidence
68
+ 3. **Prediction**: The modifications reveal what perceptual structures the network expects—predicting what humans will hallucinate
69
+ 4. **Visualization**: View the predicted hallucination emerging step-by-step
70
 
71
+ ## Models
72
 
73
+ - **Robust ResNet50**: Trained with adversarial examples (ε=3.0), develops human-like perceptual biases
74
+ - **Standard ResNet50**: Standard ImageNet training without adversarial robustness
 
 
75
 
76
  ## Citation
77
 
78
+ If you use this work in your research, please cite:
79
+
80
+ ```bibtex
81
+ @article{toosi2024hallucination,
82
+ title={Predicting Human Visual Hallucinations with Robust Neural Networks},
83
+ author={Toosi, Tahereh},
84
+ year={2024}
85
+ }
86
+ ```
87
+
88
+ ## About
89
+
90
+ **Developed by [Tahereh Toosi](https://toosi.github.io)**
91
 
92
+ This demo demonstrates how adversarially robust neural networks can predict human perceptual hallucinations before they occur.
93
 
94
  ## License
95
 
app.py CHANGED
@@ -79,6 +79,7 @@ examples = [
79
  "[Brightness Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
80
  "[Edge Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
81
  ],
 
82
  "method": "Prior-Guided Drift Diffusion",
83
  "reverse_diff": {
84
  "model": "resnet50_robust",
@@ -101,18 +102,18 @@ examples = [
101
  "method": "Prior-Guided Drift Diffusion",
102
  "reverse_diff": {
103
  "model": "resnet50_robust",
104
- "layer": "layer4",
105
- "initial_noise": 0.5,
106
- "diffusion_noise": 0.01,
107
- "step_size": 0.2,
108
- "iterations": 301,
109
- "epsilon": 40.0
110
  }
111
  },
112
  {
113
  "image": os.path.join("stimuli", "Confetti_illusion.png"),
114
  "name": "Confetti Illusion",
115
- "wiki": "https://en.wikipedia.org/wiki/Optical_illusion",
116
  "papers": [
117
  "[Color Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
118
  "[Context Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
@@ -243,69 +244,86 @@ def apply_example(example):
243
  example["reverse_diff"]["initial_noise"], # Initial noise
244
  example["reverse_diff"]["diffusion_noise"], # Diffusion noise value (corrected)
245
  example["reverse_diff"]["step_size"], # Step size (added)
246
- example["reverse_diff"]["layer"] # Model layer
 
247
  ]
248
 
249
  # Define the interface
250
- with gr.Blocks(title="Generative Inference Demo") as demo:
251
- gr.Markdown("# Generative Inference Demo")
252
- gr.Markdown("This demo showcases how neural networks can perceive visual illusions through generative inference.")
 
 
 
 
 
 
 
 
 
253
 
254
  gr.Markdown("""
255
- **How to use this demo:**
256
- - **Load pre-configured examples**: Click on any visual illusion below and hit "Load Parameters" to automatically set up the optimal parameters for that illusion
257
- - **Upload your own images**: Use the image upload area to test your own images with different parameter settings
258
- - **Experiment with parameters**: Adjust the inference method, iterations, noise levels, and other parameters to see how they affect the generative inference process
 
259
  """)
260
 
261
  # Main processing interface
262
  with gr.Row():
263
  with gr.Column(scale=1):
264
  # Inputs
265
- image_input = gr.Image(label="Input Image", type="pil")
266
 
267
- with gr.Row():
268
- model_choice = gr.Dropdown(
269
- choices=["resnet50_robust", "standard_resnet50"],
270
- value="resnet50_robust",
271
- label="Model"
272
- )
273
-
274
- inference_type = gr.Dropdown(
275
- choices=["Prior-Guided Drift Diffusion", "IncreaseConfidence"],
276
- value="Prior-Guided Drift Diffusion",
277
- label="Inference Method"
278
- )
279
 
280
- with gr.Row():
281
- eps_slider = gr.Slider(minimum=0.01, maximum=3.0, value=0.5, step=0.01, label="Epsilon (Perturbation Size)")
282
- iterations_slider = gr.Slider(minimum=1, maximum=600, value=50, step=1, label="Number of Iterations") # Updated max to 600
283
 
284
- with gr.Row():
285
- initial_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.05, step=0.01,
286
- label="Initial Noise Ratio")
287
- diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=0.05, value=0.01, step=0.001,
288
- label="Diffusion Noise Ratio") # Corrected name
 
 
 
 
 
 
 
 
 
289
 
290
- with gr.Row():
291
- step_size_slider = gr.Slider(minimum=0.01, maximum=2.0, value=0.5, step=0.01,
292
- label="Step Size") # Added step size slider
293
- layer_choice = gr.Dropdown(
294
- choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
295
- value="all",
296
- label="Model Layer"
297
- )
298
-
299
- run_button = gr.Button("Run Inference", variant="primary")
 
 
 
 
 
 
 
 
300
 
301
  with gr.Column(scale=2):
302
  # Outputs
303
- output_image = gr.Image(label="Final Inferred Image")
304
- output_frames = gr.Gallery(label="Inference Steps", columns=5, rows=2)
305
 
306
  # Examples section with integrated explanations
307
- gr.Markdown("## Visual Illusion Examples")
308
- gr.Markdown("Select an illusion to load its parameters and see how generative inference reveals perceptual effects")
309
 
310
  # For each example, create a row with the image and explanation side by side
311
  for i, ex in enumerate(examples):
@@ -323,7 +341,7 @@ with gr.Blocks(title="Generative Inference Demo") as demo:
323
  image_input, model_choice, inference_type,
324
  eps_slider, iterations_slider,
325
  initial_noise_slider, diffusion_noise_slider,
326
- step_size_slider, layer_choice
327
  ]
328
  )
329
 
@@ -332,17 +350,10 @@ with gr.Blocks(title="Generative Inference Demo") as demo:
332
  gr.Markdown(f"### {ex['name']}")
333
  gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
334
 
335
- gr.Markdown("**Generative Inference Parameters:**")
336
- params_md = f"""
337
- - **Method**: {ex['method']}
338
- - **Model Layer**: {ex['reverse_diff']['layer']}
339
- - **Initial Noise**: {ex['reverse_diff']['initial_noise']}
340
- - **Diffusion Noise**: {ex['reverse_diff']['diffusion_noise']}
341
- - **Step Size**: {ex['reverse_diff']['step_size']}
342
- - **Iterations**: {ex['reverse_diff']['iterations']}
343
- - **Epsilon**: {ex['reverse_diff']['epsilon']}
344
- """
345
- gr.Markdown(params_md)
346
 
347
  if i < len(examples) - 1: # Don't add separator after the last example
348
  gr.Markdown("---")
@@ -359,27 +370,42 @@ with gr.Blocks(title="Generative Inference Demo") as demo:
359
  outputs=[output_image, output_frames]
360
  )
361
 
 
 
 
 
 
 
 
 
 
362
  # About section
363
  gr.Markdown("""
364
- ## About Generative Inference
 
 
365
 
366
- Generative inference is a technique that reveals how neural networks perceive visual stimuli. This demo primarily uses the Prior-Guided Drift Diffusion method.
367
 
368
- ### Prior-Guided Drift Diffusion
369
- Moving away from a noisy representation of the input images
370
 
371
- ### IncreaseConfidence
372
- Moving away from the least likely class identified at iteration 0 (fast perception)
373
 
374
  ### Parameters:
375
- - **Initial Noise Ratio**: Controls the amount of noise added to the image at the beginning
376
- - **Diffusion Noise Ratio**: Controls the amount of noise added at each optimization step
377
- - **Step Size**: Learning rate for the optimization process
378
- - **Number of Iterations**: How many optimization steps to perform
379
- - **Model Layer**: Select a specific layer of the ResNet50 model to extract features from
380
- - **Epsilon**: Controls the size of perturbation during optimization
 
 
 
 
381
 
382
- **Generative Inference was developed by [Tahereh Toosi](https://toosi.github.io).**
383
  """)
384
 
385
  # Launch the demo
 
79
  "[Brightness Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
80
  "[Edge Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
81
  ],
82
+ "instructions": "Both blocks are gray in color (the same), use your finger to cover the middle line. Hit 'Load Parameters' and then hit 'Run Generative Inference' to see how the model sees the blocks.",
83
  "method": "Prior-Guided Drift Diffusion",
84
  "reverse_diff": {
85
  "model": "resnet50_robust",
 
102
  "method": "Prior-Guided Drift Diffusion",
103
  "reverse_diff": {
104
  "model": "resnet50_robust",
105
+ "layer": "avgpool",
106
+ "initial_noise": 0.9,
107
+ "diffusion_noise": 0.003,
108
+ "step_size": 0.58,
109
+ "iterations": 100,
110
+ "epsilon": 0.81
111
  }
112
  },
113
  {
114
  "image": os.path.join("stimuli", "Confetti_illusion.png"),
115
  "name": "Confetti Illusion",
116
+ "wiki": "https://www.youtube.com/watch?v=SvEiEi8O7QE",
117
  "papers": [
118
  "[Color Perception](https://doi.org/10.1016/j.visres.2000.200.1)",
119
  "[Context Effects](https://doi.org/10.1016/j.tics.2003.08.003)"
 
244
  example["reverse_diff"]["initial_noise"], # Initial noise
245
  example["reverse_diff"]["diffusion_noise"], # Diffusion noise value (corrected)
246
  example["reverse_diff"]["step_size"], # Step size (added)
247
+ example["reverse_diff"]["layer"], # Model layer
248
+ gr.Group(visible=True) # Show parameters section
249
  ]
250
 
251
  # Define the interface
252
+ with gr.Blocks(title="Human Hallucination Prediction", css="""
253
+ .purple-button {
254
+ background-color: #8B5CF6 !important;
255
+ color: white !important;
256
+ border: none !important;
257
+ }
258
+ .purple-button:hover {
259
+ background-color: #7C3AED !important;
260
+ }
261
+ """) as demo:
262
+ gr.Markdown("# 👁️ Human Hallucination Prediction")
263
+ gr.Markdown("**Predict what visual hallucinations humans will experience** using adversarially robust neural networks. This demo forecasts perceptual phenomena like illusory contours, figure-ground reversals, and Gestalt effects before humans report them.")
264
 
265
  gr.Markdown("""
266
+ **How to predict hallucinations:**
267
+ 1. **Select an example illusion** below and click "Load Parameters" to set optimal prediction settings
268
+ 2. **Click "Run Generative Inference"** to predict what hallucination humans will perceive
269
+ 3. **View the prediction**: Watch as the model reveals the perceptual structures it expects—matching what humans typically hallucinate
270
+ 4. **Upload your own images** to test if they will induce hallucinations in human observers
271
  """)
272
 
273
  # Main processing interface
274
  with gr.Row():
275
  with gr.Column(scale=1):
276
  # Inputs
277
+ image_input = gr.Image(label="Input Image", type="pil", value=os.path.join("stimuli", "Neon_Color_Circle.jpg"))
278
 
279
+ # Run Inference button right below the image
280
+ run_button = gr.Button("🔮 Predict Hallucination", variant="primary", elem_classes="purple-button")
 
 
 
 
 
 
 
 
 
 
281
 
282
+ # Parameters toggle button
283
+ params_button = gr.Button("⚙️ Play with the parameters", variant="secondary")
 
284
 
285
+ # Parameters section (initially hidden)
286
+ with gr.Group(visible=False) as params_section:
287
+ with gr.Row():
288
+ model_choice = gr.Dropdown(
289
+ choices=["resnet50_robust", "standard_resnet50"], # "resnet50_robust_face" - hidden for deployment
290
+ value="resnet50_robust",
291
+ label="Model"
292
+ )
293
+
294
+ inference_type = gr.Dropdown(
295
+ choices=["Prior-Guided Drift Diffusion", "IncreaseConfidence"],
296
+ value="Prior-Guided Drift Diffusion",
297
+ label="Inference Method"
298
+ )
299
 
300
+ with gr.Row():
301
+ eps_slider = gr.Slider(minimum=0.0, maximum=40.0, value=20.0, step=0.01, label="Epsilon (Stimulus Fidelity)")
302
+ iterations_slider = gr.Slider(minimum=1, maximum=600, value=101, step=1, label="Number of Iterations") # Updated max to 600
303
+
304
+ with gr.Row():
305
+ initial_noise_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.01,
306
+ label="Drift Noise")
307
+ diffusion_noise_slider = gr.Slider(minimum=0.0, maximum=0.05, value=0.003, step=0.001,
308
+ label="Diffusion Noise") # Corrected name
309
+
310
+ with gr.Row():
311
+ step_size_slider = gr.Slider(minimum=0.01, maximum=2.0, value=1.0, step=0.01,
312
+ label="Update Rate") # Added step size slider
313
+ layer_choice = gr.Dropdown(
314
+ choices=["all", "conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4", "avgpool"],
315
+ value="layer3",
316
+ label="Model Layer"
317
+ )
318
 
319
  with gr.Column(scale=2):
320
  # Outputs
321
+ output_image = gr.Image(label="Predicted Hallucination")
322
+ output_frames = gr.Gallery(label="Hallucination Prediction Process", columns=5, rows=2)
323
 
324
  # Examples section with integrated explanations
325
+ gr.Markdown("## 🎯 Visual Illusion Examples")
326
+ gr.Markdown("Select an illusion to predict what hallucination humans will experience when viewing it")
327
 
328
  # For each example, create a row with the image and explanation side by side
329
  for i, ex in enumerate(examples):
 
341
  image_input, model_choice, inference_type,
342
  eps_slider, iterations_slider,
343
  initial_noise_slider, diffusion_noise_slider,
344
+ step_size_slider, layer_choice, params_section
345
  ]
346
  )
347
 
 
350
  gr.Markdown(f"### {ex['name']}")
351
  gr.Markdown(f"[Read more on Wikipedia]({ex['wiki']})")
352
 
353
+ # Show instructions if they exist
354
+ if "instructions" in ex:
355
+ gr.Markdown(f"**Instructions:** {ex['instructions']}")
356
+
 
 
 
 
 
 
 
357
 
358
  if i < len(examples) - 1: # Don't add separator after the last example
359
  gr.Markdown("---")
 
370
  outputs=[output_image, output_frames]
371
  )
372
 
373
+ # Toggle parameters visibility
374
+ def toggle_params():
375
+ return gr.Group(visible=True)
376
+
377
+ params_button.click(
378
+ fn=toggle_params,
379
+ outputs=[params_section]
380
+ )
381
+
382
  # About section
383
  gr.Markdown("""
384
+ ## 🧠 About Hallucination Prediction
385
+
386
+ This tool predicts human visual hallucinations using **generative inference** with adversarially robust neural networks. Robust models develop human-like perceptual biases, allowing them to forecast what perceptual structures humans will experience.
387
 
388
+ ### Prediction Methods:
389
 
390
+ **Prior-Guided Drift Diffusion** (Primary Method)
391
+ Starting from a noisy representation, the model converges toward what it expects to perceive—revealing predicted hallucinations
392
 
393
+ **IncreaseConfidence**
394
+ Moving away from unlikely interpretations to reveal the most probable perceptual experience
395
 
396
  ### Parameters:
397
+ - **Drift Noise**: Initial uncertainty in the prediction process
398
+ - **Diffusion Noise**: Stochastic exploration during prediction
399
+ - **Update Rate**: Speed of convergence to the predicted hallucination
400
+ - **Number of Iterations**: How many prediction steps to perform
401
+ - **Model Layer**: Which perceptual level to predict from (early edges vs. high-level objects)
402
+ - **Epsilon (Stimulus Fidelity)**: How closely the prediction must match the input stimulus
403
+
404
+ ### Why Does This Work?
405
+
406
+ Adversarially robust neural networks develop perceptual representations similar to human vision. When we use generative inference to reveal what these networks "expect" to see, it matches what humans hallucinate in ambiguous images—allowing us to predict human perception.
407
 
408
+ **Developed by [Tahereh Toosi](https://toosi.github.io)**
409
  """)
410
 
411
  # Launch the demo
face_vase.png ADDED
huggingface-metadata.json CHANGED
@@ -1,10 +1,10 @@
1
  {
2
- "title": "Generative Inference Demo",
3
- "emoji": "🧠",
4
- "colorFrom": "indigo",
5
  "colorTo": "purple",
6
  "sdk": "gradio",
7
- "sdk_version": "3.32.0",
8
  "app_file": "app.py",
9
  "pinned": false,
10
  "license": "mit"
 
1
  {
2
+ "title": "Human Hallucination Prediction",
3
+ "emoji": "👁️",
4
+ "colorFrom": "blue",
5
  "colorTo": "purple",
6
  "sdk": "gradio",
7
+ "sdk_version": "5.23.1",
8
  "app_file": "app.py",
9
  "pinned": false,
10
  "license": "mit"
inference.py CHANGED
@@ -25,7 +25,8 @@ print(f"Using device: {device}")
25
  # Constants
26
  MODEL_URLS = {
27
  'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
28
- 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt'
 
29
  }
30
 
31
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
@@ -162,7 +163,12 @@ def download_model(model_type):
162
  if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
163
  return None # Use PyTorch's pretrained model
164
 
165
- model_path = Path(f"models/{model_type}.pt")
 
 
 
 
 
166
  if not model_path.exists():
167
  print(f"Downloading {model_type} model...")
168
  url = MODEL_URLS[model_type]
 
25
  # Constants
26
  MODEL_URLS = {
27
  'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
28
+ 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
29
+ 'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/blob/main/100_checkpoint.pt'
30
  }
31
 
32
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
 
163
  if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
164
  return None # Use PyTorch's pretrained model
165
 
166
+ # Handle special case for face model
167
+ if model_type == 'resnet50_robust_face':
168
+ model_path = Path("models/resnet50_robust_face_100_checkpoint.pt")
169
+ else:
170
+ model_path = Path(f"models/{model_type}.pt")
171
+
172
  if not model_path.exists():
173
  print(f"Downloading {model_type} model...")
174
  url = MODEL_URLS[model_type]