gonzalocordova's picture
Update app.py
d20dc7b
import gradio as gr
import torch
import os
from torchvision import transforms
from huggingface_hub import hf_hub_download
token = ""
# Download model from HuggingFace Hub and load it with PyTorch
REPO_ID = "gonzalocordova/DistractionDetectorCNN"
FILENAME = "resnet50_xlarge_resnet50_2023-05-19_21-41-02.pth"
model_pth = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="model", use_auth_token=token)
model = torch.load(model_pth, map_location=torch.device('cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
])
def predict_fn(image):
"""
This function will predict the class of an image
:param image_path: The path of the image
:param raw_output: If True, it will return the raw output of the model
:return: Tuple (real class, predicted class, probability)
"""
image = image.convert('RGB')
image = transform(image)
image = image.unsqueeze(0)
with torch.no_grad():
output = model(image)
probabilities = torch.exp(output)
# focused (0) round probability
focused_prob = round(probabilities[0][0].item(), 2)
# distracted (1) round probability
distracted_prob = round(probabilities[0][1].item(), 2)
# return dictionary whose keys are labels and values are confidences
return {'focused': focused_prob, 'distracted': distracted_prob}
gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch(share=True)