jacksonwambali commited on
Commit
27b4dc6
·
verified ·
1 Parent(s): 1aeb569

Update app

Browse files
Files changed (1) hide show
  1. app +8 -455
app CHANGED
@@ -1,461 +1,14 @@
1
- import os
2
- import gradio as gr
3
- import re
4
- import folium
5
  from fastai.vision.all import *
6
- from groq import Groq
7
- from PIL import Image
8
-
9
- # Load the trained model
10
  learn = load_learner('export.pkl')
11
- labels = learn.dls.vocab
12
-
13
- # Initialize Groq client
14
- client = Groq(
15
- api_key=os.environ.get("GROQ_API_KEY"),
16
- )
17
-
18
- def clean_bird_name(name):
19
- """Clean bird name by removing numbers and special characters, and fix formatting"""
20
- # Remove numbers and dots at the beginning
21
- cleaned = re.sub(r'^\d+\.', '', name)
22
- # Replace underscores with spaces
23
- cleaned = cleaned.replace('_', ' ')
24
- # Remove any remaining special characters
25
- cleaned = re.sub(r'[^\w\s]', '', cleaned)
26
- # Fix spacing
27
- cleaned = ' '.join(cleaned.split())
28
- return cleaned
29
-
30
- def get_bird_habitat_map(bird_name, check_tanzania=True):
31
- """Get habitat map locations for the bird using Groq API"""
32
- clean_name = clean_bird_name(bird_name)
33
-
34
- # First check if the bird is endemic to Tanzania
35
- if check_tanzania:
36
- tanzania_check_prompt = f"""
37
- Is the {clean_name} bird native to or commonly found in Tanzania?
38
- Answer with ONLY "yes" or "no".
39
- """
40
-
41
- try:
42
- tanzania_check = client.chat.completions.create(
43
- messages=[{"role": "user", "content": tanzania_check_prompt}],
44
- model="llama-3.3-70b-versatile",
45
- )
46
- is_in_tanzania = "yes" in tanzania_check.choices[0].message.content.lower()
47
- except:
48
- # Default to showing Tanzania if we can't determine
49
- is_in_tanzania = True
50
- else:
51
- is_in_tanzania = True
52
-
53
- # Now get the habitat locations
54
- prompt = f"""
55
- Provide a JSON array of the main habitat locations for the {clean_name} bird in the world.
56
- Return ONLY a JSON array with 3-5 entries, each containing:
57
- 1. "name": Location name
58
- 2. "lat": Latitude (numeric value)
59
- 3. "lon": Longitude (numeric value)
60
- 4. "description": Brief description of why this is a key habitat (2-3 sentences)
61
-
62
- Example format:
63
- [
64
- {{"name": "Example Location", "lat": 12.34, "lon": 56.78, "description": "Brief description"}},
65
- ...
66
- ]
67
-
68
- {'' if is_in_tanzania else 'DO NOT include any locations in Tanzania as this bird is not native to or commonly found there.'}
69
- """
70
-
71
- try:
72
- chat_completion = client.chat.completions.create(
73
- messages=[
74
- {
75
- "role": "user",
76
- "content": prompt,
77
- }
78
- ],
79
- model="llama-3.3-70b-versatile",
80
- )
81
- response = chat_completion.choices[0].message.content
82
-
83
- # Extract JSON from response (in case there's additional text)
84
- import json
85
- import re
86
-
87
- # Find JSON pattern in response
88
- json_match = re.search(r'\[.*\]', response, re.DOTALL)
89
- if json_match:
90
- locations = json.loads(json_match.group())
91
- else:
92
- # Fallback if JSON extraction fails
93
- locations = [
94
- {"name": "Primary habitat region", "lat": 0, "lon": 0,
95
- "description": "Could not retrieve specific habitat information for this bird."}
96
- ]
97
-
98
- return locations, is_in_tanzania
99
-
100
- except Exception as e:
101
- return [{"name": "Error retrieving data", "lat": 0, "lon": 0,
102
- "description": "Please try again or check your connection."}], False
103
-
104
- def create_habitat_map(habitat_locations):
105
- """Create a folium map with the habitat locations"""
106
- # Find center point based on valid coordinates
107
- valid_coords = [(loc.get("lat", 0), loc.get("lon", 0))
108
- for loc in habitat_locations
109
- if loc.get("lat", 0) != 0 or loc.get("lon", 0) != 0]
110
-
111
- if valid_coords:
112
- # Calculate the average of the coordinates
113
- avg_lat = sum(lat for lat, _ in valid_coords) / len(valid_coords)
114
- avg_lon = sum(lon for _, lon in valid_coords) / len(valid_coords)
115
- # Create map centered on the average coordinates
116
- m = folium.Map(location=[avg_lat, avg_lon], zoom_start=3)
117
- else:
118
- # Default world map if no valid coordinates
119
- m = folium.Map(location=[20, 0], zoom_start=2)
120
-
121
- # Add markers for each habitat location
122
- for location in habitat_locations:
123
- name = location.get("name", "Unknown")
124
- lat = location.get("lat", 0)
125
- lon = location.get("lon", 0)
126
- description = location.get("description", "No description available")
127
-
128
- # Skip invalid coordinates
129
- if lat == 0 and lon == 0:
130
- continue
131
-
132
- # Add marker
133
- folium.Marker(
134
- location=[lat, lon],
135
- popup=folium.Popup(f"<b>{name}</b><br>{description}", max_width=300),
136
- tooltip=name
137
- ).add_to(m)
138
-
139
- # Save map to HTML
140
- map_html = m._repr_html_()
141
- return map_html
142
 
