msmaje's picture
Create app.py
54611eb verified
import gradio as gr
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import json
import os
from typing import Dict, Tuple
import numpy as np
def build_efficientnet_model(num_classes: int, device: torch.device):
"""Build EfficientNet-B0 model with custom classifier - optimized for rice disease classification."""
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
model = torchvision.models.efficientnet_b0(weights=weights)
# Replace classifier to match training setup
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.3, inplace=True),
torch.nn.Linear(in_features=1280, out_features=num_classes, bias=True)
)
return model.to(device)
class RiceDiseaseClassifier:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.classes = ['bacterial', 'blast', 'brownspot', 'tungro', 'healthy']
self.class_descriptions = {
'bacterial': 'Bacterial Blight - A serious disease causing wilting and yellowing of leaves',
'blast': 'Rice Blast - Fungal disease causing diamond-shaped lesions on leaves',
'brownspot': 'Brown Spot - Fungal disease causing brown spots with yellow halos',
'tungro': 'Tungro Virus - Viral disease causing stunted growth and yellowing',
'healthy': 'Healthy - No disease detected'
}
# Define transforms (same as training)
self.transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x), # Convert grayscale to RGB
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load the best model (you'll need to determine which performed best)
self.model = self.load_model()
self.model.eval()
def load_model(self):
"""Load the best performing EfficientNet-B0 model."""
model_path = "efficientnet_b0.pth"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file {model_path} not found. Please upload your trained EfficientNet model.")
print(f"Loading EfficientNet-B0 model from {model_path}")
# Build EfficientNet model architecture
model = build_efficientnet_model(len(self.classes), self.device)
# Load trained weights
try:
state_dict = torch.load(model_path, map_location=self.device)
model.load_state_dict(state_dict)
print("βœ… EfficientNet model loaded successfully!")
except Exception as e:
raise RuntimeError(f"Error loading model weights: {str(e)}")
return model
def predict(self, image: Image.Image) -> Tuple[Dict[str, float], str]:
"""Predict rice disease from image."""
try:
# Preprocess image
if image.mode != 'RGB':
image = image.convert('RGB')
# Apply transforms
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Convert to numpy and create results
probs = probabilities.cpu().numpy()
# Create confidence dictionary
confidences = {self.classes[i]: float(probs[i]) for i in range(len(self.classes))}
# Get predicted class
predicted_idx = np.argmax(probs)
predicted_class = self.classes[predicted_idx]
confidence = float(probs[predicted_idx])
# Create detailed result
result_text = f"**Predicted Disease: {predicted_class.upper()}**\n\n"
result_text += f"**Description:** {self.class_descriptions[predicted_class]}\n\n"
result_text += f"**Confidence:** {confidence:.2%}\n\n"
if predicted_class != 'healthy':
result_text += "**Recommendation:** Consult with an agricultural specialist for proper treatment."
else:
result_text += "**Status:** Your rice plant appears to be healthy!"
return confidences, result_text
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
return {cls: 0.0 for cls in self.classes}, error_msg
# Initialize classifier
classifier = RiceDiseaseClassifier()
def classify_rice_disease(image):
"""Main function for Gradio interface."""
if image is None:
return {cls: 0.0 for cls in classifier.classes}, "Please upload an image."
confidences, result_text = classifier.predict(image)
return confidences, result_text
# Create Gradio interface
def create_interface():
"""Create and return Gradio interface."""
# Define examples (you can add example images to the repository)
examples = [
["examples/bacterial.jpg"] if os.path.exists("examples/bacterial.jpg") else None,
["examples/blast.jpg"] if os.path.exists("examples/blast.jpg") else None,
["examples/healthy.jpg"] if os.path.exists("examples/healthy.jpg") else None,
]
examples = [ex for ex in examples if ex is not None] # Filter out None values
# Create the interface
iface = gr.Interface(
fn=classify_rice_disease,
inputs=[
gr.Image(type="pil", label="Upload Rice Plant Image")
],
outputs=[
gr.Label(num_top_classes=5, label="Disease Classification Confidence"),
gr.Markdown(label="Detailed Results")
],
title="🌾 Rice Disease Classification - EfficientNet Model",
description="""
Upload an image of a rice plant to detect potential diseases using our **EfficientNet-B0** deep learning model.
**Detectable Conditions:**
- 🦠 **Bacterial Blight**: Serious bacterial infection
- πŸ„ **Rice Blast**: Common fungal disease
- 🟀 **Brown Spot**: Fungal disease with characteristic spots
- 🦠 **Tungro Virus**: Viral infection causing stunting
- βœ… **Healthy**: No disease detected
*This tool provides preliminary screening only. Always consult agricultural experts for definitive diagnosis and treatment.*
""",
article="""
### πŸ€– About This Model
This rice disease classification system uses **EfficientNet-B0**, a state-of-the-art deep learning architecture optimized for accuracy and efficiency.
The model was trained on a comprehensive dataset of rice plant images with various disease conditions.
**Model Specifications:**
- Architecture: EfficientNet-B0 with custom classification head
- Input Size: 128Γ—128 pixels
- Parameters: ~5.3M total, ~1.3M trainable
- Training: Transfer learning with data augmentation
### πŸ“‹ How to Use
1. **Upload Image**: Select a clear photo of rice plant leaves
2. **Automatic Analysis**: The model processes the image instantly
3. **Review Results**: Check confidence scores and disease description
4. **Take Action**: Follow agricultural recommendations if disease detected
### πŸ’‘ Best Practices
- Use well-lit, focused images
- Show clear view of leaves/symptoms
- Avoid blurry or distant shots
- Include multiple leaves if possible
### ⚠️ Important Limitations
- **Screening Tool Only**: Not a replacement for professional diagnosis
- **Image Quality Dependent**: Results vary with photo quality
- **Environmental Factors**: Weather and lighting can affect plant appearance
- **Professional Consultation**: Always seek expert agricultural advice for treatment
### 🎯 Model Performance
- Optimized for high accuracy across all disease classes
- Trained with extensive data augmentation
- Validated on diverse rice varieties and conditions
- Continuously improved with feedback and new data
**Developed for agricultural research and education purposes.**
""",
examples=examples if examples else None,
theme=gr.themes.Soft(),
allow_flagging="never"
)
return iface
# Launch the app
if __name__ == "__main__":
# Create and launch interface
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)