GreenPulse / sample1.py
ArnavLatiyan's picture
Upload 9 files
a02d467 verified
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from disease_info import get_disease_info
from flask import Flask, render_template
import threading
import socket
from warnings import filterwarnings
# Suppress deprecation warnings
filterwarnings("ignore", category=UserWarning)
# ========== MODEL DEFINITION ==========
class Plant_Disease_VGG16(nn.Module):
def __init__(self):
super().__init__()
weights = models.VGG16_Weights.IMAGENET1K_V1
self.network = models.vgg16(weights=weights)
# Freeze early layers
for param in list(self.network.features.parameters())[:-5]:
param.requires_grad = False
# Modify final layer
num_ftrs = self.network.classifier[-1].in_features
self.network.classifier[-1] = nn.Linear(num_ftrs, 38) # 38 classes
def forward(self, xb):
return self.network(xb)
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Plant_Disease_VGG16()
model.load_state_dict(torch.load("model/vgg_model_ft.pth", map_location=device))
model.to(device)
model.eval()
# Class labels
class_labels = [
'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_',
'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot',
'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy',
'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy',
'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight',
'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy',
'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy',
'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus',
'Tomato___healthy'
]
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def parse_class_label(class_label):
"""Extract plant and disease from class label"""
parts = class_label.split('___')
plant = parts[0].replace('_', ' ').replace(',', '')
disease = parts[1].replace('_', ' ') if len(parts) > 1 else "healthy"
return plant, disease
def predict(image):
"""Make prediction on input image"""
try:
if image is None:
return "Error: No image provided"
# Preprocess and predict
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(image)
probabilities = torch.nn.functional.softmax(preds[0], dim=0)
# Get top prediction
top_prob, top_idx = torch.max(probabilities, 0)
class_name = class_labels[top_idx.item()]
plant, disease = parse_class_label(class_name)
# Get disease info
disease_info = get_disease_info(plant, disease)
# Format results
result = f"""
Plant: {plant}
Disease: {disease}
Description:
{disease_info['description']}
Recommended Treatments:
{disease_info['pesticides']}
Application Timing:
{disease_info['timing']}
Prevention Measures:
{disease_info['prevention']}
"""
return result
except Exception as e:
return f"Error in prediction: {str(e)}"
# ========== WEB APPLICATION ==========
def find_available_port(start_port):
"""Find next available port from start_port"""
port = start_port
while True:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('0.0.0.0', port))
return port
except OSError:
port += 1
app = Flask(__name__)
# Gradio Interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Analysis Results", lines=20),
title="GREEN PULSE - Plant Health Analysis",
description="Upload an image of a plant leaf to detect health issues.",
examples=[
["examples/healthy_apple.jpg"],
["examples/diseased_tomato.jpg"]
]
)
def run_gradio():
"""Launch Gradio in separate thread"""
global gradio_port
gradio_port = find_available_port(7860)
print(f"\nGradio interface running on port: {gradio_port}")
iface.launch(
server_name="0.0.0.0",
server_port=gradio_port,
share=False,
prevent_thread_lock=True
)
# Start Gradio thread
gradio_port = 7860 # Default
gradio_thread = threading.Thread(target=run_gradio, daemon=True)
gradio_thread.start()
# Flask Routes
@app.route('/')
def home():
"""Main landing page"""
return render_template("index.html")
@app.route('/analyze')
def analyze():
"""Page with embedded Gradio interface"""
return render_template("analyze.html", gradio_port=gradio_port)
@app.route('/results')
def results():
"""Results display page"""
return render_template("results.html")
if __name__ == '__main__':
"""Main application entry point"""
flask_port = find_available_port(5000)
print(f"Flask server running on port: {flask_port}")
print(f"Access the app at: http://localhost:{flask_port}")
app.run(debug=True, port=flask_port, use_reloader=False)