JackRabbit commited on
Commit
885f8ec
·
1 Parent(s): 41e69d7

added app file

Browse files
Files changed (1) hide show
  1. app.py +272 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import pandas as pd
4
+ import io
5
+ import os
6
+ import requests
7
+ from autogluon.multimodal import MultiModalPredictor
8
+ from huggingface_hub import snapshot_download
9
+ import logging
10
+ import datetime
11
+ import re
12
+
13
+ # Configure logging
14
+ log_filename = "model_predictions.log"
15
+ logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s')
16
+
17
+ # Set the page to wide mode
18
+ st.set_page_config(page_title="Honey Bee Image Classification")
19
+
20
+ # -------------------------
21
+ # MODEL LOADING
22
+ # -------------------------
23
+ @st.cache_resource
24
+ def load_model():
25
+ repo_id = "Honey-Bee-Society/honeybee_ml_v1"
26
+
27
+ # Download the model files from Hugging Face
28
+ local_dir = snapshot_download(repo_id)
29
+
30
+ # Ensure the necessary files exist in the local directory
31
+ assets_path = os.path.join(local_dir, "assets.json")
32
+ model_checkpoint = os.path.join(local_dir, "model.ckpt")
33
+
34
+ if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
35
+ raise FileNotFoundError("Required model files not found in the downloaded directory.")
36
+
37
+ # Load the model using the downloaded directory path
38
+ return MultiModalPredictor.load(local_dir)
39
+
40
+ # -------------------------
41
+ # HELPER FUNCTIONS
42
+ # -------------------------
43
+ def resize_image_proportionally(image, max_size_mb=1):
44
+ """Resize the image if it exceeds max_size_mb in memory."""
45
+ img_byte_array = io.BytesIO()
46
+ image.save(img_byte_array, format='PNG')
47
+ img_size = len(img_byte_array.getvalue()) / (1024 * 1024)
48
+
49
+ if img_size > max_size_mb:
50
+ scale_factor = (max_size_mb / img_size) ** 0.5
51
+ new_width = int(image.width * scale_factor)
52
+ new_height = int(image.height * scale_factor)
53
+ image = image.resize((new_width, new_height))
54
+
55
+ return image
56
+
57
+ def predict_image(image, predictor):
58
+ """Predict probabilities for an in-memory PIL image using the given predictor."""
59
+ img_byte_array = io.BytesIO()
60
+ image.save(img_byte_array, format='PNG')
61
+ img_data = img_byte_array.getvalue()
62
+ df = pd.DataFrame({"image": [img_data]})
63
+ probabilities = predictor.predict_proba(df, realtime=True)
64
+ return probabilities
65
+
66
+ def save_image(image, img_name, target_size_kb=500):
67
+ """Compress and save the image to ensure it is <= target_size_kb KB."""
68
+ processed_image_path = os.path.join("processed_images", img_name)
69
+
70
+ if not os.path.exists("processed_images"):
71
+ os.makedirs("processed_images")
72
+
73
+ quality = 95 # Start with high quality
74
+ img_byte_array = io.BytesIO()
75
+
76
+ while quality > 10: # Stop if quality gets too low
77
+ img_byte_array.seek(0)
78
+ image.save(img_byte_array, format='JPEG', quality=quality)
79
+ img_size_kb = len(img_byte_array.getvalue()) / 1024
80
+
81
+ if img_size_kb <= target_size_kb:
82
+ break
83
+
84
+ quality -= 5
85
+
86
+ with open(processed_image_path, "wb") as f:
87
+ f.write(img_byte_array.getvalue())
88
+
89
+ return processed_image_path
90
+
91
+ def log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score):
92
+ logging.info(
93
+ f"Image Path: {image_path}, "
94
+ f"Honeybee: {honeybee_score:.2f}%, "
95
+ f"Bumblebee: {bumblebee_score:.2f}%, "
96
+ f"Vespidae: {vespidae_score:.2f}%"
97
+ )
98
+
99
+ def sanitize_filename(filename):
100
+ """Remove unsafe characters from filenames."""
101
+ safe_filename = re.sub(r'[^A-Za-z0-9_.-]', '_', filename)
102
+ return safe_filename
103
+
104
+ def check_file_size(uploaded_file, max_size_mb=10):
105
+ """Return False if file size exceeds `max_size_mb`."""
106
+ uploaded_file.seek(0, os.SEEK_END)
107
+ file_size = uploaded_file.tell() / (1024 * 1024)
108
+ uploaded_file.seek(0)
109
+ if file_size > max_size_mb:
110
+ st.error(f"File size exceeds {max_size_mb}MB limit. Please upload a smaller file.")
111
+ return False
112
+ return True
113
+
114
+ # -------------------------
115
+ # API HANDLER
116
+ # -------------------------
117
+ def run_api(predictor):
118
+ """
119
+ A simple 'API-like' endpoint in Streamlit.
120
+
121
+ Usage example:
122
+ ?api=1&image_url=https://somewhere.com/bee.jpg
123
+ """
124
+ params = st.query_params # Replaced st.experimental_get_query_params with st.query_params
125
+ image_url = params.get("image_url")
126
+
127
+ if not image_url:
128
+ st.json({"error": "No 'image_url' provided. Example: ?api=1&image_url=<URL>"})
129
+ return
130
+
131
+ # Download the image
132
+ response = requests.get(
133
+ image_url,
134
+ headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://honeybeeclassification.streamlit.app)"}
135
+ )
136
+
137
+ if response.status_code != 200:
138
+ st.json({"error": f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"})
139
+ return
140
+
141
+ image_bytes = response.content
142
+ # Check file size (limit 10MB as in the UI)
143
+ image_size_mb = len(image_bytes)/(1024*1024)
144
+ if image_size_mb > 10:
145
+ st.json({"error": f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."})
146
+ return
147
+
148
+ # Convert to PIL for processing
149
+ try:
150
+ image = Image.open(io.BytesIO(image_bytes))
151
+ except Exception as e:
152
+ st.json({"error": f"Could not open image: {e}"})
153
+ return
154
+
155
+ # Optional: resize to keep memory usage low (same logic as UI)
156
+ image = resize_image_proportionally(image)
157
+
158
+ # Predict
159
+ try:
160
+ probabilities = predict_image(image, predictor)
161
+ honeybee_score = float(probabilities[1].iloc[0]) * 100
162
+ bumblebee_score = float(probabilities[2].iloc[0]) * 100
163
+ vespidae_score = float(probabilities[3].iloc[0]) * 100
164
+ except Exception as e:
165
+ st.json({"error": f"Prediction failed: {e}"})
166
+ return
167
+
168
+ # Determine highest-scoring label
169
+ highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
170
+ if highest_score < 80:
171
+ prediction_label = "No bee detected (scores too low)."
172
+ else:
173
+ if honeybee_score == highest_score:
174
+ prediction_label = "Honey Bee"
175
+ elif bumblebee_score == highest_score:
176
+ prediction_label = "Bumblebee"
177
+ else:
178
+ prediction_label = "Vespidae (wasp/hornet)"
179
+
180
+ # Return results as JSON
181
+ st.json({
182
+ "honeybee_score": honeybee_score,
183
+ "bumblebee_score": bumblebee_score,
184
+ "vespidae_score": vespidae_score,
185
+ "prediction_label": prediction_label
186
+ })
187
+
188
+ # -------------------------
189
+ # UI HANDLER
190
+ # -------------------------
191
+ def run_ui(predictor):
192
+ st.title("Honey Bee Image Classification")
193
+
194
+ # File uploader
195
+ uploaded_file = st.file_uploader(
196
+ "Upload a photo of the suspected bee to see if you have honey bees. :bee:",
197
+ type=["png", "jpg", "jpeg"]
198
+ )
199
+
200
+ with st.expander("ML Model Details"):
201
+ st.write("""
202
+ We trained a MultiModalPredictor from the AutoGluon library to classify images of bees,
203
+ focusing primarily on Honey Bees. The model is fine-tuned on a curated dataset from inaturalist
204
+ images (70k+ images) with an accuracy of ~97.5%. It classifies the image as Honey Bee, Bumblebee,
205
+ or a Vespidae (wasp/hornet).
206
+
207
+ **Open Source**:
208
+ [Honey-Bee-Society/honeybee_ml_v1](https://huggingface.co/Honey-Bee-Society/honeybee_ml_v1)
209
+ """)
210
+
211
+ if uploaded_file is not None:
212
+ if check_file_size(uploaded_file):
213
+ image = Image.open(uploaded_file)
214
+ image = resize_image_proportionally(image)
215
+
216
+ progress_bar = st.progress(0)
217
+ try:
218
+ probabilities = predict_image(image, predictor)
219
+ progress_bar.progress(100)
220
+
221
+ honeybee_score = float(probabilities[1].iloc[0]) * 100
222
+ bumblebee_score = float(probabilities[2].iloc[0]) * 100
223
+ vespidae_score = float(probabilities[3].iloc[0]) * 100
224
+
225
+ # Generate a safe and unique filename
226
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
227
+ sanitized_filename = sanitize_filename(uploaded_file.name)
228
+ img_name = f"processed_{sanitized_filename}_{timestamp}.jpg"
229
+
230
+ # Save compressed image
231
+ image_path = save_image(image, img_name)
232
+
233
+ # Log predictions
234
+ log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score)
235
+
236
+ # Find highest score
237
+ highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
238
+
239
+ # Display result
240
+ if highest_score < 80:
241
+ st.warning("We are fairly confident there is no bee in this photo. Try another image.")
242
+ else:
243
+ if honeybee_score == highest_score:
244
+ st.success("Yes! This is a honey bee!")
245
+ elif bumblebee_score == highest_score:
246
+ st.info("This is likely a bumblebee, not a honey bee.")
247
+ else:
248
+ st.info("This is likely a member of the vespidae family (wasp, hornet, etc.).")
249
+
250
+ except Exception as e:
251
+ st.error(f"An error occurred: {e}")
252
+ finally:
253
+ progress_bar.empty()
254
+
255
+ # -------------------------
256
+ # MAIN ENTRY POINT
257
+ # -------------------------
258
+ def main():
259
+ predictor = load_model()
260
+
261
+ # Check if we're in "API mode" or "UI mode"
262
+ query_params = st.query_params # Replaced st.experimental_get_query_params with st.query_params
263
+ if "api" in query_params:
264
+ # Run as an API (no UI)
265
+
266
+ run_api(predictor)
267
+ else:
268
+ # Run the standard UI
269
+ run_ui(predictor)
270
+
271
+ if __name__ == '__main__':
272
+ main()