143
- def format_bird_info(raw_info):
144
- """Improve the formatting of bird information"""
145
- # Add proper line breaks between sections and ensure consistent heading levels
146
- formatted = raw_info
147
-
148
- # Fix heading levels (make all main sections h3)
149
- formatted = re.sub(r'#+\s+NOT TYPICALLY FOUND IN TANZANIA',
150
- '<div class="alert alert-warning"><strong>⚠️ NOT TYPICALLY FOUND IN TANZANIA</strong></div>',
151
- formatted)
152
-
153
- # Replace markdown headings with HTML headings for better control
154
- formatted = re.sub(r'#+\s+(.*)', r'<h3>\1</h3>', formatted)
155
-
156
- # Add paragraph tags for better spacing
157
- formatted = re.sub(r'\n\*\s+(.*)', r'<p>• \1</p>', formatted)
158
- formatted = re.sub(r'\n([^<\n].*)', r'<p>\1</p>', formatted)
159
-
160
- # Remove any duplicate paragraph tags
161
- formatted = formatted.replace('<p><p>', '<p>')
162
- formatted = formatted.replace('</p></p>', '</p>')
163
-
164
- return formatted
165
 
166
- def get_bird_info(bird_name):
167
- """Get detailed information about a bird using Groq API"""
168
- clean_name = clean_bird_name(bird_name)
169
-
170
- prompt = f"""
171
- Provide detailed information about the {clean_name} bird, including:
172
- 1. Physical characteristics and appearance
173
- 2. Habitat and distribution
174
- 3. Diet and behavior
175
- 4. Migration patterns (emphasize if this pattern has changed in recent years due to climate change)
176
- 5. Conservation status
177
-
178
- If this bird is not commonly found in Tanzania, explicitly flag that this bird is "NOT TYPICALLY FOUND IN TANZANIA" at the beginning of your response and explain why its presence might be unusual.
179
-
180
- Format your response in markdown for better readability.
181
- """
182
-
183
- try:
184
- chat_completion = client.chat.completions.create(
185
- messages=[
186
- {
187
- "role": "user",
188
- "content": prompt,
189
- }
190
- ],
191
- model="llama-3.3-70b-versatile",
192
- )
193
- return chat_completion.choices[0].message.content
194
- except Exception as e:
195
- return f"Error fetching information: {str(e)}"
196
-
197
- def predict_and_get_info(img):
198
- """Predict bird species and get detailed information"""
199
- # Process the image
200
  img = PILImage.create(img)
