michaela299 commited on
Commit
b6311de
·
1 Parent(s): 34915b6

remove ui

Browse files
Files changed (1) hide show
  1. ui.py +0 -75
ui.py CHANGED
@@ -1,75 +0,0 @@
1
-
2
- import gradio as gr
3
- import numpy as np
4
- from model import BaselineCNN
5
- from data_pipeline import val_test_transform, IMAGE_SIZE
6
- import torch
7
-
8
- from datasets import load_dataset
9
- dataset = load_dataset("DScomp380/plant_village", split="train")
10
- CLASS_NAMES = dataset.features["label"].names
11
-
12
-
13
- #load the model
14
- CLASSES = 39
15
- model = None
16
-
17
-
18
- def predict(input_image):
19
- global model
20
- if model is None:
21
- model = BaselineCNN(num_classes=CLASSES)
22
- model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
23
- model.eval()
24
- #resize to models image size, convert to tensor, normalize values
25
- image_tensor = val_test_transform(input_image)
26
-
27
- #add new dimension at index 0 so each image has a batch size of atleast 1
28
- image_tensor = image_tensor.unsqueeze(0)
29
-
30
- #run inference
31
- with torch.no_grad():
32
- #pass the batch through the model
33
- output = model(image_tensor)
34
-
35
- #convert to probabilitiees
36
- probabilities = torch.nn.functional.softmax(output,dim=1)[0]
37
-
38
- numPredictionsToShow = 10
39
-
40
- #get the top 5 predictions
41
- topProbs, TopClassIndicies = torch.topk(probabilities, numPredictionsToShow)
42
- #returns 5 largest probabilities
43
-
44
- #create the output dictionary
45
- result = {}
46
- for rank in range(numPredictionsToShow):#loop through top 5
47
- classIndex = TopClassIndicies[rank].item()#get the int value from the tensor at index rank
48
- className = CLASS_NAMES[classIndex]#get human readable class name
49
- probabilityValue = topProbs[rank].item()#convert prob from tensor to python float
50
-
51
- result[className] = probabilityValue
52
-
53
- return result
54
-
55
-
56
- with gr.Blocks(title="Plant Disease Classifier") as app:
57
- gr.Markdown("# Plant Disease Classification")
58
- gr.Markdown("Upload an image of a plant leaf to classify its disease.")
59
-
60
- with gr.Row():
61
- image_input = gr.Image(type="pil", label="Upload Leaf Image")
62
- label_output = gr.Label(label="Predicted Disease")
63
-
64
- gr.Examples(
65
- examples =[], inputs=image_input)
66
-
67
- submit_btn = gr.Button("Submit")
68
- submit_btn.click(fn=predict, inputs=image_input, outputs=label_output)
69
-
70
- #fn=predict,
71
- # inputs=gr.Image(type="pil"),
72
- # outputs=gr.Label(num_top_classes=3))
73
-
74
-
75
-