import gradio as gr import torch import os from torchvision import transforms from huggingface_hub import hf_hub_download token = "" # Download model from HuggingFace Hub and load it with PyTorch REPO_ID = "gonzalocordova/DistractionDetectorCNN" FILENAME = "resnet50_xlarge_resnet50_2023-05-19_21-41-02.pth" model_pth = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="model", use_auth_token=token) model = torch.load(model_pth, map_location=torch.device('cpu')) model.eval() transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), ]) def predict_fn(image): """ This function will predict the class of an image :param image_path: The path of the image :param raw_output: If True, it will return the raw output of the model :return: Tuple (real class, predicted class, probability) """ image = image.convert('RGB') image = transform(image) image = image.unsqueeze(0) with torch.no_grad(): output = model(image) probabilities = torch.exp(output) # focused (0) round probability focused_prob = round(probabilities[0][0].item(), 2) # distracted (1) round probability distracted_prob = round(probabilities[0][1].item(), 2) # return dictionary whose keys are labels and values are confidences return {'focused': focused_prob, 'distracted': distracted_prob} gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch(share=True)