ArnavLatiyan commited on
Commit
24cd083
·
verified ·
1 Parent(s): b7c9d01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -62
app.py CHANGED
@@ -1,92 +1,176 @@
1
- import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- from torchvision import models, transforms
 
5
  from PIL import Image
6
- import numpy as np
7
- import warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # SUPPRESS ALL WARNINGS - NO OTHER CHANGES MADE
10
- warnings.filterwarnings("ignore")
11
- torch.set_warn_always(False)
 
 
 
12
 
13
- # YOUR ORIGINAL CODE EXACTLY AS YOU HAD IT:
14
- class_names = [
15
  'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
16
  'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
17
- 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_',
18
  'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot',
19
  'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy',
20
  'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy',
21
  'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight',
22
  'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy',
23
  'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy',
24
- 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
25
  'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite',
26
  'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus',
27
  'Tomato___healthy'
28
  ]
29
 
30
- def load_model(model_path):
31
- model = models.vgg16(pretrained=False)
32
- num_features = model.classifier[6].in_features
33
- model.classifier[6] = nn.Linear(num_features, len(class_names))
34
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
35
- model.eval()
36
- return model
37
-
38
  transform = transforms.Compose([
39
- transforms.Resize((128, 128)),
40
  transforms.ToTensor(),
41
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
42
  ])
43
 
 
 
 
 
 
 
 
44
  def predict(image):
45
- image = Image.fromarray(image.astype('uint8'), 'RGB')
46
- image = transform(image).unsqueeze(0)
47
-
48
- with torch.no_grad():
49
- output = model(image)
50
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
51
- confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
52
-
53
- top_pred = max(confidences, key=confidences.get)
54
- plant = top_pred.split('___')[0]
55
- disease = top_pred.split('___')[1]
56
-
57
- return {
58
- "Plant": plant,
59
- "Disease": disease,
60
- "Confidence": confidences[top_pred],
61
- "All Predictions": confidences
62
- }
63
-
64
- model_path = "model/vgg_model_ft.pth"
65
- model = load_model(model_path)
66
-
67
- title = "Plant Disease Classifier"
68
- description = """
69
- Upload an image of a plant leaf to classify its health status. The model can detect diseases across 14 plant types and 38 disease categories.
 
 
 
 
 
 
 
 
 
 
70
  """
 
 
 
 
71
 
72
- examples = [
73
- ["example_images/healthy_apple.jpg"],
74
- ["example_images/diseased_tomato.jpg"]
75
- ]
 
 
 
 
 
 
 
76
 
 
 
 
77
  iface = gr.Interface(
78
  fn=predict,
79
- inputs=gr.Image(label="Upload Plant Leaf Image"),
80
- outputs=[
81
- gr.Label(label="Plant"),
82
- gr.Label(label="Disease Status"),
83
- gr.Label(label="Confidence"),
84
- gr.Label(label="All Predictions")
85
- ],
86
- title=title,
87
- description=description,
88
- examples=examples,
89
- allow_flagging="never"
90
  )
91
 
92
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torchvision.models as models
4
+ import torchvision.transforms as transforms
5
  from PIL import Image
6
+ import gradio as gr
7
+ from disease_info import get_disease_info
8
+ from flask import Flask, render_template
9
+ import threading
10
+ import socket
11
+ from warnings import filterwarnings
12
+
13
+ # Suppress deprecation warnings
14
+ filterwarnings("ignore", category=UserWarning)
15
+
16
+ # ========== MODEL DEFINITION ==========
17
+ class Plant_Disease_VGG16(nn.Module):
18
+ def __init__(self):
19
+ super().__init__()
20
+ weights = models.VGG16_Weights.IMAGENET1K_V1
21
+ self.network = models.vgg16(weights=weights)
22
+ # Freeze early layers
23
+ for param in list(self.network.features.parameters())[:-5]:
24
+ param.requires_grad = False
25
+ # Modify final layer
26
+ num_ftrs = self.network.classifier[-1].in_features
27
+ self.network.classifier[-1] = nn.Linear(num_ftrs, 38) # 38 classes
28
+
29
+ def forward(self, xb):
30
+ return self.network(xb)
31
 
32
+ # Initialize model
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model = Plant_Disease_VGG16()
35
+ model.load_state_dict(torch.load("model/vgg_model_ft.pth", map_location=device))
36
+ model.to(device)
37
+ model.eval()
38
 
