ttoosi commited on
Commit
26af353
·
1 Parent(s): 1762704
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import requests
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ # Load the model checkpoint from Hugging Face
9
+ checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filename="100_checkpoint.pt")
10
+
11
+ # Initialize the model
12
+ model = models.resnet50()
13
+ model.load_state_dict(torch.load(checkpoint_path))
14
+ model.eval()
15
+
16
+ # Image preprocessing
17
+ preprocess = transforms.Compose([
18
+ transforms.Resize(256),
19
+ transforms.CenterCrop(224),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
+ ])
23
+
24
+ # Function to make predictions
25
+ def predict(image):
26
+ image = preprocess(image).unsqueeze(0) # Add batch dimension
27
+ with torch.no_grad():
28
+ output = model(image)
29
+ _, predicted_class = output.max(1)
30
+ return f"Predicted class: {predicted_class.item()}"
31
+
32
+ # Create the Gradio interface
33
+ iface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="text")
34
+
35
+ # Launch the interface
36
+ iface.launch()