Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from disease_info import get_disease_info | |
| from flask import Flask, render_template | |
| import threading | |
| import socket | |
| from warnings import filterwarnings | |
| # Suppress deprecation warnings | |
| filterwarnings("ignore", category=UserWarning) | |
| # ========== MODEL DEFINITION ========== | |
| class Plant_Disease_VGG16(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| weights = models.VGG16_Weights.IMAGENET1K_V1 | |
| self.network = models.vgg16(weights=weights) | |
| # Freeze early layers | |
| for param in list(self.network.features.parameters())[:-5]: | |
| param.requires_grad = False | |
| # Modify final layer | |
| num_ftrs = self.network.classifier[-1].in_features | |
| self.network.classifier[-1] = nn.Linear(num_ftrs, 38) # 38 classes | |
| def forward(self, xb): | |
| return self.network(xb) | |
| # Initialize model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = Plant_Disease_VGG16() | |
| model.load_state_dict(torch.load("model/vgg_model_ft.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Class labels | |
| class_labels = [ | |
| 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', | |
| 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', | |
| 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', | |
| 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', | |
| 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', | |
| 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', | |
| 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', | |
| 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', | |
| 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', | |
| 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', | |
| 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', | |
| 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', | |
| 'Tomato___healthy' | |
| ] | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| def parse_class_label(class_label): | |
| """Extract plant and disease from class label""" | |
| parts = class_label.split('___') | |
| plant = parts[0].replace('_', ' ').replace(',', '') | |
| disease = parts[1].replace('_', ' ') if len(parts) > 1 else "healthy" | |
| return plant, disease | |
| def predict(image): | |
| """Make prediction on input image""" | |
| try: | |
| if image is None: | |
| return "Error: No image provided" | |
| # Preprocess and predict | |
| image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| preds = model(image) | |
| probabilities = torch.nn.functional.softmax(preds[0], dim=0) | |
| # Get top prediction | |
| top_prob, top_idx = torch.max(probabilities, 0) | |
| class_name = class_labels[top_idx.item()] | |
| plant, disease = parse_class_label(class_name) | |
| # Get disease info | |
| disease_info = get_disease_info(plant, disease) | |
| # Format results | |
| result = f""" | |
| Plant: {plant} | |
| Disease: {disease} | |
| Description: | |
| {disease_info['description']} | |
| Recommended Treatments: | |
| {disease_info['pesticides']} | |
| Application Timing: | |
| {disease_info['timing']} | |
| Prevention Measures: | |
| {disease_info['prevention']} | |
| """ | |
| return result | |
| except Exception as e: | |
| return f"Error in prediction: {str(e)}" | |
| # ========== WEB APPLICATION ========== | |
| def find_available_port(start_port): | |
| """Find next available port from start_port""" | |
| port = start_port | |
| while True: | |
| try: | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| s.bind(('0.0.0.0', port)) | |
| return port | |
| except OSError: | |
| port += 1 | |
| app = Flask(__name__) | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Textbox(label="Analysis Results", lines=20), | |
| title="GREEN PULSE - Plant Health Analysis", | |
| description="Upload an image of a plant leaf to detect health issues.", | |
| examples=[ | |
| ["examples/healthy_apple.jpg"], | |
| ["examples/diseased_tomato.jpg"] | |
| ] | |
| ) | |
| def run_gradio(): | |
| """Launch Gradio in separate thread""" | |
| global gradio_port | |
| gradio_port = find_available_port(7860) | |
| print(f"\nGradio interface running on port: {gradio_port}") | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=gradio_port, | |
| share=False, | |
| prevent_thread_lock=True | |
| ) | |
| # Start Gradio thread | |
| gradio_port = 7860 # Default | |
| gradio_thread = threading.Thread(target=run_gradio, daemon=True) | |
| gradio_thread.start() | |
| # Flask Routes | |
| def home(): | |
| """Main landing page""" | |
| return render_template("index.html") | |
| def analyze(): | |
| """Page with embedded Gradio interface""" | |
| return render_template("analyze.html", gradio_port=gradio_port) | |
| def results(): | |
| """Results display page""" | |
| return render_template("results.html") | |
| if __name__ == '__main__': | |
| """Main application entry point""" | |
| flask_port = find_available_port(5000) | |
| print(f"Flask server running on port: {flask_port}") | |
| print(f"Access the app at: http://localhost:{flask_port}") | |
| app.run(debug=True, port=flask_port, use_reloader=False) |