ttoosi commited on
Commit
9c56da2
·
verified ·
1 Parent(s): bdecc48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -36
app.py CHANGED
@@ -9,18 +9,9 @@ from PIL import Image
9
  import numpy as np
10
  import random
11
 
12
- # from datasets import load_dataset
13
- # from datasets import DatasetDict
14
- # ds = DatasetDict({
15
- # "validation": load_dataset("chronopt-research/cropped-vggface2-224", split="validation"),
16
- # })
17
 
18
 
19
 
20
- # Load the VGGFace2 dataset using Hugging Face's datasets library
21
- # ds = load_dataset("chronopt-research/cropped-vggface2-224", split="validation")
22
-
23
-
24
  # Load the model checkpoint from Hugging Face
25
  checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt")
26
 
@@ -67,11 +58,17 @@ preprocess = transforms.Compose([
67
  # return f"Predicted class: {predicted_class.item()}", sample_images_urls
68
 
69
 
70
- # Simplified Generative Inference
 
 
 
 
 
 
71
  def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0.01):
72
  """
73
  Perform Generative Perceptual Inference on the input image.
74
- :param image: Input image as a PIL image or numpy array.
75
  :param mode: Either 'increase confidence' or 'ReverseDiffuse'.
76
  :param model: Pretrained PyTorch model.
77
  :param n_iterations: Number of inference iterations.
@@ -91,20 +88,11 @@ def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0
91
  for _ in range(n_iterations):
92
  optimizer.zero_grad()
93
  output = model(image_tensor)
94
- probs = torch.nn.functional.softmax(output, dim=1)
95
 
96
  # Define inference loss based on mode
97
  if mode == "increase confidence":
98
- # Push away from the least likely classes
99
- _, least_likely_indices = torch.topk(probs, k=2, largest=False)
100
- losses = []
101
- for idx in least_likely_indices[0]:
102
- target = torch.full((1,), idx, dtype=torch.long, device=output.device)
103
- loss = torch.nn.CrossEntropyLoss()(output, target)
104
- losses.append(loss)
105
- loss = torch.stack(losses).mean() # Average the losses for the least likely classes
106
  elif mode == "ReverseDiffuse":
107
- # Push away from noisy versions
108
  noisy_image = image_tensor + torch.randn_like(image_tensor) * 0.1
109
  loss = torch.nn.functional.mse_loss(image_tensor, noisy_image)
110
  else:
@@ -125,29 +113,18 @@ def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0
125
 
126
  return processed_image, grad_image
127
 
128
-
129
- # # Create the Gradio interface
130
- # iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs="text") # Updated from gr.inputs.Image to gr.Image
131
-
132
- # # Create the Gradio interface
133
- # iface = gr.Interface(
134
- # fn=predict,
135
- # inputs=gr.Image(type="numpy"),
136
- # outputs=[gr.Textbox(label="Predicted Class"), gr.Gallery(label="Class Samples")],
137
- # title="ResNet-50 VGGFace2 Classifier"
138
- # )
139
-
140
  iface = gr.Interface(
141
  fn=lambda image, mode: simple_generative_inference(image, mode, model),
142
  inputs=[
143
- gr.Image(type="pil"), # Input image
144
- gr.Radio(["increase confidence", "ReverseDiffuse"], label="GPI Mode") # Mode selection
145
  ],
146
  outputs=[
147
  gr.Image(label="Processed Image"), # Processed image
148
  gr.Image(label="Gradient Visualization") # Gradient visualization
149
  ],
150
- title="Generative Perceptual Inference (GPI)"
151
  )
152
 
153
 
 
9
  import numpy as np
10
  import random
11
 
 
 
 
 
 
12
 
13
 
14
 
 
 
 
 
15
  # Load the model checkpoint from Hugging Face
16
  checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt")
17
 
 
58
  # return f"Predicted class: {predicted_class.item()}", sample_images_urls
59
 
60
 
61
+ import torch
62
+ import torch.nn.functional as F
63
+ from torchvision import transforms
64
+ from PIL import Image
65
+ import numpy as np
66
+
67
+ # Simple Generative Inference function
68
  def simple_generative_inference(image, mode, model, n_iterations=10, step_size=0.01):
69
  """
70
  Perform Generative Perceptual Inference on the input image.
71
+ :param image: Input image as a PIL image.
72
  :param mode: Either 'increase confidence' or 'ReverseDiffuse'.
73
  :param model: Pretrained PyTorch model.
74
  :param n_iterations: Number of inference iterations.
 
88
  for _ in range(n_iterations):
89
  optimizer.zero_grad()
90
  output = model(image_tensor)
 
91
 
92
  # Define inference loss based on mode
93
  if mode == "increase confidence":
94
+ loss = -torch.nn.functional.cross_entropy(output, output.softmax(dim=1).argmax(dim=1))
 
 
 
 
 
 
 
95
  elif mode == "ReverseDiffuse":
 
96
  noisy_image = image_tensor + torch.randn_like(image_tensor) * 0.1
97
  loss = torch.nn.functional.mse_loss(image_tensor, noisy_image)
98
  else:
 
113
 
114
  return processed_image, grad_image
115
 
116
+ # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
117
  iface = gr.Interface(
118
  fn=lambda image, mode: simple_generative_inference(image, mode, model),
119
  inputs=[
120
+ gr.Image(type="pil", label="Input Image"), # Input image
121
+ gr.Radio(["increase confidence", "ReverseDiffuse"], label="Inference Mode") # Mode selection
122
  ],
123
  outputs=[
124
  gr.Image(label="Processed Image"), # Processed image
125
  gr.Image(label="Gradient Visualization") # Gradient visualization
126
  ],
127
+ title="Generative Inference"
128
  )
129
 
130