bee_ml_1 / app.py
JackRabbit
api updates
253cf58
import streamlit as st
from PIL import Image
import pandas as pd
import io
import os
import requests
from autogluon.multimodal import MultiModalPredictor
from huggingface_hub import snapshot_download
import logging
import datetime
import re
# Configure logging
log_filename = "model_predictions.log"
logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s')
# Set the page config
st.set_page_config(page_title="Honey Bee Image Classification", layout="wide")
@st.cache_resource
def load_model():
repo_id = "Honey-Bee-Society/honeybee_ml_v1"
local_dir = snapshot_download(repo_id)
assets_path = os.path.join(local_dir, "assets.json")
model_checkpoint = os.path.join(local_dir, "model.ckpt")
if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
raise FileNotFoundError("Required model files not found in the downloaded directory.")
return MultiModalPredictor.load(local_dir)
def resize_image_proportionally(image, max_size_mb=1):
img_byte_array = io.BytesIO()
image.save(img_byte_array, format='PNG')
img_size = len(img_byte_array.getvalue()) / (1024 * 1024)
if img_size > max_size_mb:
scale_factor = (max_size_mb / img_size) ** 0.5
new_width = int(image.width * scale_factor)
new_height = int(image.height * scale_factor)
image = image.resize((new_width, new_height))
return image
def predict_image(image, predictor):
img_byte_array = io.BytesIO()
image.save(img_byte_array, format='PNG')
img_data = img_byte_array.getvalue()
df = pd.DataFrame({"image": [img_data]})
probabilities = predictor.predict_proba(df, realtime=True)
return probabilities
def save_image(image, img_name, target_size_kb=500):
processed_image_path = os.path.join("processed_images", img_name)
if not os.path.exists("processed_images"):
os.makedirs("processed_images")
quality = 95
img_byte_array = io.BytesIO()
while quality > 10:
img_byte_array.seek(0)
image.save(img_byte_array, format='JPEG', quality=quality)
img_size_kb = len(img_byte_array.getvalue()) / 1024
if img_size_kb <= target_size_kb:
break
quality -= 5
with open(processed_image_path, "wb") as f:
f.write(img_byte_array.getvalue())
return processed_image_path
def log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score):
logging.info(
f"Image Path: {image_path}, "
f"Honeybee: {honeybee_score:.2f}%, "
f"Bumblebee: {bumblebee_score:.2f}%, "
f"Vespidae: {vespidae_score:.2f}%"
)
def sanitize_filename(filename):
safe_filename = re.sub(r'[^A-Za-z0-9_.-]', '_', filename)
return safe_filename
def check_file_size(uploaded_file, max_size_mb=10):
uploaded_file.seek(0, os.SEEK_END)
file_size = uploaded_file.tell() / (1024 * 1024)
uploaded_file.seek(0)
if file_size > max_size_mb:
st.error(f"File size exceeds {max_size_mb}MB limit. Please upload a smaller file.")
return False
return True
def run_api(predictor):
"""
'API mode' for this Streamlit app.
Expects a query param ?api=1&image_url=<PUBLIC_IMAGE_URL>
Example usage:
curl "https://YOUR-SPACE.hf.space/?api=1&image_url=<some_image_url>"
WARNING: You will still get HTML with embedded JSON. That's a Streamlit limitation.
"""
# Use st.query_params (not st.experimental_get_query_params)
params = st.query_params
image_url = params.get("image_url", [None])[0] # `query_params` returns dict of lists
if not image_url:
st.json({"error": "No 'image_url' provided. Usage: ?api=1&image_url=<URL>"})
st.stop()
# Download the image
response = requests.get(
image_url,
headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://honeybeeclassification.streamlit.app)"}
)
if response.status_code != 200:
st.json({"error": f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"})
st.stop()
image_bytes = response.content
# Check file size (limit 10MB)
image_size_mb = len(image_bytes) / (1024 * 1024)
if image_size_mb > 10:
st.json({"error": f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."})
st.stop()
# Convert to PIL
try:
image = Image.open(io.BytesIO(image_bytes))
except Exception as e:
st.json({"error": f"Could not open image: {e}"})
st.stop()
# Resize
image = resize_image_proportionally(image)
# Predict
try:
probabilities = predict_image(image, predictor)
honeybee_score = float(probabilities[1].iloc[0]) * 100
bumblebee_score = float(probabilities[2].iloc[0]) * 100
vespidae_score = float(probabilities[3].iloc[0]) * 100
except Exception as e:
st.json({"error": f"Prediction failed: {e}"})
st.stop()
# Determine highest-scoring label
highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
if highest_score < 80:
prediction_label = "No bee detected (scores too low)."
else:
if honeybee_score == highest_score:
prediction_label = "Honey Bee"
elif bumblebee_score == highest_score:
prediction_label = "Bumblebee"
else:
prediction_label = "Vespidae (wasp/hornet)"
# Return results as JSON, but note that Streamlit wraps this in HTML
st.json({
"honeybee_score": honeybee_score,
"bumblebee_score": bumblebee_score,
"vespidae_score": vespidae_score,
"prediction_label": prediction_label
})
# Stop execution so the normal UI won't render
st.stop()
def run_ui(predictor):
st.title("Honey Bee Image Classification")
uploaded_file = st.file_uploader(
"Upload a photo of the suspected bee...",
type=["png", "jpg", "jpeg"]
)
with st.expander("ML Model Details"):
st.write("""
We trained a MultiModalPredictor to classify bee images
(Honey Bee, Bumblebee, or Vespidae).
Accuracy is ~97.5% on our test set.
""")
if uploaded_file is not None:
if check_file_size(uploaded_file):
image = Image.open(uploaded_file)
image = resize_image_proportionally(image)
progress_bar = st.progress(0)
try:
probabilities = predict_image(image, predictor)
progress_bar.progress(100)
honeybee_score = float(probabilities[1].iloc[0]) * 100
bumblebee_score = float(probabilities[2].iloc[0]) * 100
vespidae_score = float(probabilities[3].iloc[0]) * 100
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
sanitized_filename = sanitize_filename(uploaded_file.name)
img_name = f"processed_{sanitized_filename}_{timestamp}.jpg"
image_path = save_image(image, img_name)
log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score)
highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
if highest_score < 80:
st.warning("We are fairly confident there is no bee in this photo.")
else:
if honeybee_score == highest_score:
st.success("Yes! This is a honey bee!")
elif bumblebee_score == highest_score:
st.info("Likely a bumblebee, not a honey bee.")
else:
st.info("Likely a wasp/hornet (vespidae).")
except Exception as e:
st.error(f"An error occurred: {e}")
finally:
progress_bar.empty()
def main():
predictor = load_model()
# Decide whether we are in 'API mode' or normal UI mode
query_params = st.query_params # Replaces st.experimental_get_query_params
if "api" in query_params:
run_api(predictor)
else:
run_ui(predictor)
if __name__ == '__main__':
main()