gonzalocordova's picture
fix: predict_fn return bug
14cf308
raw
history blame
1.48 kB
import gradio as gr
import torch
import os
from torchvision import transforms
from huggingface_hub import hf_hub_download
token = "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"
# Download model from HuggingFace Hub and load it with PyTorch
REPO_ID = "gonzalocordova/DistractionDetectorCNN"
FILENAME = "model_with_extended_dataset_resnet50_2023-03-28_10-42-07.pth"
model_pth = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="model", use_auth_token=token)
model = torch.load(model_pth, map_location=torch.device('cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
])
def predict_fn(image):
"""
This function will predict the class of an image
:param image_path: The path of the image
:param raw_output: If True, it will return the raw output of the model
:return: Tuple (real class, predicted class, probability)
"""
image = image.convert('RGB')
image = transform(image)
image = image.unsqueeze(0)
with torch.no_grad():
output = model(image)
probabilities = torch.exp(output)
top_p, top_class = probabilities.topk(1, dim=1)
# return dictionary whose keys are labels and values are confidences
return {str(top_class.item()), str(top_p.item())}
gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch()