JackRabbit commited on
Commit
b644be2
·
1 Parent(s): 2995c7a

api updates

Browse files
Files changed (1) hide show
  1. app.py +39 -75
app.py CHANGED
@@ -14,34 +14,23 @@ import re
14
  log_filename = "model_predictions.log"
15
  logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s')
16
 
17
- # Set the page to wide mode
18
- st.set_page_config(page_title="Honey Bee Image Classification")
19
 
20
- # -------------------------
21
- # MODEL LOADING
22
- # -------------------------
23
  @st.cache_resource
24
  def load_model():
25
  repo_id = "Honey-Bee-Society/honeybee_ml_v1"
26
-
27
- # Download the model files from Hugging Face
28
  local_dir = snapshot_download(repo_id)
29
 
30
- # Ensure the necessary files exist in the local directory
31
  assets_path = os.path.join(local_dir, "assets.json")
32
  model_checkpoint = os.path.join(local_dir, "model.ckpt")
33
 
34
  if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
35
  raise FileNotFoundError("Required model files not found in the downloaded directory.")
36
 
37
- # Load the model using the downloaded directory path
38
  return MultiModalPredictor.load(local_dir)
39
 
40
- # -------------------------
41
- # HELPER FUNCTIONS
42
- # -------------------------
43
  def resize_image_proportionally(image, max_size_mb=1):
44
- """Resize the image if it exceeds max_size_mb in memory."""
45
  img_byte_array = io.BytesIO()
46
  image.save(img_byte_array, format='PNG')
47
  img_size = len(img_byte_array.getvalue()) / (1024 * 1024)
@@ -55,7 +44,6 @@ def resize_image_proportionally(image, max_size_mb=1):
55
  return image
56
 
57
  def predict_image(image, predictor):
58
- """Predict probabilities for an in-memory PIL image using the given predictor."""
59
  img_byte_array = io.BytesIO()
60
  image.save(img_byte_array, format='PNG')
61
  img_data = img_byte_array.getvalue()
@@ -64,23 +52,19 @@ def predict_image(image, predictor):
64
  return probabilities
65
 
66
  def save_image(image, img_name, target_size_kb=500):
67
- """Compress and save the image to ensure it is <= target_size_kb KB."""
68
  processed_image_path = os.path.join("processed_images", img_name)
69
-
70
  if not os.path.exists("processed_images"):
71
  os.makedirs("processed_images")
72
 
73
- quality = 95 # Start with high quality
74
  img_byte_array = io.BytesIO()
75
 
76
- while quality > 10: # Stop if quality gets too low
77
  img_byte_array.seek(0)
78
  image.save(img_byte_array, format='JPEG', quality=quality)
79
  img_size_kb = len(img_byte_array.getvalue()) / 1024
80
-
81
  if img_size_kb <= target_size_kb:
82
  break
83
-
84
  quality -= 5
85
 
86
  with open(processed_image_path, "wb") as f:
@@ -97,12 +81,10 @@ def log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score)
97
  )
98
 
99
  def sanitize_filename(filename):
100
- """Remove unsafe characters from filenames."""
101
  safe_filename = re.sub(r'[^A-Za-z0-9_.-]', '_', filename)
102
  return safe_filename
103
 
104
  def check_file_size(uploaded_file, max_size_mb=10):
105
- """Return False if file size exceeds `max_size_mb`."""
106
  uploaded_file.seek(0, os.SEEK_END)
107
  file_size = uploaded_file.tell() / (1024 * 1024)
108
  uploaded_file.seek(0)
@@ -111,27 +93,28 @@ def check_file_size(uploaded_file, max_size_mb=10):
111
  return False
112
  return True
113
 
114
- # -------------------------
115
- # API HANDLER
116
- # -------------------------
117
  def run_api(predictor):
118
  """
119
- A simple 'API-like' endpoint in Streamlit.
120
-
121
- Usage example:
122
- ?api=1&image_url=https://somewhere.com/bee.jpg
 
 
 
 
 
 
 
 
 
123
  """
