ttoosi commited on
Commit
ea98391
·
1 Parent(s): b47a80f

model load fixed but image upload persists

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -5,12 +5,27 @@ from PIL import Image
5
  import requests
6
  from huggingface_hub import hf_hub_download
7
 
 
 
 
8
  # Load the model checkpoint from Hugging Face
9
  checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt")
10
 
11
  # Initialize the model
12
  model = models.resnet50()
13
- model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))['model']) # Force model to load on CPU
 
 
 
 
 
 
 
 
 
 
 
 
14
  model.eval()
15
 
16
  # Image preprocessing
@@ -18,11 +33,13 @@ preprocess = transforms.Compose([
18
  transforms.Resize(256),
19
  transforms.CenterCrop(224),
20
  transforms.ToTensor(),
21
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
  ])
23
 
24
  # Function to make predictions
25
  def predict(image):
 
 
26
  image = preprocess(image).unsqueeze(0) # Add batch dimension
27
  with torch.no_grad():
28
  output = model(image) # Perform inference on CPU
@@ -30,7 +47,7 @@ def predict(image):
30
  return f"Predicted class: {predicted_class.item()}"
31
 
32
  # Create the Gradio interface
33
- iface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="text")
34
 
35
  # Launch the interface
36
  iface.launch()
 
5
  import requests
6
  from huggingface_hub import hf_hub_download
7
 
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
  # Load the model checkpoint from Hugging Face
12
  checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt")
13
 
14
  # Initialize the model
15
  model = models.resnet50()
16
+ # change the num_classes to 500
17
+ model.fc = torch.nn.Linear(model.fc.in_features, 500)
18
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))['model']
19
+ # remove the prefix 'module.' from the keys
20
+ # remove the prefix 'model.' from the keys that have it
21
+ new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()}
22
+ new_state_dict = {k.replace('model.', ''): v for k, v in new_state_dict.items()}
23
+ new_state_dict = {k.replace('attacker.', ''): v for k, v in new_state_dict.items()}
24
+
25
+
26
+ print(new_state_dict.keys())
27
+ print('********************')
28
+ model.load_state_dict(new_state_dict, strict=False) # ignore Unexpected key(s) in state_dict: "normalizer.new_mean", "normalizer.new_std", "normalize.new_mean", "normalize.new_std".
29
  model.eval()
30
 
31
  # Image preprocessing
 
33
  transforms.Resize(256),
34
  transforms.CenterCrop(224),
35
  transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # vggface2
37
  ])
38
 
39
  # Function to make predictions
40
  def predict(image):
41
+ if isinstance(image, np.ndarray):
42
+ image = Image.fromarray(image) # Convert to PIL Image if i
43
  image = preprocess(image).unsqueeze(0) # Add batch dimension
44
  with torch.no_grad():
45
  output = model(image) # Perform inference on CPU
 
47
  return f"Predicted class: {predicted_class.item()}"
48
 
49
  # Create the Gradio interface
50
+ iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs="text") # Updated from gr.inputs.Image to gr.Image
51
 
52
  # Launch the interface
53
  iface.launch()