jillian64 commited on
Commit
9927ccf
·
verified ·
1 Parent(s): a1559b4

Rename app.py to app

Browse files
Files changed (2) hide show
  1. app +376 -0
  2. app.py +0 -70
app ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import re
4
+ import folium
5
+ from fastai.vision.all import *
6
+ from groq import Groq
7
+ from PIL import Image
8
+
9
+ # Load the trained model for grapevine disease classification
10
+ learn = load_learner('export.pkl') # Assumes you have a trained model for grapevine diseases
11
+ labels = learn.dls.vocab
12
+
13
+ # Initialize Groq client
14
+ client = Groq(
15
+ api_key=os.environ.get("GROQ_API_KEY"),
16
+ )
17
+
18
+ def clean_disease_name(name):
19
+ """Clean disease name by removing numbers and special characters, and fix formatting"""
20
+ # Remove numbers and dots at the beginning
21
+ cleaned = re.sub(r'^\d+\.', '', name)
22
+ # Replace underscores with spaces
23
+ cleaned = cleaned.replace('_', ' ')
24
+ # Remove any remaining special characters
25
+ cleaned = re.sub(r'[^\w\s]', '', cleaned)
26
+ # Fix spacing
27
+ cleaned = ' '.join(cleaned.split())
28
+ return cleaned
29
+
30
+ def get_disease_distribution_map(disease_name):
31
+ """Get global distribution map for the disease using Groq API"""
32
+ clean_name = clean_disease_name(disease_name)
33
+
34
+ # Get the disease distribution locations
35
+ prompt = f"""
36
+ Provide a JSON array of the main regions where {clean_name} disease is prevalent in grapevines worldwide.
37
+ Return ONLY a JSON array with 3-5 entries, each containing:
38
+ 1. "name": Region or country name
39
+ 2. "lat": Latitude (numeric value)
40
+ 3. "lon": Longitude (numeric value)
41
+ 4. "description": Brief description of why this region is affected and the severity (2-3 sentences)
42
+
43
+ Example format:
44
+ [
45
+ {{"name": "Example Location", "lat": 12.34, "lon": 56.78, "description": "Brief description"}},
46
+ ...
47
+ ]
48
+ """
49
+
50
+ try:
51
+ chat_completion = client.chat.completions.create(
52
+ messages=[
53
+ {
54
+ "role": "user",
55
+ "content": prompt,
56
+ }
57
+ ],
58
+ model="llama-3.3-70b-versatile",
59
+ )
60
+ response = chat_completion.choices[0].message.content
61
+
62
+ # Extract JSON from response (in case there's additional text)
63
+ import json
64
+ import re
65
+
66
+ # Find JSON pattern in response
67
+ json_match = re.search(r'\[.*\]', response, re.DOTALL)
68
+ if json_match:
69
+ locations = json.loads(json_match.group())
70
+ else:
71
+ # Fallback if JSON extraction fails
72
+ locations = [
73
+ {"name": "Major vineyard regions", "lat": 0, "lon": 0,
74
+ "description": "Could not retrieve specific distribution information for this disease."}
75
+ ]
76
+
77
+ return locations
78
+
79
+ except Exception as e:
80
+ return [{"name": "Error retrieving data", "lat": 0, "lon": 0,
81
+ "description": "Please try again or check your connection."}]
82
+
83
+ def create_distribution_map(locations):
84
+ """Create a folium map with the disease distribution locations"""
85
+ # Find center point based on valid coordinates
86
+ valid_coords = [(loc.get("lat", 0), loc.get("lon", 0))
87
+ for loc in locations
88
+ if loc.get("lat", 0) != 0 or loc.get("lon", 0) != 0]
89
+
90
+ if valid_coords:
91
+ # Calculate the average of the coordinates
92
+ avg_lat = sum(lat for lat, _ in valid_coords) / len(valid_coords)
93
+ avg_lon = sum(lon for _, lon in valid_coords) / len(valid_coords)
94
+ # Create map centered on the average coordinates
95
+ m = folium.Map(location=[avg_lat, avg_lon], zoom_start=3)
96
+ else:
97
+ # Default world map if no valid coordinates
98
+ m = folium.Map(location=[20, 0], zoom_start=2)
99
+
100
+ # Add markers for each location
101
+ for location in locations:
102
+ name = location.get("name", "Unknown")
103
+ lat = location.get("lat", 0)
104
+ lon = location.get("lon", 0)
105
+ description = location.get("description", "No description available")
106
+
107
+ # Skip invalid coordinates
108
+ if lat == 0 and lon == 0:
109
+ continue
110
+
111
+ # Add marker with custom color for disease (red)
112
+ folium.Marker(
113
+ location=[lat, lon],
114
+ popup=folium.Popup(f"<b>{name}</b><br>{description}", max_width=300),
115
+ tooltip=name,
116
+ icon=folium.Icon(color='red', icon='info-sign')
117
+ ).add_to(m)
118
+
119
+ # Save map to HTML
120
+ map_html = m._repr_html_()
121
+ return map_html
122
+
123
+ def format_disease_info(raw_info):
124
+ """Improve the formatting of disease information"""
125
+ # Add proper line breaks between sections and ensure consistent heading levels
126
+ formatted = raw_info
127
+
128
+ # Replace markdown headings with HTML headings for better control
129
+ formatted = re.sub(r'#+\s+(.*)', r'<h3>\1</h3>', formatted)
130
+
131
+ # Add paragraph tags for better spacing
132
+ formatted = re.sub(r'\n\*\s+(.*)', r'<p>• \1</p>', formatted)
133
+ formatted = re.sub(r'\n([^<\n].*)', r'<p>\1</p>', formatted)
134
+
135
+ # Remove any duplicate paragraph tags
136
+ formatted = formatted.replace('<p><p>', '<p>')
137
+ formatted = formatted.replace('</p></p>', '</p>')
138
+
139
+ return formatted
140
+
141
+ def get_disease_info(disease_name):
142
+ """Get detailed information about a grapevine disease using Groq API"""
143
+ clean_name = clean_disease_name(disease_name)
144
+
145
+ prompt = f"""
146
+ Provide detailed information about {clean_name} disease in grapevines, including:
147
+ 1. Pathogen information (fungus, bacteria, virus, etc.)
148
+ 2. Symptoms and visual identification
149
+ 3. Disease cycle and conditions that favor development
150
+ 4. Impact on grape production and wine quality
151
+ 5. Management and treatment strategies
152
+ 6. Any recent changes in disease prevalence or severity due to climate change
153
+
154
+ Format your response in markdown for better readability.
155
+ """
156
+
157
+ try:
158
+ chat_completion = client.chat.completions.create(
159
+ messages=[
160
+ {
161
+ "role": "user",
162
+ "content": prompt,
163
+ }
164
+ ],
165
+ model="llama-3.3-70b-versatile",
166
+ )
167
+ return chat_completion.choices[0].message.content
168
+ except Exception as e:
169
+ return f"Error fetching information: {str(e)}"
170
+
171
+ def predict_and_get_info(img):
172
+ """Predict grapevine disease and get detailed information"""
173
+ # Process the image
174
+ img = PILImage.create(img)
175
+
176
+ # Get prediction
177
+ pred, pred_idx, probs = learn.predict(img)
178
+
179
+ # Get top 5 predictions (or all if less than 5)
180
+ num_classes = min(5, len(labels))
181
+ top_indices = probs.argsort(descending=True)[:num_classes]
182
+ top_probs = probs[top_indices]
183
+ top_labels = [labels[i] for i in top_indices]
184
+
185
+ # Format as dictionary with cleaned names for display
186
+ prediction_results = {clean_disease_name(top_labels[i]): float(top_probs[i]) for i in range(num_classes)}
187
+
188
+ # Get top prediction (original format for info retrieval)
189
+ top_disease = str(pred)
190
+ # Also keep a clean version for display
191
+ clean_top_disease = clean_disease_name(top_disease)
192
+
193
+ # Get distribution locations and create map
194
+ distribution_locations = get_disease_distribution_map(top_disease)
195
+ distribution_map_html = create_distribution_map(distribution_locations)
196
+
197
+ # Get detailed information about the top predicted disease
198
+ disease_info = get_disease_info(top_disease)
199
+ formatted_info = format_disease_info(disease_info)
200
+
201
+ # Create combined info with map at the top and properly formatted information
202
+ custom_css = """
203
+ <style>
204
+ .disease-container {
205
+ font-family: Arial, sans-serif;
206
+ padding: 10px;
207
+ }
208
+ .map-container {
209
+ height: 400px;
210
+ width: 100%;
211
+ border: 1px solid #ddd;
212
+ border-radius: 8px;
213
+ overflow: hidden;
214
+ margin-bottom: 20px;
215
+ }
216
+ .info-container {
217
+ line-height: 1.6;
218
+ }
219
+ .info-container h3 {
220
+ margin-top: 20px;
221
+ margin-bottom: 10px;
222
+ color: #8B0000;
223
+ border-bottom: 1px solid #eee;
224
+ padding-bottom: 5px;
225
+ }
226
+ .info-container p {
227
+ margin-bottom: 10px;
228
+ }
229
+ .treatment-section {
230
+ background-color: #f1f8e9;
231
+ padding: 15px;
232
+ border-radius: 8px;
233
+ margin-top: 20px;
234
+ }
235
+ </style>
236
+ """
237
+
238
+ combined_info = f"""
239
+ {custom_css}
240
+ <div class="disease-container">
241
+ <h2>Global Distribution of {clean_top_disease}</h2>
242
+ <div class="map-container">
243
+ {distribution_map_html}
244
+ </div>
245
+
246
+ <div class="info-container">
247
+ <h2>Disease Information</h2>
248
+ {formatted_info}
249
+ </div>
250
+ </div>
251
+ """
252
+
253
+ return prediction_results, combined_info, clean_top_disease
254
+
255
+ def follow_up_question(question, disease_name):
256
+ """Allow vineyard managers to ask follow-up questions about the identified disease"""
257
+ if not question.strip() or not disease_name:
258
+ return "Please identify a grapevine disease first and ask a specific question about it."
259
+
260
+ prompt = f"""
261
+ The vineyard manager is asking about {disease_name} disease: "{question}"
262
+
263
+ Provide a detailed, scientific answer focusing on accurate plant pathology information.
264
+ If the question relates to treatment options, preventive measures, or climate change impacts, emphasize those aspects in your response.
265
+
266
+ IMPORTANT: Do not repeat basic introductory information about the disease that would have already been provided in a general description.
267
+ Do not start your answer with phrases like "Introduction to {disease_name}" or similar repetitive headers.
268
+ Directly answer the specific question asked.
269
+
270
+ Format your response in markdown for better readability.
271
+ """
272
+
273
+ try:
274
+ chat_completion = client.chat.completions.create(
275
+ messages=[
276
+ {
277
+ "role": "user",
278
+ "content": prompt,
279
+ }
280
+ ],
281
+ model="llama-3.3-70b-versatile",
282
+ )
283
+ return chat_completion.choices[0].message.content
284
+ except Exception as e:
285
+ return f"Error fetching information: {str(e)}"
286
+
287
+ # Create the Gradio interface
288
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
289
+ gr.Markdown("# Grapevine Disease Identification Tool")
290
+ gr.Markdown("Upload an image to identify grapevine diseases and get detailed information on management strategies and treatment options.")
291
+
292
+ # Store the current disease for context
293
+ current_disease = gr.State("")
294
+
295
+ # Main identification section
296
+ with gr.Row():
297
+ with gr.Column(scale=1):
298
+ input_image = gr.Image(type="pil", label="Upload Grapevine Image")
299
+ submit_btn = gr.Button("Identify Disease", variant="primary")
300
+
301
+ with gr.Column(scale=2):
302
+ prediction_output = gr.Label(label="Top 5 Predictions", num_top_classes=5)
303
+ disease_info_output = gr.HTML(label="Disease Information")
304
+
305
+ # Clear divider
306
+ gr.Markdown("---")
307
+
308
+ # Follow-up question section with improved UI
309
+ gr.Markdown("## Follow-up Questions")
310
+
311
+ conversation_history = gr.Markdown("")
312
+
313
+ with gr.Row():
314
+ follow_up_input = gr.Textbox(
315
+ label="Ask a question about this disease",
316
+ placeholder="Example: What are the best preventative measures for this disease?",
317
+ lines=2
318
+ )
319
+
320
+ with gr.Row():
321
+ follow_up_btn = gr.Button("Submit Question", variant="primary")
322
+ clear_btn = gr.Button("Clear Conversation")
323
+
324
+ # Set up event handlers
325
+ def process_image(img):
326
+ if img is None:
327
+ return None, "Please upload an image", "", ""
328
+
329
+ try:
330
+ pred_results, info, clean_disease_name = predict_and_get_info(img)
331
+ return pred_results, info, clean_disease_name, ""
332
+ except Exception as e:
333
+ return None, f"Error processing image: {str(e)}", "", ""
334
+
335
+ def update_conversation(question, disease_name, history):
336
+ if not question.strip():
337
+ return history
338
+
339
+ answer = follow_up_question(question, disease_name)
340
+
341
+ # Format the conversation with clear separation
342
+ new_exchange = f"""
343
+ ### Question:
344
+ {question}
345
+ ### Answer:
346
+ {answer}
347
+ ---
348
+ """
349
+ updated_history = new_exchange + history
350
+ return updated_history
351
+
352
+ def clear_conversation_history():
353
+ return ""
354
+
355
+ submit_btn.click(
356
+ process_image,
357
+ inputs=[input_image],
358
+ outputs=[prediction_output, disease_info_output, current_disease, conversation_history]
359
+ )
360
+
361
+ follow_up_btn.click(
362
+ update_conversation,
363
+ inputs=[follow_up_input, current_disease, conversation_history],
364
+ outputs=[conversation_history]
365
+ ).then(
366
+ lambda: "",
367
+ outputs=follow_up_input
368
+ )
369
+
370
+ clear_btn.click(
371
+ clear_conversation_history,
372
+ outputs=[conversation_history]
373
+ )
374
+
375
+ # Launch the app
376
+ app.launch(share=True)
app.py DELETED
@@ -1,70 +0,0 @@
1
- import gradio as gr
2
- from fastai.vision.all import *
3
- import requests
4
- import os
5
-
6
- # Load the trained model
7
- learn = load_learner('export.pkl')
8
- labels = learn.dls.vocab
9
-
10
- # DeepSeek API endpoint and headers (replace with your actual API details)
11
- DEEPSEEK_API_URL = "https://api.deepseek.com/v1/chat/completions" # Example URL, replace with actual API URL
12
- DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") # Use environment variable for security
13
-
14
- def predict(img):
15
- # Predict the disease using the FastAI model
16
- img = PILImage.create(img)
17
- pred, pred_idx, probs = learn.predict(img)
18
- predicted_label = labels[pred_idx]
19
-
20
- # Prepare the prompt for DeepSeek API
21
- prompt = f"Describe the grapevine disease {predicted_label} and suggest treatment methods."
22
-
23
- # Prepare data for DeepSeek API
24
- headers = {
25
- "Authorization": f"Bearer {DEEPSEEK_API_KEY}",
26
- "Content-Type": "application/json"
27
- }
28
- data = {
29
- "model": "deepseek-chat", # Specify the model you are using
30
- "messages": [{"role": "user", "content": prompt}]
31
- }
32
-
33
- # Call DeepSeek API
34
- try:
35
- response = requests.post(DEEPSEEK_API_URL, headers=headers, json=data)
36
- response.raise_for_status() # Raise an error for bad status codes
37
- deepseek_response = response.json()
38
-
39
- # Extract description and treatment from the API response
40
- description = deepseek_response.get("choices", [{}])[0].get("message", {}).get("content", "No description available.")
41
- treatment = "Treatment methods are included in the description." # Modify this if the API provides separate treatment info
42
- except requests.exceptions.RequestException as e:
43
- description = f"Failed to fetch description from DeepSeek API. Error: {str(e)}"
44
- treatment = "Failed to fetch treatment information from DeepSeek API."
45
-
46
- # Return both the prediction probabilities and the DeepSeek response
47
- return {
48
- "prediction": {labels[i]: float(probs[i]) for i in range(len(labels))},
49
- "description": description,
50
- "treatment": treatment
51
- }
52
-
53
- # Define the Gradio interface
54
- interface = gr.Interface(
55
- fn=predict,
56
- inputs=gr.Image(),
57
- outputs=[
58
- gr.Label(num_top_classes=3, label="Prediction Probabilities"),
59
- gr.Textbox(label="Disease Description"),
60
- gr.Textbox(label="Treatment Methods")
61
- ],
62
- title="Grapevine Disease Detection",
63
- description="Upload an image of a grapevine leaf to detect diseases and get treatment suggestions."
64
- )
65
-
66
- # Enable the queue to handle POST requests
67
- interface.queue(api_open=True)
68
-
69
- # Launch the interface
70
- interface.launch()