Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import os | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| token = "" | |
| # Download model from HuggingFace Hub and load it with PyTorch | |
| REPO_ID = "gonzalocordova/DistractionDetectorCNN" | |
| FILENAME = "resnet50_xlarge_resnet50_2023-05-19_21-41-02.pth" | |
| model_pth = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="model", use_auth_token=token) | |
| model = torch.load(model_pth, map_location=torch.device('cpu')) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.ToTensor(), | |
| ]) | |
| def predict_fn(image): | |
| """ | |
| This function will predict the class of an image | |
| :param image_path: The path of the image | |
| :param raw_output: If True, it will return the raw output of the model | |
| :return: Tuple (real class, predicted class, probability) | |
| """ | |
| image = image.convert('RGB') | |
| image = transform(image) | |
| image = image.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(image) | |
| probabilities = torch.exp(output) | |
| # focused (0) round probability | |
| focused_prob = round(probabilities[0][0].item(), 2) | |
| # distracted (1) round probability | |
| distracted_prob = round(probabilities[0][1].item(), 2) | |
| # return dictionary whose keys are labels and values are confidences | |
| return {'focused': focused_prob, 'distracted': distracted_prob} | |
| gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch(share=True) |