Spaces:
Build error
Build error
| import torch | |
| import gradio as gr | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import requests | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| import numpy as np | |
| import random | |
| from datasets import load_dataset | |
| from datasets import DatasetDict | |
| ds = DatasetDict({ | |
| "validation": load_dataset("chronopt-research/cropped-vggface2-224", split="validation"), | |
| }) | |
| # Load the VGGFace2 dataset using Hugging Face's datasets library | |
| # ds = load_dataset("chronopt-research/cropped-vggface2-224", split="validation") | |
| # Load the model checkpoint from Hugging Face | |
| checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt") | |
| # Initialize the model | |
| model = models.resnet50() | |
| # change the num_classes to 500 | |
| model.fc = torch.nn.Linear(model.fc.in_features, 500) | |
| checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))['model'] | |
| # remove the prefix 'module.' from the keys | |
| # remove the prefix 'model.' from the keys that have it | |
| new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} | |
| new_state_dict = {k.replace('model.', ''): v for k, v in new_state_dict.items()} | |
| new_state_dict = {k.replace('attacker.', ''): v for k, v in new_state_dict.items()} | |
| print(new_state_dict.keys()) | |
| print('********************') | |
| model.load_state_dict(new_state_dict, strict=False) # ignore Unexpected key(s) in state_dict: "normalizer.new_mean", "normalizer.new_std", "normalize.new_mean", "normalize.new_std". | |
| model.eval() | |
| # Image preprocessing | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # vggface2 | |
| ]) | |
| # Function to make predictions | |
| def predict(image): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) # Convert to PIL Image if i | |
| image = preprocess(image).unsqueeze(0) # Add batch dimension | |
| with torch.no_grad(): | |
| output = model(image) # Perform inference on CPU | |
| _, predicted_class = output.max(1) | |
| # Fetch 9 random samples from the predicted class | |
| class_samples = ds.filter(lambda example: example['label'] == predicted_class.item()) | |
| sample_images = random.sample(list(class_samples), min(len(class_samples), 9)) | |
| sample_images_urls = [sample['image'] for sample in sample_images] | |
| return f"Predicted class: {predicted_class.item()}", sample_images_urls | |
| # # Create the Gradio interface | |
| # iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs="text") # Updated from gr.inputs.Image to gr.Image | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=[gr.Textbox(label="Predicted Class"), gr.Gallery(label="Class Samples")], | |
| title="ResNet-50 VGGFace2 Classifier" | |
| ) | |
| # Launch the interface | |
| iface.launch() |