cabrel09's picture
docs: update model information and supported crops/diseases
aba8b73
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import google.generativeai as genai
import os
import logging
import asyncio
from typing import Tuple, Optional
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configure Gemini API
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel('gemini-2.0-flash-exp')
else:
logger.warning("GEMINI_API_KEY not found. Remedy suggestions will be disabled.")
gemini_model = None
# Global variables for model caching
model = None
processor = None
# Improved crop and disease mapping
CROP_DISEASE_MAPPING = {
# Tomato diseases
"tomato___bacterial_spot": ("Tomato", "Bacterial Spot"),
"tomato___early_blight": ("Tomato", "Early Blight"),
"tomato___late_blight": ("Tomato", "Late Blight"),
"tomato___leaf_mold": ("Tomato", "Leaf Mold"),
"tomato___septoria_leaf_spot": ("Tomato", "Septoria Leaf Spot"),
"tomato___spider_mites_two_spotted_spider_mite": ("Tomato", "Two-Spotted Spider Mite"),
"tomato___target_spot": ("Tomato", "Target Spot"),
"tomato___tomato_yellow_leaf_curl_virus": ("Tomato", "Yellow Leaf Curl Virus"),
"tomato___tomato_mosaic_virus": ("Tomato", "Mosaic Virus"),
"tomato___healthy": ("Tomato", "Healthy"),
# Potato diseases
"potato___early_blight": ("Potato", "Early Blight"),
"potato___late_blight": ("Potato", "Late Blight"),
"potato___healthy": ("Potato", "Healthy"),
# Apple diseases
"apple___apple_scab": ("Apple", "Apple Scab"),
"apple___black_rot": ("Apple", "Black Rot"),
"apple___cedar_apple_rust": ("Apple", "Cedar Apple Rust"),
"apple___healthy": ("Apple", "Healthy"),
# Corn diseases
"corn_(maize)___cercospora_leaf_spot": ("Corn/Maize", "Cercospora Leaf Spot"),
"corn_(maize)___common_rust": ("Corn/Maize", "Common Rust"),
"corn_(maize)___northern_leaf_blight": ("Corn/Maize", "Northern Leaf Blight"),
"corn_(maize)___healthy": ("Corn/Maize", "Healthy"),
# Grape diseases
"grape___black_rot": ("Grape", "Black Rot"),
"grape___esca_(black_measles)": ("Grape", "Esca (Black Measles)"),
"grape___leaf_blight_(isariopsis_leaf_spot)": ("Grape", "Leaf Blight"),
"grape___healthy": ("Grape", "Healthy"),
# Bell Pepper diseases
"pepper,_bell___bacterial_spot": ("Bell Pepper", "Bacterial Spot"),
"pepper,_bell___healthy": ("Bell Pepper", "Healthy"),
# Cherry diseases
"cherry_(including_sour)___powdery_mildew": ("Cherry", "Powdery Mildew"),
"cherry_(including_sour)___healthy": ("Cherry", "Healthy"),
# Peach diseases
"peach___bacterial_spot": ("Peach", "Bacterial Spot"),
"peach___healthy": ("Peach", "Healthy"),
# Strawberry diseases
"strawberry___leaf_scorch": ("Strawberry", "Leaf Scorch"),
"strawberry___healthy": ("Strawberry", "Healthy"),
# Soybean diseases
"soybean___healthy": ("Soybean", "Healthy"),
# Squash diseases
"squash___powdery_mildew": ("Squash", "Powdery Mildew"),
# Orange diseases
"orange___haunglongbing_(citrus_greening)": ("Orange", "Huanglongbing (Citrus Greening)"),
# Raspberry diseases
"raspberry___healthy": ("Raspberry", "Healthy"),
}
def load_model():
"""Load the ViT model for crop disease classification"""
global model, processor
try:
model_name = "cabrel09/crop_leaf_disease_detector"
logger.info(f"Loading model: {model_name}")
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
logger.info("Model loaded successfully!")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
def parse_label(raw_label: str) -> Tuple[str, str]:
"""Parse the raw model output to extract crop and disease information"""
label = raw_label.lower().strip()
# Check if we have a direct mapping
if label in CROP_DISEASE_MAPPING:
return CROP_DISEASE_MAPPING[label]
# Try to parse common patterns
if "___" in label:
parts = label.split("___")
if len(parts) >= 2:
crop_part = parts[0].replace("_", " ").title()
disease_part = parts[1].replace("_", " ").title()
# Clean up crop names
crop_part = crop_part.replace("(", "").replace(")", "")
if "Corn Maize" in crop_part:
crop_part = "Corn/Maize"
elif "Pepper Bell" in crop_part:
crop_part = "Bell Pepper"
elif "Cherry Including Sour" in crop_part:
crop_part = "Cherry"
return crop_part, disease_part
# Fallback parsing
label_parts = label.replace("-", " ").replace("_", " ").split()
crop_keywords = {
'tomato': 'Tomato', 'potato': 'Potato', 'apple': 'Apple',
'corn': 'Corn', 'maize': 'Corn/Maize', 'grape': 'Grape',
'pepper': 'Pepper', 'cherry': 'Cherry', 'peach': 'Peach',
'strawberry': 'Strawberry', 'soybean': 'Soybean',
'squash': 'Squash', 'orange': 'Orange', 'citrus': 'Citrus',
'raspberry': 'Raspberry', 'wheat': 'Wheat', 'rice': 'Rice'
}
disease_keywords = {
'blight': 'Blight', 'rust': 'Rust', 'spot': 'Spot',
'rot': 'Rot', 'mold': 'Mold', 'virus': 'Virus',
'bacterial': 'Bacterial', 'fungal': 'Fungal',
'healthy': 'Healthy', 'scab': 'Scab', 'mildew': 'Mildew'
}
detected_crop = "Unknown Crop"
detected_disease = "Unknown Condition"
for word in label_parts:
for keyword, crop_name in crop_keywords.items():
if keyword in word:
detected_crop = crop_name
break
for word in label_parts:
for keyword, disease_name in disease_keywords.items():
if keyword in word:
detected_disease = disease_name
break
if 'early' in label and 'blight' in label:
detected_disease = "Early Blight"
elif 'late' in label and 'blight' in label:
detected_disease = "Late Blight"
elif 'leaf' in label and 'spot' in label:
detected_disease = "Leaf Spot"
elif 'mosaic' in label:
detected_disease = "Mosaic Virus"
return detected_crop, detected_disease
def get_detailed_disease_info(crop: str, disease: str) -> str:
"""Get detailed information about the specific crop disease"""
if disease == "Healthy":
return f"βœ… **Healthy {crop}** - Your plant appears to be in good condition! Continue with proper care practices."
elif "blight" in disease.lower():
return f"πŸ‚ **{disease}** detected in **{crop}** - This is a fungal disease that causes tissue death and browning."
elif "spot" in disease.lower():
return f"⚫ **{disease}** found on **{crop}** - Characterized by dark lesions on plant tissues."
elif "rust" in disease.lower():
return f"🦠 **{disease}** affecting **{crop}** - Shows as orange/reddish pustules on plant surfaces."
elif "virus" in disease.lower():
return f"🦠 **{disease}** in **{crop}** - Viral infection causing various symptoms including mottling and distortion."
else:
return f"πŸ” **{disease}** detected in **{crop}** - This condition requires attention and proper management."
async def get_enhanced_remedy_suggestions(crop: str, disease: str, confidence: float) -> str:
"""Get AI-powered remedy suggestions using Gemini API with crop-specific context"""
if not gemini_model:
return get_offline_remedy_suggestions(crop, disease)
try:
prompt = f"""
As an expert plant pathologist and agricultural consultant, provide comprehensive treatment recommendations for {disease} affecting {crop} plants.
The diagnosis confidence is {confidence:.1%}.
Please provide a detailed response with the following sections:
1. **Immediate Actions** (2-3 urgent steps to take right now)
2. **Treatment Options**:
- Organic/biological treatments
- Chemical fungicides/bactericides (if needed)
- Cultural control methods
3. **Prevention Strategy** (4-5 preventive measures)
4. **Monitoring & Follow-up** (what to watch for)
5. **Expected Timeline** (recovery expectations)
Make it specific to {crop} cultivation and {disease} characteristics. Include product names where appropriate and be practical for farmers.
Format with clear headings and bullet points for easy reading.
"""
response = gemini_model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
temperature=0.7,
max_output_tokens=1200,
)
)
return f"πŸ€– **AI-Powered Treatment Plan for {crop} - {disease}:**\n\n{response.text}"
except Exception as e:
logger.error(f"Error getting AI remedy suggestions: {str(e)}")
return get_offline_remedy_suggestions(crop, disease)
def get_offline_remedy_suggestions(crop: str, disease: str) -> str:
"""Fallback offline remedy suggestions"""
return f"""
**πŸ“‹ General Treatment Approach for {crop} - {disease}:**
**🚨 Immediate Actions:**
β€’ Identify and remove affected plant parts
β€’ Improve growing conditions (drainage, spacing, air circulation)
β€’ Apply appropriate treatment based on disease type
**πŸ§ͺ Treatment Options:**
β€’ For fungal diseases: Use fungicides (copper-based for organic)
β€’ For bacterial diseases: Copper compounds or antibiotics
β€’ For viral diseases: Remove infected plants, control insect vectors
**πŸ›‘οΈ Prevention Strategy:**
β€’ Choose resistant varieties when available
β€’ Practice crop rotation
β€’ Maintain proper plant hygiene
β€’ Monitor regularly for early detection
**πŸ“ž Recommendation:** Consult local agricultural extension services for region-specific advice.
"""
def classify_disease_enhanced(image: Image.Image) -> Tuple[str, str, float, str]:
"""Enhanced disease classification with better parsing"""
if model is None or processor is None:
return "Unknown Crop", "Error", 0.0, "Model not loaded properly."
try:
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probabilities = F.softmax(outputs.logits, dim=-1)
confidence = torch.max(probabilities).item()
predicted_class_id = outputs.logits.argmax().item()
predicted_label = model.config.id2label[predicted_class_id]
crop, disease = parse_label(predicted_label)
disease_info = get_detailed_disease_info(crop, disease)
logger.info(f"Raw prediction: {predicted_label}")
logger.info(f"Parsed: Crop={crop}, Disease={disease}, Confidence={confidence:.2%}")
return crop, disease, confidence, disease_info
except Exception as e:
logger.error(f"Error in classification: {str(e)}")
return "Unknown Crop", "Error", 0.0, f"Classification failed: {str(e)}"
async def process_image_enhanced(image: Image.Image) -> Tuple[str, str]:
"""Enhanced image processing with better crop and disease identification"""
if image is None:
return "⚠️ No image provided", ""
crop, disease, confidence, disease_info = classify_disease_enhanced(image)
if disease == "Error":
return f"❌ **Classification Error**\n{disease_info}", ""
confidence_emoji = "🎯" if confidence > 0.8 else "πŸ“Š" if confidence > 0.6 else "⚠️"
main_result = f"""
## πŸ”¬ **Plant Disease Analysis Results**
### 🌱 **Identified Crop:** `{crop}`
### 🦠 **Disease Status:** `{disease}`
### πŸ“Š **Confidence Score:** {confidence_emoji} `{confidence:.2%}`
---
### πŸ“‹ **Disease Information:**
{disease_info}
### πŸ” **Analysis Notes:**
- Higher confidence scores (>80%) indicate more reliable identification
- Multiple factors including image quality affect accuracy
- Consider consulting agricultural experts for critical decisions
---
""".strip()
remedy_text = await get_enhanced_remedy_suggestions(crop, disease, confidence)
return main_result, remedy_text
# Optimized CSS - Functional and space-efficient
optimized_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap');
:root {
--primary-green: #7FB069;
--light-green: #A7C957;
--dark-green: #6A994E;
--accent-green: #95D5B2;
--bg-dark: #1a2e1a;
--text-light: #F8F9F0;
--glass-bg: rgba(26, 46, 26, 0.25);
}
.gradio-container {
background: linear-gradient(135deg, var(--bg-dark) 0%, #2d4a2d 50%, var(--bg-dark) 100%);
min-height: 100vh;
font-family: 'Inter', sans-serif;
padding: 0 !important;
}
/* COMPACT NAVBAR - Minimal space usage */
.compact-navbar {
background: var(--glass-bg);
backdrop-filter: blur(15px);
padding: 0.8rem 2rem;
border-bottom: 1px solid rgba(183, 228, 199, 0.15);
display: flex;
justify-content: space-between;
align-items: center;
position: sticky;
top: 0;
z-index: 1000;
height: 60px;
}
.navbar-title {
background: linear-gradient(135deg, var(--accent-green), var(--primary-green));
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 1.4rem;
font-weight: 600;
margin: 0;
}
.navbar-subtitle {
color: rgba(248, 249, 240, 0.7);
font-size: 0.85rem;
margin: 0;
}
/* MAIN CONTENT - Maximum space utilization */
.main-container {
display: flex;
height: calc(100vh - 60px);
gap: 1rem;
padding: 1rem;
}
/* LEFT PANEL - Compact upload area */
.input-panel {
background: var(--glass-bg);
backdrop-filter: blur(15px);
border-radius: 12px;
padding: 1.5rem;
border: 1px solid rgba(183, 228, 199, 0.1);
width: 320px;
min-width: 320px;
display: flex;
flex-direction: column;
}
.input-panel h3 {
color: var(--accent-green);
font-size: 1.1rem;
margin: 0 0 1rem 0;
font-weight: 500;
}
/* COMPACT UPLOAD - Functional size */
.compact-upload {
border: 2px dashed var(--primary-green) !important;
border-radius: 8px !important;
background: rgba(127, 176, 105, 0.05) !important;
height: 200px !important;
transition: all 0.2s ease !important;
margin-bottom: 1rem !important;
}
.compact-upload:hover {
border-color: var(--light-green) !important;
background: rgba(127, 176, 105, 0.1) !important;
}
/* RIGHT PANEL - MAXIMUM SPACE FOR RESULTS */
.results-panel {
background: var(--glass-bg);
backdrop-filter: blur(15px);
border-radius: 12px;
border: 1px solid rgba(183, 228, 199, 0.1);
flex: 1;
display: flex;
flex-direction: column;
overflow: hidden;
}
.results-header {
padding: 1rem 1.5rem 0.5rem;
border-bottom: 1px solid rgba(183, 228, 199, 0.1);
background: rgba(127, 176, 105, 0.05);
}
.results-header h3 {
color: var(--accent-green);
font-size: 1.2rem;
margin: 0;
font-weight: 600;
}
/* SCROLLABLE RESULTS - Key improvement */
.scrollable-results {
flex: 1;
overflow-y: auto;
padding: 0;
max-height: calc(100vh - 200px);
}
.result-content {
padding: 1.5rem;
color: var(--text-light);
line-height: 1.6;
height: 100%;
}
/* Custom scrollbar */
.scrollable-results::-webkit-scrollbar {
width: 8px;
}
.scrollable-results::-webkit-scrollbar-track {
background: rgba(127, 176, 105, 0.1);
border-radius: 4px;
}
.scrollable-results::-webkit-scrollbar-thumb {
background: var(--primary-green);
border-radius: 4px;
}
.scrollable-results::-webkit-scrollbar-thumb:hover {
background: var(--light-green);
}
/* STATUS INDICATORS */
.status-indicator {
display: inline-block;
padding: 0.25rem 0.75rem;
border-radius: 20px;
font-size: 0.8rem;
font-weight: 500;
}
.status-analyzing {
background: rgba(167, 201, 87, 0.2);
color: var(--light-green);
}
.status-ready {
background: rgba(149, 213, 178, 0.2);
color: var(--accent-green);
}
.status-error {
background: rgba(255, 99, 99, 0.2);
color: #ff9999;
}
/* MOBILE RESPONSIVE */
@media (max-width: 768px) {
.main-container {
flex-direction: column;
height: auto;
min-height: calc(100vh - 60px);
}
.input-panel {
width: 100%;
min-width: auto;
}
.compact-upload {
height: 150px !important;
}
.navbar-title {
font-size: 1.2rem;
}
.scrollable-results {
max-height: 400px;
}
}
/* Tab improvements */
.tab-nav button {
background: rgba(127, 176, 105, 0.1) !important;
border: 1px solid rgba(127, 176, 105, 0.2) !important;
color: var(--text-light) !important;
}
.tab-nav button.selected {
background: var(--primary-green) !important;
color: var(--bg-dark) !important;
}
/* Markdown content improvements */
.result-content h1, .result-content h2, .result-content h3 {
color: var(--accent-green);
margin-top: 1.5rem;
margin-bottom: 1rem;
}
.result-content h1:first-child, .result-content h2:first-child, .result-content h3:first-child {
margin-top: 0;
}
.result-content code {
background: rgba(127, 176, 105, 0.2);
padding: 0.2rem 0.4rem;
border-radius: 4px;
color: var(--light-green);
}
.result-content strong {
color: var(--accent-green);
}
.result-content ul, .result-content ol {
padding-left: 1.5rem;
}
.result-content li {
margin: 0.5rem 0;
}
"""
# Initialize the model
logger.info("Initializing Enhanced Crop Disease Detector...")
model_loaded = load_model()
def create_optimized_interface():
"""Space-optimized interface focusing on AI results"""
with gr.Blocks(css=optimized_css, theme=gr.themes.Soft(), title="🌱 AgriVision AI") as interface:
# COMPACT NAVBAR instead of large header
gr.HTML("""
<div class="compact-navbar">
<div>
<h1 class="navbar-title">🌱 AgriVision AI</h1>
<p class="navbar-subtitle">Advanced Crop Disease Detection</p>
</div>
<div class="status-indicator status-ready">Ready</div>
</div>
""")
if not model_loaded:
gr.Warning("⚠️ Model failed to load. Please refresh the page or contact support.")
return interface
# MAIN CONTAINER - Maximized space
with gr.Row(elem_classes="main-container"):
# LEFT PANEL - Compact input area
with gr.Column(scale=0, min_width=320, elem_classes="input-panel"):
gr.HTML("<h3>πŸ“€ Upload Plant Image</h3>")
# COMPACT UPLOAD - Functional size only
image_input = gr.Image(
type="pil",
label="",
elem_classes="compact-upload",
show_label=False,
height=200
)
# Quick info - minimal space
gr.HTML("""
<div style="font-size: 0.85rem; color: rgba(248, 249, 240, 0.7); line-height: 1.4; margin-top: 1rem;">
πŸ’‘ <strong>Tips:</strong><br>
β€’ Clear, well-lit images work best<br>
β€’ Focus on affected plant parts<br>
β€’ Analysis starts automatically on upload<br>
β€’ Supports 38+ crop-disease combinations
</div>
""")
# RIGHT PANEL - MAXIMUM SPACE for results
with gr.Column(scale=1, elem_classes="results-panel"):
# Results header
gr.HTML("""
<div class="results-header">
<h3>πŸ”¬ AI Analysis Results</h3>
</div>
""")
# TABBED SCROLLABLE RESULTS
with gr.Tabs():
with gr.TabItem("🎯 Disease Classification", elem_id="classification-tab"):
disease_output = gr.Markdown(
elem_classes=["scrollable-results", "result-content"],
show_label=False
)
with gr.TabItem("πŸ’Š Treatment Plan", elem_id="treatment-tab"):
remedy_output = gr.Markdown(
elem_classes=["scrollable-results", "result-content"],
show_label=False
)
with gr.TabItem("πŸ“Š Model Details", elem_id="details-tab"):
model_details = gr.Markdown(
elem_classes=["scrollable-results", "result-content"],
show_label=False,
value="""
**πŸ€– Model Information:**
**Architecture:** Vision Transformer (ViT)
**Model:** cabrel09/crop_leaf_disease_detector
**Dataset:** PlantVillage Dataset
**Classes Supported:** 04 crop-disease combinations
**Input Resolution:** 224x224 pixels
**Processing Time:** ~2-3 seconds
**Supported Crops:**
- Corn/Maize, Rice
- Potato, Wheat
**Disease Categories:**
- Fungal diseases (Blight, Rust, Mold)
- Bacterial infections (Bacterial Spot)
- Viral infections (Mosaic Virus, YLCV)
- Healthy plant detection
**Usage Instructions:**
1. Upload a clear image of the affected plant
2. Ensure good lighting and focus on diseased areas
3. Review the classification results and confidence score
4. Follow the AI-generated treatment recommendations
5. Consult agricultural experts for critical decisions
**Note:** Results are based on visual patterns learned from thousands of plant images. Higher confidence scores (>80%) indicate more reliable identification.
"""
)
# Enhanced analysis function with model details
async def enhanced_analysis(image):
if image is None:
return (
"⚠️ **No Image Uploaded**\n\nPlease upload a plant image to begin analysis.",
"**πŸ”„ Waiting for Image Upload**\n\nOnce you upload an image, the AI will:\n\n1. Process the image through the Vision Transformer\n2. Classify the crop and disease\n3. Generate treatment recommendations\n4. Provide confidence scores and analysis details",
)
try:
# Your existing classification logic here
disease_result, remedy_result = await process_image_enhanced(image)
return disease_result, remedy_result
except Exception as e:
error_msg = f"❌ **Analysis Error**\n\n{str(e)}\n\nPlease try uploading a different image or check the image format."
return error_msg, "**Error Details:**\n\nThe analysis could not be completed. Common issues:\n- Unsupported image format\n- Image too small or unclear\n- Network connectivity issues\n\nPlease try again with a clear, well-lit plant image."
# Event handlers
image_input.change(
fn=enhanced_analysis,
inputs=[image_input],
outputs=[disease_output, remedy_output]
)
return interface
# Launch with optimized settings
if __name__ == "__main__":
interface = create_optimized_interface()
interface.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
debug=True
)