BrandonFors's picture
uploading files to space
8e1235e
import gradio as gr
import os
import torch
import torchvision
from modeling import EffNetPlantDiseaseClassification
from timeit import default_timer as timer
from typing import Dict, Tuple
# load the model from hugging face repo
model = EffNetPlantDiseaseClassification.from_pretrained("BrandonFors/effnetv2_s_plant_disease")
# load transforms for effnetv2-s from torchvision
effnetv2_s_weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
effnetv2_s_auto_transforms = effnetv2_s_weights.transforms()
# define class names
class_names = ['Apple___Apple_scab',
'Apple___Black_rot',
'Apple___Cedar_apple_rust',
'Apple___healthy',
'Blueberry___healthy',
'Cherry_(including_sour)___Powdery_mildew',
'Cherry_(including_sour)___healthy',
'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
'Corn_(maize)___Common_rust_',
'Corn_(maize)___Northern_Leaf_Blight',
'Corn_(maize)___healthy',
'Grape___Black_rot',
'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
'Grape___healthy',
'Orange___Haunglongbing_(Citrus_greening)',
'Peach___Bacterial_spot',
'Peach___healthy',
'Pepper,_bell___Bacterial_spot',
'Pepper,_bell___healthy',
'Potato___Early_blight',
'Potato___Late_blight',
'Potato___healthy',
'Raspberry___healthy',
'Soybean___healthy',
'Squash___Powdery_mildew',
'Strawberry___Leaf_scorch',
'Strawberry___healthy',
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy']
# create predict function
def predict(img):
# start the timer to time the prediction
start_time = timer()
# Transform the input image for use
img = effnetv2_s_auto_transforms(img).unsqueeze(0)
# Put the model in eval mode to make prediction
model.eval()
with torch.inference_mode():
pred_logits = model(img)["logits"]
pred_probs = torch.softmax(pred_logits, dim=1)
# create a prediction label and prediction probabilities dict
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
# calc pred time
pred_time = round(timer() - start_time, 4)
# return the pred dict and time
return pred_labels_and_probs, pred_time
# Gradio App
# Create title, description, and article
title = "EffNetV2_S_PlantDisease"
description = "EffNetV2-S trained on Plant Village Dataset"
article = "Personal project"
# Create example list
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio Demo
demo = gr.Interface(fn=predict, # maps inputs to outputs
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
gr.Number(label="Prediction time (s)")],
examples=example_list,
title=title,
description=description,
article=article)
# launch the demo
demo.launch(debug=False)