Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| from datasets import load_dataset | |
| from CNN_model import BasicCNN | |
| ds = load_dataset("DScomp380/plant_village") | |
| labels = ds['train']['label'] | |
| label_names = ds['train'].features['label'].names | |
| print(len(label_names)) | |
| print(label_names) | |
| #resolution required by model | |
| resize = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| #model | |
| model=BasicCNN(num_classes=39) | |
| model.load_state_dict(torch.load("final_model_final_model.pt", map_location="cpu")) | |
| model.eval() | |
| def detect_disease(image, temp): | |
| if image is None: | |
| return "No image passed" | |
| image = resize(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| probs = F.softmax(outputs / temp , dim=1) | |
| top5 = torch.topk(probs, 5) | |
| top_probs = top5.values[0].tolist() | |
| top_probs = [round(float(p), 4) for p in top_probs] | |
| top_indices = top5.indices[0].tolist() | |
| top_diseases = [label_names[i] for i in top_indices] | |
| df=pd.DataFrame({ | |
| 'Disease' : top_diseases, | |
| 'pr' : top_probs, | |
| }) | |
| visual = px.bar(df, x='Disease', y='pr', color='Disease', text='pr') | |
| visual.update_layout(title='Disease Probability', yaxis_title='Probability', width=700, height=500) | |
| return {label_names[i.item()]: float(top5.values[0][idx]) | |
| for idx, i in enumerate(top5.indices[0])}, visual | |
| demo = gr.Interface( | |
| fn= detect_disease, | |
| inputs= [ | |
| gr.Image( type= 'pil'), | |
| gr.Slider(0.5, 2.0, value= 1.0, label = "Prediction Sharpness"), | |
| ], | |
| outputs=[ | |
| gr.Label(num_top_classes=5, label= "Diagnosis"), | |
| gr.Plot(label= "Top 5 Possible Diseases"), | |
| ], | |
| title= "Disease Classifier", | |
| description = "Upload your leaf image to get diagnosis.", | |
| examples=[["examples/soybean.jpg",1.0],["examples/apple_image.jpg",1.0], | |
| ["examples/apple_scab.jpg", 1.0],["examples/cherry_healthy.jpg",1.0], | |
| ["examples/Squash.jpg", 1.0],["examples/tomato.jpg",1.0], | |
| ["examples/peach.jpg",1.0],[ "examples/grape.jpg", 1.0]], #launch example leaf images | |
| flagging_dir= "flagged", | |
| flagging_mode = "manual", | |
| flagging_options=["Wrong disease", "Low Confidence", "Other"] | |
| ) | |
| demo.launch() | |