Spaces:
Runtime error
Runtime error
File size: 1,605 Bytes
57a5006 d20dc7b 57a5006 fcfe8bc 57a5006 02ca130 fcfe8bc 115c771 57a5006 14cf308 57a5006 88ac1ec 13fcf41 88ac1ec 13fcf41 88ac1ec 57a5006 14cf308 13fcf41 57a5006 fb198fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import gradio as gr
import torch
import os
from torchvision import transforms
from huggingface_hub import hf_hub_download
token = ""
# Download model from HuggingFace Hub and load it with PyTorch
REPO_ID = "gonzalocordova/DistractionDetectorCNN"
FILENAME = "resnet50_xlarge_resnet50_2023-05-19_21-41-02.pth"
model_pth = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="model", use_auth_token=token)
model = torch.load(model_pth, map_location=torch.device('cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
])
def predict_fn(image):
"""
This function will predict the class of an image
:param image_path: The path of the image
:param raw_output: If True, it will return the raw output of the model
:return: Tuple (real class, predicted class, probability)
"""
image = image.convert('RGB')
image = transform(image)
image = image.unsqueeze(0)
with torch.no_grad():
output = model(image)
probabilities = torch.exp(output)
# focused (0) round probability
focused_prob = round(probabilities[0][0].item(), 2)
# distracted (1) round probability
distracted_prob = round(probabilities[0][1].item(), 2)
# return dictionary whose keys are labels and values are confidences
return {'focused': focused_prob, 'distracted': distracted_prob}
gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch(share=True) |