| import gradio as gr |
| import torch |
| import torchvision |
| from torchvision import transforms |
| from PIL import Image |
| import json |
| import os |
| from typing import Dict, Tuple |
| import numpy as np |
|
|
| def build_efficientnet_model(num_classes: int, device: torch.device): |
| """Build EfficientNet-B0 model with custom classifier - optimized for rice disease classification.""" |
| weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT |
| model = torchvision.models.efficientnet_b0(weights=weights) |
| |
| |
| model.classifier = torch.nn.Sequential( |
| torch.nn.Dropout(p=0.3, inplace=True), |
| torch.nn.Linear(in_features=1280, out_features=num_classes, bias=True) |
| ) |
| return model.to(device) |
|
|
| class RiceDiseaseClassifier: |
| def __init__(self): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.classes = ['bacterial', 'blast', 'brownspot', 'tungro', 'healthy'] |
| self.class_descriptions = { |
| 'bacterial': 'Bacterial Blight - A serious disease causing wilting and yellowing of leaves', |
| 'blast': 'Rice Blast - Fungal disease causing diamond-shaped lesions on leaves', |
| 'brownspot': 'Brown Spot - Fungal disease causing brown spots with yellow halos', |
| 'tungro': 'Tungro Virus - Viral disease causing stunted growth and yellowing', |
| 'healthy': 'Healthy - No disease detected' |
| } |
| |
| |
| self.transform = transforms.Compose([ |
| transforms.Resize((128, 128)), |
| transforms.ToTensor(), |
| transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
| |
| |
| self.model = self.load_model() |
| self.model.eval() |
| |
| def load_model(self): |
| """Load the best performing EfficientNet-B0 model.""" |
| model_path = "efficientnet_b0.pth" |
| |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file {model_path} not found. Please upload your trained EfficientNet model.") |
| |
| print(f"Loading EfficientNet-B0 model from {model_path}") |
| |
| |
| model = build_efficientnet_model(len(self.classes), self.device) |
| |
| |
| try: |
| state_dict = torch.load(model_path, map_location=self.device) |
| model.load_state_dict(state_dict) |
| print("β
EfficientNet model loaded successfully!") |
| except Exception as e: |
| raise RuntimeError(f"Error loading model weights: {str(e)}") |
| |
| return model |
| |
| def predict(self, image: Image.Image) -> Tuple[Dict[str, float], str]: |
| """Predict rice disease from image.""" |
| try: |
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| |
| input_tensor = self.transform(image).unsqueeze(0).to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(input_tensor) |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
| |
| |
| probs = probabilities.cpu().numpy() |
| |
| |
| confidences = {self.classes[i]: float(probs[i]) for i in range(len(self.classes))} |
| |
| |
| predicted_idx = np.argmax(probs) |
| predicted_class = self.classes[predicted_idx] |
| confidence = float(probs[predicted_idx]) |
| |
| |
| result_text = f"**Predicted Disease: {predicted_class.upper()}**\n\n" |
| result_text += f"**Description:** {self.class_descriptions[predicted_class]}\n\n" |
| result_text += f"**Confidence:** {confidence:.2%}\n\n" |
| |
| if predicted_class != 'healthy': |
| result_text += "**Recommendation:** Consult with an agricultural specialist for proper treatment." |
| else: |
| result_text += "**Status:** Your rice plant appears to be healthy!" |
| |
| return confidences, result_text |
| |
| except Exception as e: |
| error_msg = f"Error processing image: {str(e)}" |
| return {cls: 0.0 for cls in self.classes}, error_msg |
|
|
| |
| classifier = RiceDiseaseClassifier() |
|
|
| def classify_rice_disease(image): |
| """Main function for Gradio interface.""" |
| if image is None: |
| return {cls: 0.0 for cls in classifier.classes}, "Please upload an image." |
| |
| confidences, result_text = classifier.predict(image) |
| return confidences, result_text |
|
|
| |
| def create_interface(): |
| """Create and return Gradio interface.""" |
| |
| |
| examples = [ |
| ["examples/bacterial.jpg"] if os.path.exists("examples/bacterial.jpg") else None, |
| ["examples/blast.jpg"] if os.path.exists("examples/blast.jpg") else None, |
| ["examples/healthy.jpg"] if os.path.exists("examples/healthy.jpg") else None, |
| ] |
| examples = [ex for ex in examples if ex is not None] |
| |
| |
| iface = gr.Interface( |
| fn=classify_rice_disease, |
| inputs=[ |
| gr.Image(type="pil", label="Upload Rice Plant Image") |
| ], |
| outputs=[ |
| gr.Label(num_top_classes=5, label="Disease Classification Confidence"), |
| gr.Markdown(label="Detailed Results") |
| ], |
| title="πΎ Rice Disease Classification - EfficientNet Model", |
| description=""" |
| Upload an image of a rice plant to detect potential diseases using our **EfficientNet-B0** deep learning model. |
| |
| **Detectable Conditions:** |
| - π¦ **Bacterial Blight**: Serious bacterial infection |
| - π **Rice Blast**: Common fungal disease |
| - π€ **Brown Spot**: Fungal disease with characteristic spots |
| - π¦ **Tungro Virus**: Viral infection causing stunting |
| - β
**Healthy**: No disease detected |
| |
| *This tool provides preliminary screening only. Always consult agricultural experts for definitive diagnosis and treatment.* |
| """, |
| article=""" |
| ### π€ About This Model |
| This rice disease classification system uses **EfficientNet-B0**, a state-of-the-art deep learning architecture optimized for accuracy and efficiency. |
| The model was trained on a comprehensive dataset of rice plant images with various disease conditions. |
| |
| **Model Specifications:** |
| - Architecture: EfficientNet-B0 with custom classification head |
| - Input Size: 128Γ128 pixels |
| - Parameters: ~5.3M total, ~1.3M trainable |
| - Training: Transfer learning with data augmentation |
| |
| ### π How to Use |
| 1. **Upload Image**: Select a clear photo of rice plant leaves |
| 2. **Automatic Analysis**: The model processes the image instantly |
| 3. **Review Results**: Check confidence scores and disease description |
| 4. **Take Action**: Follow agricultural recommendations if disease detected |
| |
| ### π‘ Best Practices |
| - Use well-lit, focused images |
| - Show clear view of leaves/symptoms |
| - Avoid blurry or distant shots |
| - Include multiple leaves if possible |
| |
| ### β οΈ Important Limitations |
| - **Screening Tool Only**: Not a replacement for professional diagnosis |
| - **Image Quality Dependent**: Results vary with photo quality |
| - **Environmental Factors**: Weather and lighting can affect plant appearance |
| - **Professional Consultation**: Always seek expert agricultural advice for treatment |
| |
| ### π― Model Performance |
| - Optimized for high accuracy across all disease classes |
| - Trained with extensive data augmentation |
| - Validated on diverse rice varieties and conditions |
| - Continuously improved with feedback and new data |
| |
| **Developed for agricultural research and education purposes.** |
| """, |
| examples=examples if examples else None, |
| theme=gr.themes.Soft(), |
| allow_flagging="never" |
| ) |
| |
| return iface |
|
|
| |
| if __name__ == "__main__": |
| |
| demo = create_interface() |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False |
| ) |