michaela299 commited on
Commit
34915b6
·
1 Parent(s): 1c9a477

restarting

Browse files
__pycache__/data_pipeline.cpython-310.pyc ADDED
Binary file (3.98 kB). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (1.29 kB). View file
 
__pycache__/ui.cpython-310.pyc CHANGED
Binary files a/__pycache__/ui.cpython-310.pyc and b/__pycache__/ui.cpython-310.pyc differ
 
app.py CHANGED
@@ -1,4 +1,56 @@
1
- from ui import app # import the Blocks object from ui.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
- app.launch(ssr_mode=False, debug=False) # disable SSR to avoid hot-reload
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from model import BaselineCNN
4
+ from data_pipeline import val_test_transform, IMAGE_SIZE
5
+ import torch
6
+
7
+ from datasets import load_dataset
8
+ dataset = load_dataset("DScomp380/plant_village", split="train")
9
+ CLASS_NAMES = dataset.features["label"].names
10
+
11
+
12
+ #load the model
13
+ CLASSES = 39
14
+ model = BaselineCNN(num_classes=CLASSES)
15
+ model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
16
+ model.eval()
17
+
18
+ def predict(input_image):
19
+ #apply the transform
20
+ image_tensor = val_test_transform(input_image)
21
+
22
+ #add batch dimension
23
+ image_tensor = image_tensor.unsqueeze(0)
24
+
25
+ #run inference
26
+ with torch.no_grad():
27
+ output = model(image_tensor)
28
+
29
+ #get probabilitiees
30
+ probabilities = torch.nn.functional.softmax(output,dim=1)[0]
31
+
32
+ #create the output dictionary
33
+ result = {CLASS_NAMES[i]: probabilities[i].item() for i in range(len(probabilities))}
34
+ return result
35
+
36
+
37
+ with gr.Blocks(title="Plant Disease Classifier") as app:
38
+ gr.Markdown("# Plant Disease Classification")
39
+ gr.Markdown("Upload an image of a plant leaf to classify its disease.")
40
+
41
+ with gr.Row():
42
+ image_input = gr.Image(type="pil", label="Upload Leaf Image")
43
+ label_output = gr.Label(label="Predicted Disease")
44
+
45
+ gr.Examples(
46
+ examples =[], inputs=image_input)
47
+
48
+ submit_btn = gr.Button("Submit")
49
+ submit_btn.click(fn=predict, inputs=image_input, outputs=label_output)
50
+
51
+ #fn=predict,
52
+ # inputs=gr.Image(type="pil"),
53
+ # outputs=gr.Label(num_top_classes=3))
54
 
55
  if __name__ == "__main__":
56
+ app.launch()