Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import os | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| token = "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" | |
| # Download model from HuggingFace Hub and load it with PyTorch | |
| REPO_ID = "gonzalocordova/DistractionDetectorCNN" | |
| FILENAME = "model_with_extended_dataset_resnet50_2023-03-28_10-42-07.pth" | |
| model_pth = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="model", use_auth_token=token) | |
| model = torch.load(model_pth, map_location=torch.device('cpu')) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.ToTensor(), | |
| ]) | |
| def predict_fn(image): | |
| """ | |
| This function will predict the class of an image | |
| :param image_path: The path of the image | |
| :param raw_output: If True, it will return the raw output of the model | |
| :return: Tuple (real class, predicted class, probability) | |
| """ | |
| image = image.convert('RGB') | |
| image = transform(image) | |
| image = image.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(image) | |
| probabilities = torch.exp(output) | |
| top_p, top_class = probabilities.topk(1, dim=1) | |
| # return dictionary whose keys are labels and values are confidences | |
| return {str(top_class.item()), str(top_p.item())} | |
| gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch() |