124
- params = st.query_params # Replaced st.experimental_get_query_params with st.query_params
125
- # image_url = params.get("image_url", )
126
- image_url = params.get("image_url")
127
-
128
- st.write("DEBUG: We are inside run_api()!")
129
- st.write("DEBUG: st.query_params:", params)
130
-
131
 
132
  if not image_url:
133
- st.json({"error": "No 'image_url' provided. Example: ?api=1&image_url=<URL>"})
134
- return
135
 
136
  # Download the image
137
  response = requests.get(
@@ -141,23 +124,23 @@ def run_api(predictor):
141
 
142
  if response.status_code != 200:
143
  st.json({"error": f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"})
144
- return
145
 
146
  image_bytes = response.content
147
- # Check file size (limit 10MB as in the UI)
148
  image_size_mb = len(image_bytes)/(1024*1024)
149
  if image_size_mb > 10:
150
  st.json({"error": f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."})
151
- return
152
 
153
- # Convert to PIL for processing
154
  try:
155
  image = Image.open(io.BytesIO(image_bytes))
156
  except Exception as e:
157
  st.json({"error": f"Could not open image: {e}"})
158
- return
159
 
160
- # Optional: resize to keep memory usage low (same logic as UI)
161
  image = resize_image_proportionally(image)
162
 
163
  # Predict
@@ -168,7 +151,7 @@ def run_api(predictor):
168
  vespidae_score = float(probabilities[3].iloc[0]) * 100
169
  except Exception as e:
170
  st.json({"error": f"Prediction failed: {e}"})
171
- return
172
 
173
  # Determine highest-scoring label
174
  highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
@@ -182,35 +165,28 @@ def run_api(predictor):
182
  else:
183
  prediction_label = "Vespidae (wasp/hornet)"
184
 
185
- # Return results as JSON
186
  st.json({
187
  "honeybee_score": honeybee_score,
188
  "bumblebee_score": bumblebee_score,
189
  "vespidae_score": vespidae_score,
190
  "prediction_label": prediction_label
191
  })
 
192
 
193
- # -------------------------
194
- # UI HANDLER
195
- # -------------------------
196
  def run_ui(predictor):
197
  st.title("Honey Bee Image Classification")
198
 
199
- # File uploader
200
  uploaded_file = st.file_uploader(
201
- "Upload a photo of the suspected bee to see if you have honey bees. :bee:",
202
  type=["png", "jpg", "jpeg"]
203
  )
204
 
205
  with st.expander("ML Model Details"):
206
  st.write("""
207
- We trained a MultiModalPredictor from the AutoGluon library to classify images of bees,
208
- focusing primarily on Honey Bees. The model is fine-tuned on a curated dataset from inaturalist
209
- images (70k+ images) with an accuracy of ~97.5%. It classifies the image as Honey Bee, Bumblebee,
210
- or a Vespidae (wasp/hornet).
211
-
212
- **Open Source**:
213
- [Honey-Bee-Society/honeybee_ml_v1](https://huggingface.co/Honey-Bee-Society/honeybee_ml_v1)
214
  """)
215
 
216
  if uploaded_file is not None:
@@ -227,50 +203,38 @@ def run_ui(predictor):
227
  bumblebee_score = float(probabilities[2].iloc[0]) * 100
228
  vespidae_score = float(probabilities[3].iloc[0]) * 100
229
 
230
- # Generate a safe and unique filename
231
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
232
  sanitized_filename = sanitize_filename(uploaded_file.name)
233
  img_name = f"processed_{sanitized_filename}_{timestamp}.jpg"
234
 
235
- # Save compressed image
236
  image_path = save_image(image, img_name)
237
-
238
- # Log predictions
239
  log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score)
240
 
241
- # Find highest score
242
  highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
243
 
244
- # Display result
245
  if highest_score < 80:
246
- st.warning("We are fairly confident there is no bee in this photo. Try another image.")
247
  else:
248
  if honeybee_score == highest_score:
249
  st.success("Yes! This is a honey bee!")
250
  elif bumblebee_score == highest_score:
251
- st.info("This is likely a bumblebee, not a honey bee.")
252
  else:
253
- st.info("This is likely a member of the vespidae family (wasp, hornet, etc.).")
254
 
255
  except Exception as e:
256
  st.error(f"An error occurred: {e}")
257
  finally:
258
  progress_bar.empty()
259
 
260
- # -------------------------
261
- # MAIN ENTRY POINT
262
- # -------------------------
263
  def main():
264
  predictor = load_model()
265
 
266
- # Check if we're in "API mode" or "UI mode"
267
- query_params = st.query_params # Replaced st.experimental_get_query_params with st.query_params
268
  if "api" in query_params:
269
- # Run as an API (no UI)
270
-
271
  run_api(predictor)
272
  else:
273
- # Run the standard UI
274
  run_ui(predictor)
275
 
276
  if __name__ == '__main__':
 
14
  log_filename = "model_predictions.log"
15
  logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s')
16
 
17
+ # Set the page config
18
+ st.set_page_config(page_title="Honey Bee Image Classification", layout="wide")
19
 
 
 
 
20
  @st.cache_resource
21
  def load_model():
22
  repo_id = "Honey-Bee-Society/honeybee_ml_v1"
 
 
23
  local_dir = snapshot_download(repo_id)
24
 
 
25
  assets_path = os.path.join(local_dir, "assets.json")
26
  model_checkpoint = os.path.join(local_dir, "model.ckpt")
27
 
28
  if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
29
  raise FileNotFoundError("Required model files not found in the downloaded directory.")
30
 
 
31
  return MultiModalPredictor.load(local_dir)
32
 
 
 
 
33
  def resize_image_proportionally(image, max_size_mb=1):
 
34
  img_byte_array = io.BytesIO()
35
  image.save(img_byte_array, format='PNG')
36
  img_size = len(img_byte_array.getvalue()) / (1024 * 1024)
 
44
  return image
45
 
46
  def predict_image(image, predictor):
 
47
  img_byte_array = io.BytesIO()
48
  image.save(img_byte_array, format='PNG')
49
  img_data = img_byte_array.getvalue()
 
52
  return probabilities
53
 
54
  def save_image(image, img_name, target_size_kb=500):
 
55
  processed_image_path = os.path.join("processed_images", img_name)
 
56
  if not os.path.exists("processed_images"):
57
  os.makedirs("processed_images")
58
 
59
+ quality = 95
60
  img_byte_array = io.BytesIO()
61
 
62
+ while quality > 10:
63
  img_byte_array.seek(0)
64
  image.save(img_byte_array, format='JPEG', quality=quality)
65
  img_size_kb = len(img_byte_array.getvalue()) / 1024
 
66
  if img_size_kb <= target_size_kb:
67
  break
 
68
  quality -= 5
69
 
70
  with open(processed_image_path, "wb") as f:
 
81
  )
82
 
83
  def sanitize_filename(filename):
 
84
  safe_filename = re.sub(r'[^A-Za-z0-9_.-]', '_', filename)
85
  return safe_filename
86
 
87
  def check_file_size(uploaded_file, max_size_mb=10):
 
88
  uploaded_file.seek(0, os.SEEK_END)
89
  file_size = uploaded_file.tell() / (1024 * 1024)
90
  uploaded_file.seek(0)
 
93
  return False
94
  return True
95
 
 
 
 
96
  def run_api(predictor):
97
  """
98
+ 'API mode' for this Streamlit app.
99
+ Expects a query param ?api=1&image_url=<PUBLIC_IMAGE_URL>
100
+
101
+ Example usage (from command line):
102
+ curl -X GET "https://your-username-your-app.hf.space/?api=1&image_url=https://raw.githubusercontent.com/yourimage.jpg"
103
+
104
+ The response is HTML with an embedded JSON, but you can often parse it directly in Python:
105
+ >>> import requests
106
+ >>> response = requests.get("https://your-username-your-app.hf.space/?api=1&image_url=...")
107
+ >>> print(response.text) # prints the entire HTML with JSON
108
+ # or sometimes:
109
+ >>> data = response.json() # may work depending on how the client interprets the response
110
+ >>> print(data)
111
  """
112
+ params = st.experimental_get_query_params() # or st.query_params in Streamlit 1.19+
113
+ image_url = params.get("image_url", [None])[0]
 
 
 
 
 
114
 
115
  if not image_url:
116
+ st.json({"error": "No 'image_url' provided. Usage: ?api=1&image_url=<URL>"})
117
+ st.stop()
118
 
119
  # Download the image
120
  response = requests.get(
 
124
 
125
  if response.status_code != 200:
126
  st.json({"error": f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"})
127
+ st.stop()
128
 
129
  image_bytes = response.content
130
+ # Check file size (limit 10MB)
131
  image_size_mb = len(image_bytes)/(1024*1024)
132
  if image_size_mb > 10:
133
  st.json({"error": f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."})
134
+ st.stop()
135
 
136
+ # Convert to PIL
137
  try:
138
  image = Image.open(io.BytesIO(image_bytes))
139
  except Exception as e:
140
  st.json({"error": f"Could not open image: {e}"})
141
+ st.stop()
142
 
143
+ # Resize
144
  image = resize_image_proportionally(image)
145
 
146
  # Predict
 
151
  vespidae_score = float(probabilities[3].iloc[0]) * 100
152
  except Exception as e:
153
  st.json({"error": f"Prediction failed: {e}"})
154
+ st.stop()
155
 
156
  # Determine highest-scoring label
157
  highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
 
165
  else:
166
  prediction_label = "Vespidae (wasp/hornet)"
167
 
168
+ # Return results as JSON and stop further Streamlit processing
169
  st.json({
170
  "honeybee_score": honeybee_score,
171
  "bumblebee_score": bumblebee_score,
172
  "vespidae_score": vespidae_score,
173
  "prediction_label": prediction_label
174
  })
175
+ st.stop()
176
 
 
 
 
177
  def run_ui(predictor):
178
  st.title("Honey Bee Image Classification")
179
 
 
180
  uploaded_file = st.file_uploader(
181
+ "Upload a photo of the suspected bee...",
182
  type=["png", "jpg", "jpeg"]
183
  )
184
 
185
  with st.expander("ML Model Details"):
186
  st.write("""
187
+ We trained a MultiModalPredictor to classify bee images
188
+ (Honey Bee, Bumblebee, or Vespidae).
189
+ Accuracy is ~97.5% on our test set.
 
 
 
 
190
  """)
191
 
192
  if uploaded_file is not None:
 
203
  bumblebee_score = float(probabilities[2].iloc[0]) * 100
204
  vespidae_score = float(probabilities[3].iloc[0]) * 100
205
 
 
206
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
207
  sanitized_filename = sanitize_filename(uploaded_file.name)
208
  img_name = f"processed_{sanitized_filename}_{timestamp}.jpg"
209
 
 
210
  image_path = save_image(image, img_name)
 
 
211
  log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score)
212
 
 
213
  highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
214
 
 
215
  if highest_score < 80:
216
+ st.warning("We are fairly confident there is no bee in this photo.")
217
  else:
218
  if honeybee_score == highest_score:
219
  st.success("Yes! This is a honey bee!")
220
  elif bumblebee_score == highest_score:
221
+ st.info("Likely a bumblebee, not a honey bee.")
222
  else:
223
+ st.info("Likely a wasp/hornet (vespidae).")
224
 
225
  except Exception as e:
226
  st.error(f"An error occurred: {e}")
227
  finally:
228
  progress_bar.empty()
229
 
 
 
 
230
  def main():
231
  predictor = load_model()
232
 
233
+ # Decide whether we are in 'API mode' or normal UI mode
234
+ query_params = st.experimental_get_query_params()
235
  if "api" in query_params:
 
 
236
  run_api(predictor)
237
  else:
 
238
  run_ui(predictor)
239
 
240
  if __name__ == '__main__':