ddecosmo commited on
Commit
515b9d8
Β·
verified Β·
1 Parent(s): 9c521df

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements (4).txt +12 -0
  2. updated_proto_kde_saving.py +407 -0
requirements (4).txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ autogluon.multimodal
2
+ datasets
3
+ folium
4
+ geopy
5
+ gradio
6
+ huggingface-hub
7
+ matplotlib
8
+ numpy
9
+ pandas
10
+ Pillow
11
+ scikit-learn
12
+ scipy
updated_proto_kde_saving.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Updated_proto_KDE_saving.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1FaE0wh8yJYv3lxVbhyN4r9eHUNBHWAOX
8
+ """
9
+
10
+ pip install -q numpy pandas Pillow gradio huggingface-hub tensorflow scipy matplotlib folium autogluon.multimodal
11
+
12
+ pip install geopy
13
+
14
+ # ==============================================================================
15
+ # CELL 1: SETUP AND CONSOLIDATED IMPORTS
16
+ # ==============================================================================
17
+ import gradio as gr
18
+ import os
19
+ import json
20
+ import uuid
21
+ import shutil
22
+ import zipfile
23
+ import pathlib
24
+ import tempfile
25
+ import pandas as pd
26
+ import PIL.Image
27
+ from datetime import datetime
28
+ import huggingface_hub
29
+ import autogluon.multimodal
30
+ import numpy as np
31
+ import matplotlib.pyplot as plt
32
+ import matplotlib.cm as cm
33
+ import matplotlib.colors
34
+ import folium
35
+ from scipy.stats import gaussian_kde
36
+ from datasets import load_dataset
37
+ from geopy.geocoders import Nominatim
38
+ from geopy.extra.rate_limiter import RateLimiter
39
+
40
+ # ==============================================================================
41
+ # CELL 2: CORE LOGIC FOR TAB 1 (UNCHANGED)
42
+ # ==============================================================================
43
+
44
+ # --- Functions for Data Capture ---
45
+ def get_current_time():
46
+ return datetime.now().isoformat()
47
+
48
+ def handle_time_capture():
49
+ timestamp = get_current_time()
50
+ status_msg = f"πŸ• **Time Captured**: {timestamp}"
51
+ return status_msg, timestamp
52
+
53
+ def get_gps_js():
54
+ return """
55
+ () => {
56
+ if (!navigator.geolocation) { alert("Geolocation not supported"); return; }
57
+ navigator.geolocation.getCurrentPosition(
58
+ function(position) {
59
+ const latBox = document.querySelector('#lat textarea');
60
+ const lonBox = document.querySelector('#lon textarea');
61
+ const accuracyBox = document.querySelector('#accuracy textarea');
62
+ const timestampBox = document.querySelector('#device_ts textarea');
63
+ if (latBox && lonBox && accuracyBox && timestampBox) {
64
+ latBox.value = position.coords.latitude.toString();
65
+ lonBox.value = position.coords.longitude.toString();
66
+ accuracyBox.value = position.coords.accuracy.toString();
67
+ timestampBox.value = new Date().toISOString();
68
+ latBox.dispatchEvent(new Event('input', { bubbles: true }));
69
+ lonBox.dispatchEvent(new Event('input', { bubbles: true }));
70
+ accuracyBox.dispatchEvent(new Event('input', { bubbles: true }));
71
+ timestampBox.dispatchEvent(new Event('input', { bubbles: true }));
72
+ } else { alert("Error: Could not find GPS input fields"); }
73
+ },
74
+ function(err) { alert("GPS Error: " + err.message); },
75
+ { enableHighAccuracy: true, timeout: 10000 }
76
+ );
77
+ }
78
+ """
79
+
80
+ def save_to_dataset(image, lat, lon, accuracy_m, device_ts):
81
+ if image is None:
82
+ return "❌ **Error**: Please capture or upload a photo first.", ""
83
+ mock_data = {
84
+ "image": "image.jpg", "latitude": lat, "longitude": lon,
85
+ "accuracy_m": accuracy_m, "device_timestamp": device_ts,
86
+ "status": "Saving Disabled"
87
+ }
88
+ status = "βœ… **Test Save Successful!** (No data saved)"
89
+ return status, json.dumps(mock_data, indent=2)
90
+
91
+ placeholder_time_capture = handle_time_capture
92
+ placeholder_save_action = save_to_dataset
93
+
94
+ # --- Functions for Model Prediction ---
95
+ MODEL_REPO_ID = "ddecosmo/lanternfly_classifier"
96
+ ZIP_FILENAME = "autogluon_image_predictor_dir.zip"
97
+ CLASS_LABELS = {0: "Lanternfly", 1: "Other Insect", 2: "No Insect"}
98
+ CACHE_DIR = pathlib.Path("hf_assets")
99
+ EXTRACT_DIR = CACHE_DIR / "predictor_native"
100
+ PREDICTOR = None
101
+
102
+ def _prepare_predictor_dir():
103
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
104
+ token = os.getenv("HF_TOKEN", None)
105
+ local_zip = huggingface_hub.hf_hub_download(
106
+ repo_id=MODEL_REPO_ID, filename=ZIP_FILENAME, repo_type="model",
107
+ token=token, local_dir=str(CACHE_DIR), local_dir_use_symlinks=False,
108
+ )
109
+ if EXTRACT_DIR.exists(): shutil.rmtree(EXTRACT_DIR)
110
+ EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
111
+ with zipfile.ZipFile(local_zip, "r") as zf: zf.extractall(str(EXTRACT_DIR))
112
+ contents = list(EXTRACT_DIR.iterdir())
113
+ return str(contents[0]) if (len(contents) == 1 and contents[0].is_dir()) else str(EXTRACT_DIR)
114
+
115
+ try:
116
+ PREDICTOR_DIR = _prepare_predictor_dir()
117
+ PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR)
118
+ PREDICTOR_LOAD_STATUS = "βœ… AutoGluon Predictor loaded successfully."
119
+ print(PREDICTOR_LOAD_STATUS)
120
+ except Exception as e:
121
+ PREDICTOR_LOAD_STATUS = f"❌ Failed to load AutoGluon Predictor: {e}"
122
+ print(PREDICTOR_LOAD_STATUS)
123
+ PREDICTOR = None
124
+
125
+ def do_predict(pil_img: PIL.Image.Image):
126
+ if PREDICTOR is None: return {"Error": 1.0}, "Model not loaded.", ""
127
+ if pil_img is None: return {"No Image": 1.0}, "No image provided.", ""
128
+ tmpdir = pathlib.Path(tempfile.mkdtemp())
129
+ img_path = tmpdir / "input.png"
130
+ pil_img.save(img_path)
131
+ df = pd.DataFrame({"image": [str(img_path)]})
132
+ proba_df = PREDICTOR.predict_proba(df).rename(columns=CLASS_LABELS)
133
+ row = proba_df.iloc[0]
134
+ pretty_dict = {label: float(row.get(label, 0.0)) for label in CLASS_LABELS.values()}
135
+ confidence_info = ", ".join([f"{label}: {prob:.2f}" for label, prob in pretty_dict.items()])
136
+ return pretty_dict, confidence_info
137
+
138
+ # ==============================================================================
139
+ # CELL 3: CORE LOGIC FOR TAB 2 (KDE ANALYSIS)
140
+ # ==============================================================================
141
+ pittsburgh_lat_min = 40.43950159029883
142
+ pittsburgh_lat_max = 40.44787067820301
143
+ pittsburgh_lon_min = -79.95054304624013
144
+ pittsburgh_lon_max = -79.93588847945053
145
+
146
+ def load_dataframe_from_huggingface():
147
+ try:
148
+ print("Loading data directly from Hugging Face dataset...")
149
+ dataset = load_dataset("rlogh/lanternfly-data", data_files="metadata/entries.jsonl", split="train")
150
+ df = dataset.to_pandas()
151
+ print("βœ… Data successfully loaded into a DataFrame.")
152
+ return df
153
+ except Exception as e:
154
+ print(f"❌ Error loading data from Hugging Face: {e}")
155
+ return None
156
+
157
+ def calculate_kde_from_dataframe(df):
158
+ try:
159
+ if 'latitude' not in df.columns or 'longitude' not in df.columns:
160
+ return None, None, None, "Error: DataFrame must contain 'latitude' and 'longitude' columns."
161
+ df.dropna(subset=['latitude', 'longitude'], inplace=True)
162
+ latitudes = df['latitude'].values
163
+ longitudes = df['longitude'].values
164
+ coordinates = np.vstack([longitudes, latitudes])
165
+ kde_object = gaussian_kde(coordinates)
166
+ return latitudes, longitudes, kde_object, None
167
+ except Exception as e:
168
+ return None, None, None, f"Error calculating KDE from DataFrame: {e}"
169
+
170
+ import math
171
+
172
+ def find_hotspot_landmark(original_latitudes, original_longitudes, kde_object):
173
+ """
174
+ Finds the hotspot and identifies the closest landmark from a predefined
175
+ custom list of campus locations.
176
+ """
177
+ # 1. Create your own dictionary of important campus landmarks
178
+ CAMPUS_LANDMARKS = {
179
+ "Scaife Hall": (40.441742986804336, -79.94725195600002),
180
+ "Hunt Library": (40.44097574857165, -79.94362666281333),
181
+ "Cohon University Center": (40.44401378993309, -79.94172335009584),
182
+ "Gates Hillman Complex": (40.4436463605335, -79.94442701667683),
183
+ "Wean Hall": (40.44267896399903, -79.94582169457243),
184
+ "Gesling Stadium": (40.443038206822905, -79.94038027450188),
185
+ "The Fence": (40.44221744932438, -79.9435687098247)
186
+ }
187
+
188
+ # 2. Find the coordinates of the densest point (same as before)
189
+ all_coords = np.vstack([original_longitudes, original_latitudes])
190
+ densities = kde_object(all_coords)
191
+ hotspot_index = np.argmax(densities)
192
+ hotspot_lat = original_latitudes[hotspot_index]
193
+ hotspot_lon = original_longitudes[hotspot_index]
194
+
195
+ # 3. Function to calculate the distance between two coordinates
196
+ def distance(lat1, lon1, lat2, lon2):
197
+ # A simple Euclidean distance is good enough for a small area like a campus
198
+ return math.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2)
199
+
200
+ # 4. Find the landmark from your list with the smallest distance to the hotspot
201
+ closest_landmark = min(
202
+ CAMPUS_LANDMARKS.keys(),
203
+ key=lambda landmark: distance(hotspot_lat, hotspot_lon, CAMPUS_LANDMARKS[landmark][0], CAMPUS_LANDMARKS[landmark][1])
204
+ )
205
+
206
+ return f"πŸ“ˆ **Hotspot Analysis**: The highest concentration was found closest to **{closest_landmark}** on campus."
207
+
208
+ def plot_kde_and_points_for_gradio(min_lat, max_lat, min_lon, max_lon, original_latitudes, original_longitudes, kde_object):
209
+ heatmap_path = "lanternfly_kde_heatmap.png"
210
+ x, y = np.mgrid[min_lon:max_lon:100j, min_lat:max_lat:100j]
211
+ positions = np.vstack([x.ravel(), y.ravel()])
212
+ z = kde_object(positions).reshape(x.shape)
213
+ z_normalized = (z - z.min()) / (z.max() - z.min()) if z.max() > z.min() else np.zeros_like(z)
214
+ fig, ax = plt.subplots(figsize=(8, 8))
215
+ im = ax.imshow(z_normalized.T, origin='lower', extent=[min_lon, max_lon, min_lat, max_lat], cmap='hot', aspect='auto')
216
+ fig.colorbar(im, ax=ax, label='Normalized Density (0-1)')
217
+ ax.set_title('Lanternfly Sightings KDE Heatmap (Static)')
218
+ plt.savefig(heatmap_path, bbox_inches='tight')
219
+ plt.close(fig)
220
+
221
+ m_colored_points = folium.Map()
222
+ bounds = [[min_lat, min_lon], [max_lat, max_lon]]
223
+ m_colored_points.fit_bounds(bounds)
224
+
225
+ original_coordinates = np.vstack([original_longitudes, original_latitudes])
226
+ density_at_points = kde_object(original_coordinates)
227
+ density_normalized_for_color = (density_at_points - density_at_points.min()) / (density_at_points.max() - density_at_points.min() + 1e-9)
228
+ max_density = density_at_points.max()
229
+ colormap = cm.get_cmap('viridis')
230
+
231
+ for lat, lon, density_norm_color in zip(original_latitudes, original_longitudes, density_normalized_for_color):
232
+ if min_lat <= lat <= max_lat and min_lon <= lon <= max_lon:
233
+ color = matplotlib.colors.rgb2hex(colormap(density_norm_color))
234
+ raw_density = kde_object([lon, lat])[0]
235
+ normalized_tooltip_density = raw_density / max_density if max_density > 0 else 0
236
+ folium.CircleMarker(
237
+ location=[lat, lon], radius=5, color=color, fill=True,
238
+ fill_color=color, fill_opacity=0.7,
239
+ tooltip=f"Normalized Density: {normalized_tooltip_density:.4f}"
240
+ ).add_to(m_colored_points)
241
+
242
+ return heatmap_path, m_colored_points._repr_html_()
243
+
244
+ import joblib # Make sure this is imported at the top
245
+
246
+ def load_kde_from_hub():
247
+ """
248
+ Downloads the pre-trained KDE model from the Hugging Face Hub and loads it.
249
+ """
250
+ try:
251
+ print("Downloading pre-trained KDE model...")
252
+ model_path = huggingface_hub.hf_hub_download(
253
+ repo_id="ddecosmo/lanternfly-kde-model", # Use the same repo_id from the upload script
254
+ filename="kde_model.joblib",
255
+ repo_type="model"
256
+ )
257
+ kde_model = joblib.load(model_path)
258
+ print("βœ… Pre-trained KDE model loaded.")
259
+ return kde_model
260
+ except Exception as e:
261
+ print(f"❌ Failed to load KDE model from Hub: {e}")
262
+ return None
263
+
264
+ def run_full_analysis_and_update_ui():
265
+ """
266
+ This function is now much faster. It loads the pre-trained KDE and all
267
+ the raw data points for visualization.
268
+ """
269
+ # --- Load both the pre-trained model and the raw data ---
270
+ kde_object = load_kde_from_hub()
271
+ lanternfly_df = load_dataframe_from_huggingface()
272
+
273
+ if kde_object is None or lanternfly_df is None:
274
+ return gr.Image(visible=False), gr.HTML("<h3>Error: Could not load model or data from Hub.</h3>", visible=True), gr.Markdown(visible=False)
275
+
276
+ # We still need the raw lat/lon to display the points on the Folium map
277
+ latitudes = lanternfly_df['latitude'].values
278
+ longitudes = lanternfly_df['longitude'].values
279
+
280
+ # --- The rest of the function remains the same ---
281
+ print("Generating visualizations with pre-trained model...")
282
+ heatmap_path, interactive_map_html = plot_kde_and_points_for_gradio(
283
+ pittsburgh_lat_min, pittsburgh_lat_max,
284
+ pittsburgh_lon_min, pittsburgh_lon_max,
285
+ latitudes, longitudes, kde_object
286
+ )
287
+
288
+ print("Finding hotspot landmark...")
289
+ hotspot_message = find_hotspot_landmark(latitudes, longitudes, kde_object)
290
+
291
+ return (
292
+ gr.Image(value=heatmap_path, visible=True),
293
+ gr.HTML(value=interactive_map_html, visible=True),
294
+ gr.Markdown(value=hotspot_message, visible=True)
295
+ )
296
+
297
+ # ==============================================================================
298
+ # CELL 4: GRADIO UI DEFINITIONS
299
+ # ==============================================================================
300
+
301
+ def field_capture_ui(camera):
302
+ with gr.Blocks():
303
+ gr.Markdown("#Lanternfly Data Logging")
304
+ with gr.Column(scale=1):
305
+ gr.Markdown("### πŸ“ Location Data")
306
+ gps_btn = gr.Button("πŸ“ Get GPS", variant="primary")
307
+ with gr.Row():
308
+ lat_box = gr.Textbox(label="Latitude", interactive=True, value="0.0", elem_id="lat")
309
+ lon_box = gr.Textbox(label="Longitude", interactive=True, value="0.0", elem_id="lon")
310
+ with gr.Row():
311
+ accuracy_box = gr.Textbox(label="Accuracy (meters)", interactive=True, value="0.0", elem_id="accuracy")
312
+ device_ts_box = gr.Textbox(label="Device Timestamp", interactive=True, elem_id="device_ts")
313
+ time_btn = gr.Button("πŸ• Get Current Time")
314
+ save_btn = gr.Button("πŸ’Ύ Save (Test Mode)")
315
+ status = gr.Markdown("πŸ”„ **Ready**")
316
+ preview = gr.JSON(label="Preview JSON")
317
+ gps_btn.click(fn=None, inputs=[], outputs=[], js=get_gps_js())
318
+ time_btn.click(fn=placeholder_time_capture, inputs=[], outputs=[status, device_ts_box])
319
+ save_btn.click(fn=placeholder_save_action, inputs=[camera, lat_box, lon_box, accuracy_box, device_ts_box], outputs=[status, preview])
320
+
321
+ def image_model_ui(image_in):
322
+ with gr.Blocks():
323
+ gr.Markdown("# Image Classification Results")
324
+ gr.Markdown("Uses an EfficientNetB1 model to classify the uploaded image.")
325
+
326
+ if PREDICTOR is None:
327
+ gr.Warning(PREDICTOR_LOAD_STATUS)
328
+
329
+ with gr.Row():
330
+ proba_pretty = gr.Label(num_top_classes=2, label="Class Probabilities")
331
+ confidence_output = gr.Textbox(label="Prediction Summary")
332
+
333
+ # Attach prediction logic to the passed-in image component
334
+ image_in.change(
335
+ fn=do_predict,
336
+ inputs=[image_in],
337
+ outputs=[proba_pretty, confidence_output]
338
+ )
339
+
340
+ # ** NEW / UPDATED **: Add the example images section here
341
+ # This assumes you have an 'examples' folder with these images in it.
342
+ gr.Examples(
343
+ examples=[
344
+ "examples/lanternfly_example.jpg",
345
+ "examples/other_insect_example.jpg",
346
+ "examples/no_insect_example.jpg"
347
+ ],
348
+ inputs=[image_in],
349
+ label="Click an Example to Classify",
350
+ examples_per_page=3
351
+ )
352
+
353
+ def kde_analysis_ui():
354
+ """
355
+ Renders the complete UI for the KDE tab with the controls on top
356
+ and the outputs below.
357
+ """
358
+ # --- 1. UI Controls (These will appear on top) ---
359
+ gr.Markdown("# Spotted Lanternfly Kernel Density Estimation Analysis")
360
+ gr.Markdown("Click the button to generate a Kernel Density Estimation (KDE) analysis based on the data gathered from the classification tab.")
361
+ gr.Markdown("This data can be found at rlogh/lanternfly-data on Hugging Face and contains images, geolocal, and temporal data for all samples.")
362
+ gr.Markdown("This dataset is public and available for use for any research or learning purposes.")
363
+
364
+ btn = gr.Button("Generate KDE Visualizations")
365
+
366
+ # --- 2. Output Areas (These will appear below the button) ---
367
+ message_output = gr.Markdown(visible=False)
368
+ with gr.Row():
369
+ heatmap_output = gr.Image(label="KDE Heatmap (Static)", visible=False)
370
+ map_output = gr.HTML(label="Interactive Density Map", visible=False)
371
+
372
+ # --- 3. Link the Button to the Function and Outputs ---
373
+ btn.click(
374
+ fn=run_full_analysis_and_update_ui,
375
+ inputs=None,
376
+ outputs=[heatmap_output, map_output, message_output]
377
+ )
378
+
379
+ with gr.Blocks(title="Unified Lanternfly App") as app:
380
+ gr.Markdown("# Lanternfly Tracker")
381
+ gr.Markdown("This application allows for the tracking of concentrated lanternflies, mainly around Carnegie Mellon University.")
382
+ gr.Markdown("It combines two tools: (1) A field capture and AI Image classifer for identifying lanternflies, and (2) a Kernel Density Estimation (KDE) ML model to visualize lanternfly hotspots on campus.")
383
+ gr.Markdown("Photos can be taken and classified as Lanternflies in the Capture & Classification tab. In future this data can be saved in real time to the dataset")
384
+ gr.Markdown("To view the overal distribution of lanternflies based on collected data, use the Spatial Analysis (KDE) tab.")
385
+
386
+
387
+ # TAB 1: (Unchanged)
388
+ with gr.Tab("Capture & Classification"):
389
+ gr.Info("GPS functionality is now enabled! Data saving is in test mode.")
390
+ shared_image_input = gr.Image(
391
+ streaming=False, height=380, label="πŸ“· Upload Photo (or use camera)",
392
+ type="pil", sources=["webcam", "upload"]
393
+ )
394
+ with gr.Row():
395
+ with gr.Column(scale=1):
396
+ image_model_ui(shared_image_input)
397
+ with gr.Column(scale=1):
398
+ field_capture_ui(shared_image_input)
399
+
400
+ # TAB 2: KDE ANALYSIS (Simplified and Corrected)
401
+ with gr.Tab("Spatial Analysis (KDE)"):
402
+ # This single function call now builds the entire tab correctly.
403
+ kde_analysis_ui()
404
+
405
+ # Launch the app
406
+ if __name__ == "__main__":
407
+ app.launch()