File size: 2,163 Bytes
34915b6
 
 
 
 
10dc028
34915b6
 
 
e79c75c
34915b6
 
 
 
 
 
 
 
 
a9b9cc0
 
34915b6
a9b9cc0
 
34915b6
a9b9cc0
34915b6
a9b9cc0
 
 
 
 
 
 
34915b6
a9b9cc0
 
e7ecfd9
10dc028
 
 
c1529fa
10dc028
34915b6
10dc028
34915b6
 
 
 
 
 
 
27fcfa2
34915b6
91a1dae
 
34915b6
 
 
 
 
 
 
26fbfd4
 
34915b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import numpy as np
from model import BaselineCNN
from data_pipeline import val_test_transform, IMAGE_SIZE
import torch 
import pandas as pd

from datasets import load_dataset
dataset = load_dataset("DScomp380/plant_village", split="train")
CLASS_NAMES = dataset.features["label"].names


#load the model
CLASSES = 39
model = BaselineCNN(num_classes=CLASSES)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()

def predict(input_image):
    # 1. Transform the image (resize, normalize, etc.)
    processed_image = val_test_transform(input_image)

    # 2. Add a batch dimension because the model expects [batch, channels, height, width]
    processed_image = processed_image.unsqueeze(0)

    # 3. Run the image through the model
    with torch.no_grad():
        model_output = model(processed_image)

    # 4. Convert raw model scores into probabilities
    probabilities = torch.nn.functional.softmax(model_output, dim=1)[0]

    # 5. Choose how many results you want to show
    number_of_predictions_to_show = 5

    # 6. Get the top-k highest probability classes
    top_probabilities, top_class_indices = torch.topk(probabilities, number_of_predictions_to_show)

    # build dataframe
    df = pd.DataFrame({
        "Class": [CLASS_NAMES[i.item()] for i in top_class_indices],
        "Probability": [f"{p.item() * 100:.2f}%" for p in top_probabilities]  # multiply by 100 and round
    })

    return df

with gr.Blocks(title="Plant Disease Classifier") as app:
    gr.Markdown("# Plant Disease Classification")
    gr.Markdown("Upload an image of a plant leaf to classify its disease.")

    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Leaf Image")
        label_output = gr.DataFrame(headers=["Class", "Probability"], type="pandas")
    
    #gr.Examples(
    #    examples =[], inputs=image_input)
    
    submit_btn = gr.Button("Submit")
    submit_btn.click(fn=predict, inputs=image_input, outputs=label_output)

    #fn=predict, 
    # inputs=gr.Image(type="pil"),
    # outputs=gr.Label(num_top_classes=3))

if __name__ == "__main__":
    app.launch()