import torch import gradio as gr from torchvision import models, transforms from PIL import Image import requests from huggingface_hub import hf_hub_download from PIL import Image import numpy as np from datasets import load_dataset import random # Load the VGGFace2 dataset using Hugging Face's datasets library ds = load_dataset("chronopt-research/cropped-vggface2-224", split="validation") # Load the model checkpoint from Hugging Face checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt") # Initialize the model model = models.resnet50() # change the num_classes to 500 model.fc = torch.nn.Linear(model.fc.in_features, 500) checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))['model'] # remove the prefix 'module.' from the keys # remove the prefix 'model.' from the keys that have it new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} new_state_dict = {k.replace('model.', ''): v for k, v in new_state_dict.items()} new_state_dict = {k.replace('attacker.', ''): v for k, v in new_state_dict.items()} print(new_state_dict.keys()) print('********************') model.load_state_dict(new_state_dict, strict=False) # ignore Unexpected key(s) in state_dict: "normalizer.new_mean", "normalizer.new_std", "normalize.new_mean", "normalize.new_std". model.eval() # Image preprocessing preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # vggface2 ]) # Function to make predictions def predict(image): if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convert to PIL Image if i image = preprocess(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = model(image) # Perform inference on CPU _, predicted_class = output.max(1) # Fetch 9 random samples from the predicted class class_samples = ds.filter(lambda example: example['label'] == predicted_class.item())['train'] sample_images = random.sample(list(class_samples), min(len(class_samples), 9)) sample_images_urls = [sample['image'] for sample in sample_images] return f"Predicted class: {predicted_class.item()}", sample_images_urls # # Create the Gradio interface # iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs="text") # Updated from gr.inputs.Image to gr.Image # Create the Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="numpy"), outputs=[gr.Textbox(label="Predicted Class"), gr.Gallery(label="Class Samples")], title="ResNet-50 VGGFace2 Classifier" ) # Launch the interface iface.launch()