File size: 1,605 Bytes
57a5006
 
 
 
 
 
d20dc7b
57a5006
fcfe8bc
57a5006
02ca130
fcfe8bc
115c771
57a5006
 
 
 
 
 
 
 
14cf308
57a5006
 
 
 
 
 
 
 
 
 
 
 
 
 
88ac1ec
13fcf41
88ac1ec
13fcf41
88ac1ec
57a5006
14cf308
13fcf41
57a5006
fb198fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
import torch
import os
from torchvision import transforms
from huggingface_hub import hf_hub_download

token = ""

# Download model from HuggingFace Hub and load it with PyTorch
REPO_ID = "gonzalocordova/DistractionDetectorCNN"
FILENAME = "resnet50_xlarge_resnet50_2023-05-19_21-41-02.pth"
model_pth = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="model", use_auth_token=token)
model = torch.load(model_pth, map_location=torch.device('cpu'))
model.eval()

transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
        ])


def predict_fn(image):
        """
        This function will predict the class of an image
        :param image_path: The path of the image
        :param raw_output: If True, it will return the raw output of the model
        :return: Tuple (real class, predicted class, probability)
        """

        image = image.convert('RGB')
        image = transform(image)
        image = image.unsqueeze(0)
        
        with torch.no_grad():
            output = model(image)
        
        probabilities = torch.exp(output)
        # focused (0) round probability
        focused_prob = round(probabilities[0][0].item(), 2)
        # distracted (1) round probability
        distracted_prob = round(probabilities[0][1].item(), 2)

        # return dictionary whose keys are labels and values are confidences
        return {'focused': focused_prob, 'distracted': distracted_prob}

gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch(share=True)