201
-
202
- # Get prediction
203
- pred, pred_idx, probs = learn.predict(img)
204
-
205
- # Get top 5 predictions (or all if less than 5)
206
- num_classes = min(5, len(labels))
207
- top_indices = probs.argsort(descending=True)[:num_classes]
208
- top_probs = probs[top_indices]
209
- top_labels = [labels[i] for i in top_indices]
210
-
211
- # Format as dictionary with cleaned names for display
212
- prediction_results = {clean_bird_name(top_labels[i]): float(top_probs[i]) for i in range(num_classes)}
213
-
214
- # Get top prediction (original format for info retrieval)
215
- top_bird = str(pred)
216
- # Also keep a clean version for display
217
- clean_top_bird = clean_bird_name(top_bird)
218
-
219
- # Get habitat locations and create map
220
- habitat_locations, is_in_tanzania = get_bird_habitat_map(top_bird)
221
- habitat_map_html = create_habitat_map(habitat_locations)
222
-
223
- # Get detailed information about the top predicted bird
224
- bird_info = get_bird_info(top_bird)
225
- formatted_info = format_bird_info(bird_info)
226
-
227
- # Create combined info with map at the top and properly formatted information
228
- custom_css = """
229
- <style>
230
- .bird-container {
231
- font-family: Arial, sans-serif;
232
- padding: 10px;
233
- }
234
- .map-container {
235
- height: 400px;
236
- width: 100%;
237
- border: 1px solid #ddd;
238
- border-radius: 8px;
239
- overflow: hidden;
240
- margin-bottom: 20px;
241
- }
242
- .info-container {
243
- line-height: 1.6;
244
- }
245
- .info-container h3 {
246
- margin-top: 20px;
247
- margin-bottom: 10px;
248
- color: #2c3e50;
249
- border-bottom: 1px solid #eee;
250
- padding-bottom: 5px;
251
- }
252
- .info-container p {
253
- margin-bottom: 10px;
254
- }
255
- .alert {
256
- padding: 10px;
257
- margin-bottom: 15px;
258
- border-radius: 4px;
259
- }
260
- .alert-warning {
261
- background-color: #fcf8e3;
262
- border: 1px solid #faebcc;
263
- color: #8a6d3b;
264
- }
265
- </style>
266
- """
267
-
268
- combined_info = f"""
269
- {custom_css}
270
- <div class="bird-container">
271
- <h2>Natural Habitat Map for {clean_top_bird}</h2>
272
- <div class="map-container">
273
- {habitat_map_html}
274
- </div>
275
-
276
- <div class="info-container">
277
- <h2>Detailed Information</h2>
278
- {formatted_info}
279
- </div>
280
- </div>
281
- """
282
-
283
- return prediction_results, combined_info, clean_top_bird
284
-
285
- def follow_up_question(question, bird_name):
286
- """Allow researchers to ask follow-up questions about the identified bird"""
287
- if not question.strip() or not bird_name:
288
- return "Please identify a bird first and ask a specific question about it."
289
-
290
- prompt = f"""
291
- The researcher is asking about the {bird_name} bird: "{question}"
292
-
293
- Provide a detailed, scientific answer focusing on accurate ornithological information.
294
- If the question relates to Tanzania or climate change impacts, emphasize those aspects in your response.
295
-
296
- IMPORTANT: Do not repeat basic introductory information about the bird that would have already been provided in a general description.
297
- Do not start your answer with phrases like "Introduction to the {bird_name}" or similar repetitive headers.
298
- Directly answer the specific question asked.
299
-
300
- Format your response in markdown for better readability.
301
- """
302
-
303
- try:
304
- chat_completion = client.chat.completions.create(
305
- messages=[
306
- {
307
- "role": "user",
308
- "content": prompt,
309
- }
310
- ],
311
- model="llama-3.3-70b-versatile",
312
- )
313
- return chat_completion.choices[0].message.content
314
- except Exception as e:
315
- return f"Error fetching information: {str(e)}"
316
 
