Spaces:
Build error
Build error
File size: 5,869 Bytes
a02d467 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import numpy as np
# Define your model class (same as during training)
class Plant_Disease_VGG16(nn.Module):
def __init__(self):
super().__init__()
self.network = models.vgg16(pretrained=True)
for param in list(self.network.features.parameters())[:-5]:
param.requires_grad = False
num_ftrs = self.network.classifier[-1].in_features
self.network.classifier[-1] = nn.Linear(num_ftrs, 38) # 38 classes
def forward(self, xb):
return self.network(xb)
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Plant_Disease_VGG16()
model.load_state_dict(torch.load("model/vgg_model_ft.pth", map_location=device))
model.to(device)
model.eval()
# Class labels with plant and disease information
class_labels = [
'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_',
'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot',
'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy',
'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy',
'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight',
'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy',
'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy',
'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus',
'Tomato___healthy'
]
# Enhanced preprocessing
def preprocess_image(image):
"""Add noise reduction, sharpening, and background removal"""
# Convert to numpy array for processing
img = np.array(image)
# Simple background removal (assuming leaf is dominant green object)
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
mask = cv2.inRange(hsv, (36, 25, 25), (86, 255, 255)) # Green color range
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
img = cv2.bitwise_and(img, img, mask=mask)
# Convert back to PIL
image = Image.fromarray(img)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image)
def parse_class_label(class_label):
"""Split class label into plant name and disease status"""
parts = class_label.split('___')
plant = parts[0].replace('_', ' ').replace(',', '')
disease = parts[1].replace('_', ' ') if len(parts) > 1 else "healthy"
return plant, disease
def is_healthy_override(image, predicted_class, confidence):
"""Heuristic check for false disease predictions"""
# If model predicts disease but image looks "too clean", override to healthy
if "healthy" not in predicted_class and confidence > 0.9:
# Simple check: count green pixels vs total
img = np.array(image)
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
green_pixels = cv2.inRange(hsv, (36, 25, 25), (86, 255, 255))
green_ratio = np.sum(green_pixels > 0) / (img.shape[0] * img.shape[1])
if green_ratio > 0.7: # Mostly green leaf with no visible spots
return True
return False
# Prediction function with fixes
def predict(image):
try:
# Preprocess
input_tensor = preprocess_image(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
preds = model(input_tensor)
probabilities = torch.nn.functional.softmax(preds[0], dim=0)
# Get top prediction
top_prob, top_idx = torch.max(probabilities, 0)
top_class = class_labels[top_idx.item()]
plant, disease = parse_class_label(top_class)
confidence = top_prob.item()
# Apply fixes
if is_healthy_override(image, top_class, confidence):
return f"Plant: {plant}\nDisease: healthy (Override: Original prediction '{disease}' had {confidence:.2%} confidence but leaf appears healthy)"
# Confidence thresholding
if confidence < 0.7:
return f"Uncertain prediction for {plant} (Confidence: {confidence:.2%})\nPlease upload a clearer image."
return f"Plant: {plant}\nDisease: {disease} (Confidence: {confidence:.2%})"
except Exception as e:
return f"Error: {str(e)}"
# Gradio UI with additional instructions
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Leaf Image"),
outputs=gr.Textbox(label="Prediction Results"),
title="Plant Disease Detection (With Error Correction)",
description="""Upload a clear image of a plant leaf. Tips:
- Crop to show only the leaf
- Use even lighting
- Avoid shadows/reflections""",
examples=[
["examples/healthy_apple.jpg"],
["examples/diseased_tomato.jpg"]
],
allow_flagging="manual"
)
if __name__ == "__main__":
import cv2 # For image processing
iface.launch() |