File size: 2,766 Bytes
26af353
 
 
 
 
 
 
ea98391
 
 
8cd11f9
 
 
 
ee9d2e9
8cd11f9
 
26af353
 
 
 
 
ea98391
 
 
 
 
 
 
 
 
 
 
 
 
26af353
 
 
 
 
 
 
ea98391
26af353
 
 
 
ea98391
 
26af353
 
741020d
26af353
8cd11f9
 
 
 
 
 
 
 
 
 
26af353
 
8cd11f9
 
 
 
 
 
26af353
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import gradio as gr
from torchvision import models, transforms
from PIL import Image
import requests
from huggingface_hub import hf_hub_download

from PIL import Image
import numpy as np

from datasets import load_dataset
import random

# Load the VGGFace2 dataset using Hugging Face's datasets library
ds = load_dataset("chronopt-research/cropped-vggface2-224", split="validation")


# Load the model checkpoint from Hugging Face
checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt")

# Initialize the model
model = models.resnet50()
# change the num_classes to 500
model.fc = torch.nn.Linear(model.fc.in_features, 500)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))['model']
# remove the prefix 'module.' from the keys
# remove the prefix 'model.' from the keys that have it
new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()}
new_state_dict = {k.replace('model.', ''): v for k, v in new_state_dict.items()}
new_state_dict = {k.replace('attacker.', ''): v for k, v in new_state_dict.items()}


print(new_state_dict.keys())
print('********************')
model.load_state_dict(new_state_dict, strict=False)  # ignore Unexpected key(s) in state_dict: "normalizer.new_mean", "normalizer.new_std", "normalize.new_mean", "normalize.new_std". 
model.eval()

# Image preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # vggface2
])

# Function to make predictions
def predict(image):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)  # Convert to PIL Image if i
    image = preprocess(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        output = model(image)  # Perform inference on CPU
    _, predicted_class = output.max(1)
    # Fetch 9 random samples from the predicted class
    class_samples = ds.filter(lambda example: example['label'] == predicted_class.item())['train']
    sample_images = random.sample(list(class_samples), min(len(class_samples), 9))
    
    sample_images_urls = [sample['image'] for sample in sample_images]
    
    return f"Predicted class: {predicted_class.item()}", sample_images_urls

# # Create the Gradio interface
# iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs="text")  # Updated from gr.inputs.Image to gr.Image

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy"),
    outputs=[gr.Textbox(label="Predicted Class"), gr.Gallery(label="Class Samples")],
    title="ResNet-50 VGGFace2 Classifier"
)

# Launch the interface
iface.launch()