317
- # Function to add custom JavaScript for camera switching
318
- def create_camera_switch_js():
319
- return """
320
- function switchCamera() {
321
- // Get all video elements
322
- const videoElements = document.querySelectorAll('video');
323
-
324
- if (videoElements.length > 0) {
325
- const video = videoElements[0];
326
-
327
- // Stop all tracks on the current stream
328
- if (video.srcObject) {
329
- const tracks = video.srcObject.getTracks();
330
- tracks.forEach(track => track.stop());
331
- }
332
-
333
- // Try to determine current camera type
334
- let usingFrontCamera = true;
335
- if (window.currentCamera === 'environment') {
336
- usingFrontCamera = false;
337
- }
338
-
339
- // Set the new camera type
340
- const facingMode = usingFrontCamera ? 'environment' : 'user';
341
- window.currentCamera = facingMode;
342
-
343
- // Request new video stream with the different camera
344
- navigator.mediaDevices.getUserMedia({
345
- video: { facingMode: facingMode }
346
- })
347
- .then(function(stream) {
348
- video.srcObject = stream;
349
- // Play the video
350
- video.play();
351
- })
352
- .catch(function(err) {
353
- console.error('Error accessing the camera: ' + err);
354
- });
355
- } else {
356
- console.warn('No video elements found.');
357
- }
358
- }
359
- """
360
-
361
- # Create the Gradio interface
362
- with gr.Blocks(theme=gr.themes.Soft()) as app:
363
- gr.Markdown("# Bird Species Identification for Researchers")
364
- gr.Markdown("Upload an image to identify bird species and get detailed information relevant to research in Tanzania and climate change studies.")
365
-
366
- # Store the current bird for context
367
- current_bird = gr.State("")
368
-
369
- # Add custom JavaScript for camera switching
370
- app.load(None, None, None, _js=create_camera_switch_js())
371
-
372
- # Main identification section
373
- with gr.Row():
374
- with gr.Column(scale=1):
375
- input_image = gr.Image(
376
- type="pil",
377
- label="Upload Bird Image",
378
- source="webcam", # Enable webcam
379
- )
380
- camera_toggle = gr.Button("Switch Camera")
381
- submit_btn = gr.Button("Identify Bird", variant="primary")
382
-
383
- with gr.Column(scale=2):
384
- prediction_output = gr.Label(label="Top 5 Predictions", num_top_classes=5)
385
- bird_info_output = gr.HTML(label="Bird Information")
386
-
387
- # Clear divider
388
- gr.Markdown("---")
389
-
390
- # Follow-up question section with improved UI
391
- gr.Markdown("## Research Questions")
392
-
393
- conversation_history = gr.Markdown("")
394
-
395
- with gr.Row():
396
- follow_up_input = gr.Textbox(
397
- label="Ask a question about this bird",
398
- placeholder="Example: How has climate change affected this bird's migration pattern?",
399
- lines=2
400
- )
401
-
402
- with gr.Row():
403
- follow_up_btn = gr.Button("Submit Question", variant="primary")
404
- clear_btn = gr.Button("Clear Conversation")
405
-
406
- # Set up event handlers
407
- def process_image(img):
408
- if img is None:
409
- return None, "Please upload an image", "", ""
410
-
411
- try:
412
- pred_results, info, clean_bird_name = predict_and_get_info(img)
413
- return pred_results, info, clean_bird_name, ""
414
- except Exception as e:
415
- return None, f"Error processing image: {str(e)}", "", ""
416
-
417
- def update_conversation(question, bird_name, history):
418
- if not question.strip():
419
- return history
420
-
421
- answer = follow_up_question(question, bird_name)
422
-
423
- # Format the conversation with clear separation
424
- new_exchange = f"""
425
- ### Question:
426
- {question}
427
- ### Answer:
428
- {answer}
429
- ---
430
- """
431
- updated_history = new_exchange + history
432
- return updated_history
433
-
434
- def clear_conversation_history():
435
- return ""
436
-
437
- submit_btn.click(
438
- process_image,
439
- inputs=[input_image],
440
- outputs=[prediction_output, bird_info_output, current_bird, conversation_history]
441
- )
442
-
443
- follow_up_btn.click(
444
- update_conversation,
445
- inputs=[follow_up_input, current_bird, conversation_history],
446
- outputs=[conversation_history]
447
- ).then(
448
- lambda: "",
449
- outputs=follow_up_input
450
- )
451
-
452
- clear_btn.click(
453
- clear_conversation_history,
454
- outputs=[conversation_history]
455
- )
456
-
457
- # Call the JavaScript function when the camera toggle button is clicked
458
- camera_toggle.click(None, None, None, _js="switchCamera")
459
-
460
- # Launch the app
461
- app.launch(share=True)
 
 
 
 
 
1
  from fastai.vision.all import *
2
+ #from fastai.vision.widgets import *
3
+ import timm
 
 
4
  learn = load_learner('export.pkl')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ labels = learn.dls.vocab
8
+ def predict(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  img = PILImage.create(img)
10
+ pred,pred_idx,probs = learn.predict(img)
11
+ return {labels[i]: float(probs[i]) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ import gradio as gr
14
+ gr.Interface(fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=3)).launch(share=True)