|
|
import gradio as gr |
|
|
import os |
|
|
import torch |
|
|
import torchvision |
|
|
from modeling import EffNetPlantDiseaseClassification |
|
|
from timeit import default_timer as timer |
|
|
from typing import Dict, Tuple |
|
|
|
|
|
model = EffNetPlantDiseaseClassification.from_pretrained("BrandonFors/effnetv2_s_plant_disease") |
|
|
|
|
|
|
|
|
effnetv2_s_weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT |
|
|
effnetv2_s_auto_transforms = effnetv2_s_weights.transforms() |
|
|
|
|
|
|
|
|
class_names = ['Apple___Apple_scab', |
|
|
'Apple___Black_rot', |
|
|
'Apple___Cedar_apple_rust', |
|
|
'Apple___healthy', |
|
|
'Blueberry___healthy', |
|
|
'Cherry_(including_sour)___Powdery_mildew', |
|
|
'Cherry_(including_sour)___healthy', |
|
|
'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', |
|
|
'Corn_(maize)___Common_rust_', |
|
|
'Corn_(maize)___Northern_Leaf_Blight', |
|
|
'Corn_(maize)___healthy', |
|
|
'Grape___Black_rot', |
|
|
'Grape___Esca_(Black_Measles)', |
|
|
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', |
|
|
'Grape___healthy', |
|
|
'Orange___Haunglongbing_(Citrus_greening)', |
|
|
'Peach___Bacterial_spot', |
|
|
'Peach___healthy', |
|
|
'Pepper,_bell___Bacterial_spot', |
|
|
'Pepper,_bell___healthy', |
|
|
'Potato___Early_blight', |
|
|
'Potato___Late_blight', |
|
|
'Potato___healthy', |
|
|
'Raspberry___healthy', |
|
|
'Soybean___healthy', |
|
|
'Squash___Powdery_mildew', |
|
|
'Strawberry___Leaf_scorch', |
|
|
'Strawberry___healthy', |
|
|
'Tomato___Bacterial_spot', |
|
|
'Tomato___Early_blight', |
|
|
'Tomato___Late_blight', |
|
|
'Tomato___Leaf_Mold', |
|
|
'Tomato___Septoria_leaf_spot', |
|
|
'Tomato___Spider_mites Two-spotted_spider_mite', |
|
|
'Tomato___Target_Spot', |
|
|
'Tomato___Tomato_Yellow_Leaf_Curl_Virus', |
|
|
'Tomato___Tomato_mosaic_virus', |
|
|
'Tomato___healthy'] |
|
|
|
|
|
def predict(img): |
|
|
|
|
|
start_time = timer() |
|
|
|
|
|
img = effnetv2_s_auto_transforms(img).unsqueeze(0) |
|
|
|
|
|
model.eval() |
|
|
with torch.inference_mode(): |
|
|
pred_logits = model(img)["logits"] |
|
|
pred_probs = torch.softmax(pred_logits, dim=1) |
|
|
|
|
|
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} |
|
|
|
|
|
|
|
|
pred_time = round(timer() - start_time, 4) |
|
|
|
|
|
return pred_labels_and_probs, pred_time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title = "EffNetV2_S_PlantDisease" |
|
|
description = "EffNetV2-S trained on Plant Village Dataset" |
|
|
article = "Personal project" |
|
|
|
|
|
|
|
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=[gr.Label(num_top_classes=3, label="Predictions"), |
|
|
gr.Number(label="Prediction time (s)")], |
|
|
examples=example_list, |
|
|
title=title, |
|
|
description=description, |
|
|
article=article) |
|
|
|
|
|
demo.launch(debug=False) |
|
|
|