swapnith / app.py
madan2248c's picture
Upload app.py
8c7d1ac verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os
from transformers import ViTForImageClassification
# Disease class mappings - 19 classes from your trained model
DISEASE_CLASSES = [
"Chilli - Healthy", # 00
"Chilli - Leaf Curl Virus", # 01
"Pepper Bell - Bacterial Spot", # 02
"Pepper Bell - Healthy", # 03
"Potato - Early Blight", # 04
"Potato - Healthy", # 05
"Potato - Late Blight", # 06
"Tomato - Bacterial Spot", # 07
"Tomato - Early Blight", # 08
"Tomato - Healthy", # 09
"Tomato - Late Blight", # 10
"Tomato - Leaf Mold", # 11
"Tomato - Mosaic Virus", # 12
"Tomato - Septoria Leaf Spot", # 13
"Tomato - Target Spot", # 14
"Tomato - Two Spotted Spider Mite", # 15
"Tomato - Yellow Leaf Curl Virus", # 16
"GroundNut - Healthy", # 17
"GroundNut - Rust" # 18
]
# Disease information database (simplified)
DISEASE_INFO = {
"Chilli - Healthy": {
"description": "The chilli plant appears healthy with no visible signs of disease.",
"treatment": "Continue good agricultural practices and regular monitoring."
},
"Chilli - Leaf Curl Virus": {
"description": "Leaf curl virus causes leaves to curl, wrinkle, and become distorted.",
"treatment": "Remove infected plants, control whiteflies with neem oil, use yellow sticky traps."
},
"Tomato - Early Blight": {
"description": "Early blight causes characteristic target-spot patterns on older leaves.",
"treatment": "Apply fungicides (chlorothalonil, mancozeb), remove infected leaves, improve air circulation."
},
"Potato - Late Blight": {
"description": "Late blight is a devastating disease that can destroy entire crops rapidly.",
"treatment": "Apply systemic fungicides immediately, destroy infected plants, improve air circulation."
}
}
# Global variables
model = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def load_model():
"""Load the TinyViT student model"""
global model
# Look for model file
model_paths = [
"best_student_attention_kd.pth",
"model/best_student_attention_kd.pth"
]
model_path = None
for path in model_paths:
if os.path.exists(path):
model_path = path
break
if model_path is None:
raise FileNotFoundError("Model file not found")
try:
print("Loading TinyViT student model...")
# Initialize TinyViT model architecture
model = ViTForImageClassification.from_pretrained(
"WinKawaks/vit-tiny-patch16-224",
num_labels=len(DISEASE_CLASSES),
ignore_mismatched_sizes=True
)
# Load trained weights
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
print(f"βœ“ Model loaded successfully on {device}")
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
def predict_disease(image):
"""Predict plant disease from image"""
if model is None:
return "❌ Model not loaded", "", ""
if image is None:
return "❌ No image provided", "", ""
try:
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(image_tensor)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
predicted_class = DISEASE_CLASSES[predicted_idx.item()]
confidence_score = confidence.item()
# Parse results
parts = predicted_class.split(" - ")
crop_type = parts[0] if len(parts) > 0 else "Unknown"
disease = parts[1] if len(parts) > 1 else predicted_class
status = "🟒 Healthy" if "healthy" in disease.lower() else "πŸ”΄ Diseased"
# Get disease info
disease_info = DISEASE_INFO.get(predicted_class, {
"description": f"Information about {disease}.",
"treatment": "Consult with a plant pathologist or agricultural extension service."
})
# Format results
result = f"""
## 🌱 **Crop Type:** {crop_type}
## 🦠 **Disease:** {disease}
## πŸ“Š **Status:** {status}
## 🎯 **Confidence:** {confidence_score:.2%}
### πŸ“ **Description:**
{disease_info['description']}
### πŸ’Š **Treatment:**
{disease_info['treatment']}
"""
return result, predicted_class, f"Confidence: {confidence_score:.2%}"
except Exception as e:
return f"❌ Error processing image: {str(e)}", "", ""
# Load model on startup
print("Initializing Plant Disease Detection Model...")
model_loaded = load_model()
if not model_loaded:
print("⚠️ Model failed to load - running in demo mode")
# Create Gradio interface with simpler configuration
demo = gr.Interface(
fn=predict_disease,
inputs=gr.Image(type="pil", label="πŸ“Έ Upload Plant Image"),
outputs=[
gr.Markdown(label="πŸ“‹ Analysis Results"),
gr.Textbox(label="🏷️ Predicted Class"),
gr.Textbox(label="πŸ“Š Confidence Score")
],
title="🌱 Plant Disease Detection AI",
description="""
Upload an image of a plant leaf to detect diseases and get treatment recommendations.
**Supported Plants:** Chilli, Pepper Bell, Potato, Tomato, GroundNut
**Supported Diseases:** 19 different disease categories including healthy plants
""",
examples=None, # You can add example images here later
cache_examples=False
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)