File size: 8,240 Bytes
885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 253cf58 885f8ec 253cf58 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 253cf58 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 885f8ec 253cf58 885f8ec 253cf58 b644be2 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 885f8ec b644be2 253cf58 885f8ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import streamlit as st
from PIL import Image
import pandas as pd
import io
import os
import requests
from autogluon.multimodal import MultiModalPredictor
from huggingface_hub import snapshot_download
import logging
import datetime
import re
# Configure logging
log_filename = "model_predictions.log"
logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s')
# Set the page config
st.set_page_config(page_title="Honey Bee Image Classification", layout="wide")
@st.cache_resource
def load_model():
repo_id = "Honey-Bee-Society/honeybee_ml_v1"
local_dir = snapshot_download(repo_id)
assets_path = os.path.join(local_dir, "assets.json")
model_checkpoint = os.path.join(local_dir, "model.ckpt")
if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
raise FileNotFoundError("Required model files not found in the downloaded directory.")
return MultiModalPredictor.load(local_dir)
def resize_image_proportionally(image, max_size_mb=1):
img_byte_array = io.BytesIO()
image.save(img_byte_array, format='PNG')
img_size = len(img_byte_array.getvalue()) / (1024 * 1024)
if img_size > max_size_mb:
scale_factor = (max_size_mb / img_size) ** 0.5
new_width = int(image.width * scale_factor)
new_height = int(image.height * scale_factor)
image = image.resize((new_width, new_height))
return image
def predict_image(image, predictor):
img_byte_array = io.BytesIO()
image.save(img_byte_array, format='PNG')
img_data = img_byte_array.getvalue()
df = pd.DataFrame({"image": [img_data]})
probabilities = predictor.predict_proba(df, realtime=True)
return probabilities
def save_image(image, img_name, target_size_kb=500):
processed_image_path = os.path.join("processed_images", img_name)
if not os.path.exists("processed_images"):
os.makedirs("processed_images")
quality = 95
img_byte_array = io.BytesIO()
while quality > 10:
img_byte_array.seek(0)
image.save(img_byte_array, format='JPEG', quality=quality)
img_size_kb = len(img_byte_array.getvalue()) / 1024
if img_size_kb <= target_size_kb:
break
quality -= 5
with open(processed_image_path, "wb") as f:
f.write(img_byte_array.getvalue())
return processed_image_path
def log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score):
logging.info(
f"Image Path: {image_path}, "
f"Honeybee: {honeybee_score:.2f}%, "
f"Bumblebee: {bumblebee_score:.2f}%, "
f"Vespidae: {vespidae_score:.2f}%"
)
def sanitize_filename(filename):
safe_filename = re.sub(r'[^A-Za-z0-9_.-]', '_', filename)
return safe_filename
def check_file_size(uploaded_file, max_size_mb=10):
uploaded_file.seek(0, os.SEEK_END)
file_size = uploaded_file.tell() / (1024 * 1024)
uploaded_file.seek(0)
if file_size > max_size_mb:
st.error(f"File size exceeds {max_size_mb}MB limit. Please upload a smaller file.")
return False
return True
def run_api(predictor):
"""
'API mode' for this Streamlit app.
Expects a query param ?api=1&image_url=<PUBLIC_IMAGE_URL>
Example usage:
curl "https://YOUR-SPACE.hf.space/?api=1&image_url=<some_image_url>"
WARNING: You will still get HTML with embedded JSON. That's a Streamlit limitation.
"""
# Use st.query_params (not st.experimental_get_query_params)
params = st.query_params
image_url = params.get("image_url", [None])[0] # `query_params` returns dict of lists
if not image_url:
st.json({"error": "No 'image_url' provided. Usage: ?api=1&image_url=<URL>"})
st.stop()
# Download the image
response = requests.get(
image_url,
headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://honeybeeclassification.streamlit.app)"}
)
if response.status_code != 200:
st.json({"error": f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"})
st.stop()
image_bytes = response.content
# Check file size (limit 10MB)
image_size_mb = len(image_bytes) / (1024 * 1024)
if image_size_mb > 10:
st.json({"error": f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."})
st.stop()
# Convert to PIL
try:
image = Image.open(io.BytesIO(image_bytes))
except Exception as e:
st.json({"error": f"Could not open image: {e}"})
st.stop()
# Resize
image = resize_image_proportionally(image)
# Predict
try:
probabilities = predict_image(image, predictor)
honeybee_score = float(probabilities[1].iloc[0]) * 100
bumblebee_score = float(probabilities[2].iloc[0]) * 100
vespidae_score = float(probabilities[3].iloc[0]) * 100
except Exception as e:
st.json({"error": f"Prediction failed: {e}"})
st.stop()
# Determine highest-scoring label
highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
if highest_score < 80:
prediction_label = "No bee detected (scores too low)."
else:
if honeybee_score == highest_score:
prediction_label = "Honey Bee"
elif bumblebee_score == highest_score:
prediction_label = "Bumblebee"
else:
prediction_label = "Vespidae (wasp/hornet)"
# Return results as JSON, but note that Streamlit wraps this in HTML
st.json({
"honeybee_score": honeybee_score,
"bumblebee_score": bumblebee_score,
"vespidae_score": vespidae_score,
"prediction_label": prediction_label
})
# Stop execution so the normal UI won't render
st.stop()
def run_ui(predictor):
st.title("Honey Bee Image Classification")
uploaded_file = st.file_uploader(
"Upload a photo of the suspected bee...",
type=["png", "jpg", "jpeg"]
)
with st.expander("ML Model Details"):
st.write("""
We trained a MultiModalPredictor to classify bee images
(Honey Bee, Bumblebee, or Vespidae).
Accuracy is ~97.5% on our test set.
""")
if uploaded_file is not None:
if check_file_size(uploaded_file):
image = Image.open(uploaded_file)
image = resize_image_proportionally(image)
progress_bar = st.progress(0)
try:
probabilities = predict_image(image, predictor)
progress_bar.progress(100)
honeybee_score = float(probabilities[1].iloc[0]) * 100
bumblebee_score = float(probabilities[2].iloc[0]) * 100
vespidae_score = float(probabilities[3].iloc[0]) * 100
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
sanitized_filename = sanitize_filename(uploaded_file.name)
img_name = f"processed_{sanitized_filename}_{timestamp}.jpg"
image_path = save_image(image, img_name)
log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score)
highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
if highest_score < 80:
st.warning("We are fairly confident there is no bee in this photo.")
else:
if honeybee_score == highest_score:
st.success("Yes! This is a honey bee!")
elif bumblebee_score == highest_score:
st.info("Likely a bumblebee, not a honey bee.")
else:
st.info("Likely a wasp/hornet (vespidae).")
except Exception as e:
st.error(f"An error occurred: {e}")
finally:
progress_bar.empty()
def main():
predictor = load_model()
# Decide whether we are in 'API mode' or normal UI mode
query_params = st.query_params # Replaces st.experimental_get_query_params
if "api" in query_params:
run_api(predictor)
else:
run_ui(predictor)
if __name__ == '__main__':
main()
|