Omarrr7's picture
Upload app.py
eb1ac9b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import gradio as gr
from PIL import Image
import numpy as np
import pandas as pd
import plotly.express as px
from datasets import load_dataset
from CNN_model import BasicCNN
ds = load_dataset("DScomp380/plant_village")
labels = ds['train']['label']
label_names = ds['train'].features['label'].names
print(len(label_names))
print(label_names)
#resolution required by model
resize = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
#model
model=BasicCNN(num_classes=39)
model.load_state_dict(torch.load("final_model_final_model.pt", map_location="cpu"))
model.eval()
def detect_disease(image, temp):
if image is None:
return "No image passed"
image = resize(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image)
probs = F.softmax(outputs / temp , dim=1)
top5 = torch.topk(probs, 5)
top_probs = top5.values[0].tolist()
top_probs = [round(float(p), 4) for p in top_probs]
top_indices = top5.indices[0].tolist()
top_diseases = [label_names[i] for i in top_indices]
df=pd.DataFrame({
'Disease' : top_diseases,
'pr' : top_probs,
})
visual = px.bar(df, x='Disease', y='pr', color='Disease', text='pr')
visual.update_layout(title='Disease Probability', yaxis_title='Probability', width=700, height=500)
return {label_names[i.item()]: float(top5.values[0][idx])
for idx, i in enumerate(top5.indices[0])}, visual
demo = gr.Interface(
fn= detect_disease,
inputs= [
gr.Image( type= 'pil'),
gr.Slider(0.5, 2.0, value= 1.0, label = "Prediction Sharpness"),
],
outputs=[
gr.Label(num_top_classes=5, label= "Diagnosis"),
gr.Plot(label= "Top 5 Possible Diseases"),
],
title= "Disease Classifier",
description = "Upload your leaf image to get diagnosis.",
examples=[["examples/soybean.jpg",1.0],["examples/apple_image.jpg",1.0],
["examples/apple_scab.jpg", 1.0],["examples/cherry_healthy.jpg",1.0],
["examples/Squash.jpg", 1.0],["examples/tomato.jpg",1.0],
["examples/peach.jpg",1.0],[ "examples/grape.jpg", 1.0]], #launch example leaf images
flagging_dir= "flagged",
flagging_mode = "manual",
flagging_options=["Wrong disease", "Low Confidence", "Other"]
)
demo.launch()