39
+ # Class labels
40
+ class_labels = [
41
  'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
42
  'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy',
43
+ 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_',
44
  'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot',
45
  'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy',
46
  'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy',
47
  'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight',
48
  'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy',
49
  'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy',
50
+ 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight',
51
  'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite',
52
  'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus',
53
  'Tomato___healthy'
54
  ]
55
 
56
+ # Image preprocessing
 
 
 
 
 
 
 
57
  transform = transforms.Compose([
58
+ transforms.Resize((224, 224)),
59
  transforms.ToTensor(),
 
60
  ])
61
 
62
+ def parse_class_label(class_label):
63
+ """Extract plant and disease from class label"""
64
+ parts = class_label.split('___')
65
+ plant = parts[0].replace('_', ' ').replace(',', '')
66
+ disease = parts[1].replace('_', ' ') if len(parts) > 1 else "healthy"
67
+ return plant, disease
68
+
69
  def predict(image):
70
+ """Make prediction on input image"""
71
+ try:
72
+ if image is None:
73
+ return "Error: No image provided"
74
+
75
+ # Preprocess and predict
76
+ image = transform(image).unsqueeze(0).to(device)
77
+ with torch.no_grad():
78
+ preds = model(image)
79
+ probabilities = torch.nn.functional.softmax(preds[0], dim=0)
80
+
81
+ # Get top prediction
82
+ top_prob, top_idx = torch.max(probabilities, 0)
83
+ class_name = class_labels[top_idx.item()]
84
+ plant, disease = parse_class_label(class_name)
85
+
86
+ # Get disease info
87
+ disease_info = get_disease_info(plant, disease)
88
+
89
+ # Format results
90
+ result = f"""
91
+ Plant: {plant}
92
+ Disease: {disease}
93
+
94
+ Description:
95
+ {disease_info['description']}
96
+
97
+ Recommended Treatments:
98
+ {disease_info['pesticides']}
99
+
100
+ Application Timing:
101
+ {disease_info['timing']}
102
+
103
+ Prevention Measures:
104
+ {disease_info['prevention']}
105
  """
106
+ return result
107
+
108
+ except Exception as e:
109
+ return f"Error in prediction: {str(e)}"
110
 
111
+ # ========== WEB APPLICATION ==========
112
+ def find_available_port(start_port):
113
+ """Find next available port from start_port"""
114
+ port = start_port
115
+ while True:
116
+ try:
117
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
118
+ s.bind(('0.0.0.0', port))
119
+ return port
120
+ except OSError:
121
+ port += 1
122
 
123
+ app = Flask(__name__)
124
+
125
+ # Gradio Interface
126
  iface = gr.Interface(
127
  fn=predict,
128
+ inputs=gr.Image(type="pil"),
129
+ outputs=gr.Textbox(label="Analysis Results", lines=20),
130
+ title="GREEN PULSE - Plant Health Analysis",
131
+ description="Upload an image of a plant leaf to detect health issues.",
132
+ examples=[
133
+ ["examples/healthy_apple.jpg"],
134
+ ["examples/diseased_tomato.jpg"]
135
+ ]
 
 
 
136
  )
137
 
138
+ def run_gradio():
139
+ """Launch Gradio in separate thread"""
140
+ global gradio_port
141
+ gradio_port = find_available_port(7860)
142
+ print(f"\nGradio interface running on port: {gradio_port}")
143
+ iface.launch(
144
+ server_name="0.0.0.0",
145
+ server_port=gradio_port,
146
+ share=False,
147
+ prevent_thread_lock=True
148
+ )
149
+
150
+ # Start Gradio thread
151
+ gradio_port = 7860 # Default
152
+ gradio_thread = threading.Thread(target=run_gradio, daemon=True)
153
+ gradio_thread.start()
154
+
155
+ # Flask Routes
156
+ @app.route('/')
157
+ def home():
158
+ """Main landing page"""
159
+ return render_template("index.html")
160
+
161
+ @app.route('/analyze')
162
+ def analyze():
163
+ """Page with embedded Gradio interface"""
164
+ return render_template("analyze.html", gradio_port=gradio_port)
165
+
166
+ @app.route('/results')
167
+ def results():
168
+ """Results display page"""
169
+ return render_template("results.html")
170
+
171
+ if __name__ == '__main__':
172
+ """Main application entry point"""
173
+ flask_port = find_available_port(5000)
174
+ print(f"Flask server running on port: {flask_port}")
175
+ print(f"Access the app at: http://localhost:{flask_port}")
176
+ app.run(debug=True, port=flask_port, use_reloader=False)