File size: 3,716 Bytes
251af5e
4f0b287
251af5e
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0b287
251af5e
 
 
 
 
4f0b287
251af5e
 
 
 
 
 
 
 
 
 
 
c00506e
251af5e
 
 
 
 
 
 
 
 
4f0b287
 
cc6218f
 
 
4f0b287
07f0645
4f0b287
 
 
 
cc6218f
 
 
 
07f0645
 
 
 
cc6218f
4f0b287
 
 
cc6218f
 
4f0b287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eaddef
cc6218f
 
 
 
 
251af5e
cc6218f
 
251af5e
cc6218f
 
 
 
 
 
 
 
 
 
251af5e
cc6218f
 
251af5e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
### 1. Imports and class names setup ### 
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch

from model import create_mobilenet_model
from timeit import default_timer as timer
from typing import Tuple, Dict

# Setup class names
class_names = ['bacterial', 'blast', 'brownspot', 'tungro']

### 2. Model and transforms preparation ###

mobilenet, manual_transforms = create_mobilenet_model(
    num_classes=4
)

mobilenet.load_state_dict(
    torch.load(
        f="mobilenet_5_epochs.pth",
        map_location=torch.device("cpu"),
    )
)

### 3. Predict function ###
def predict(img) -> Tuple[Dict, float]:
    start_time = timer()
    
    img = manual_transforms(img).unsqueeze(0)
    
    mobilenet.eval()
    with torch.inference_mode():
        pred_probs = torch.softmax(mobilenet(img), dim=1)
    
    pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
    
    pred_time = round(timer() - start_time, 5)
    
    return pred_labels_and_probs, pred_time

### 4. Gradio app ###

# Create a Blocks app (only one!)
with gr.Blocks() as gradio_app:
    
    gr.HTML(
        """
        <h1 style='text-align: center'>
        Rice Disease Classification - MobileNet Model
        </h1>
        """
    )
    
    gr.HTML(
        """
        <h3 style='text-align: center'>
        Follow me for more!
<!--     <a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | -->
        <a href='https://github.com/ExplorerGumel' target='_blank'>Github</a> | 
        <a href='https://www.linkedin.com/in/munzali-alhassan/' target='_blank'>Linkedin</a>  | 
<!--         <a href='https://www.huggingface.co/kadirnar/' target='_blank'>HuggingFace</a> -->
        </h3>
        """
    )
    
    with gr.Row():
        with gr.Column():
            image = gr.Image(type="pil", label="Upload Image")
            infer = gr.Button(value="Predict")
            
            # Examples linked to the input component 'image'
            example_list = [["examples/" + example] for example in os.listdir("examples")]
            gr.Examples(
                examples=example_list,
                inputs=[image]  # Pass the actual input component
            )

        with gr.Column():
            label = gr.Label(num_top_classes=4, label="Predictions")
            pred_time = gr.Number(label="Prediction Time (s)")
        
        infer.click(
            fn=predict,
            inputs=[image],
            outputs=[label, pred_time]
        )

# Launch the app
gradio_app.launch(debug=True, share=True)
# gradio_app.launch(debug=True, share=True)
# # Create title, description and article strings
# title = "RICE DISEASES CLASSIFICATION"
# description = "A MobileNetV2 feature extractor computer vision model to classify images of Rice diseases."
# article = "Created by Munzali Alhassan."

# # Create examples list from "examples/" directory
# example_list = [["examples/" + example] for example in os.listdir("examples")]

# # Create the Gradio demo
# demo = gr.Interface(fn=predict, # mapping function from input to output
#                     inputs=gr.Image(type="pil"), # what are the inputs?
#                     outputs=[gr.Label(num_top_classes=4, label="Predictions"), # what are the outputs?
#                              gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
#                     # Create examples list from "examples/" directory
#                     examples=example_list, 
#                     title=title,
#                     description=description,
#                     article=article)

# # Launch the demo!
# demo.launch(share=True)