|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
import pandas as pd |
|
|
import io |
|
|
import os |
|
|
import requests |
|
|
from autogluon.multimodal import MultiModalPredictor |
|
|
from huggingface_hub import snapshot_download |
|
|
import logging |
|
|
import datetime |
|
|
import re |
|
|
|
|
|
|
|
|
log_filename = "model_predictions.log" |
|
|
logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s') |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Honey Bee Image Classification", layout="wide") |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
repo_id = "Honey-Bee-Society/honeybee_ml_v1" |
|
|
local_dir = snapshot_download(repo_id) |
|
|
|
|
|
assets_path = os.path.join(local_dir, "assets.json") |
|
|
model_checkpoint = os.path.join(local_dir, "model.ckpt") |
|
|
|
|
|
if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint): |
|
|
raise FileNotFoundError("Required model files not found in the downloaded directory.") |
|
|
|
|
|
return MultiModalPredictor.load(local_dir) |
|
|
|
|
|
def resize_image_proportionally(image, max_size_mb=1): |
|
|
img_byte_array = io.BytesIO() |
|
|
image.save(img_byte_array, format='PNG') |
|
|
img_size = len(img_byte_array.getvalue()) / (1024 * 1024) |
|
|
|
|
|
if img_size > max_size_mb: |
|
|
scale_factor = (max_size_mb / img_size) ** 0.5 |
|
|
new_width = int(image.width * scale_factor) |
|
|
new_height = int(image.height * scale_factor) |
|
|
image = image.resize((new_width, new_height)) |
|
|
|
|
|
return image |
|
|
|
|
|
def predict_image(image, predictor): |
|
|
img_byte_array = io.BytesIO() |
|
|
image.save(img_byte_array, format='PNG') |
|
|
img_data = img_byte_array.getvalue() |
|
|
df = pd.DataFrame({"image": [img_data]}) |
|
|
probabilities = predictor.predict_proba(df, realtime=True) |
|
|
return probabilities |
|
|
|
|
|
def save_image(image, img_name, target_size_kb=500): |
|
|
processed_image_path = os.path.join("processed_images", img_name) |
|
|
if not os.path.exists("processed_images"): |
|
|
os.makedirs("processed_images") |
|
|
|
|
|
quality = 95 |
|
|
img_byte_array = io.BytesIO() |
|
|
|
|
|
while quality > 10: |
|
|
img_byte_array.seek(0) |
|
|
image.save(img_byte_array, format='JPEG', quality=quality) |
|
|
img_size_kb = len(img_byte_array.getvalue()) / 1024 |
|
|
if img_size_kb <= target_size_kb: |
|
|
break |
|
|
quality -= 5 |
|
|
|
|
|
with open(processed_image_path, "wb") as f: |
|
|
f.write(img_byte_array.getvalue()) |
|
|
|
|
|
return processed_image_path |
|
|
|
|
|
def log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score): |
|
|
logging.info( |
|
|
f"Image Path: {image_path}, " |
|
|
f"Honeybee: {honeybee_score:.2f}%, " |
|
|
f"Bumblebee: {bumblebee_score:.2f}%, " |
|
|
f"Vespidae: {vespidae_score:.2f}%" |
|
|
) |
|
|
|
|
|
def sanitize_filename(filename): |
|
|
safe_filename = re.sub(r'[^A-Za-z0-9_.-]', '_', filename) |
|
|
return safe_filename |
|
|
|
|
|
def check_file_size(uploaded_file, max_size_mb=10): |
|
|
uploaded_file.seek(0, os.SEEK_END) |
|
|
file_size = uploaded_file.tell() / (1024 * 1024) |
|
|
uploaded_file.seek(0) |
|
|
if file_size > max_size_mb: |
|
|
st.error(f"File size exceeds {max_size_mb}MB limit. Please upload a smaller file.") |
|
|
return False |
|
|
return True |
|
|
|
|
|
def run_api(predictor): |
|
|
""" |
|
|
'API mode' for this Streamlit app. |
|
|
Expects a query param ?api=1&image_url=<PUBLIC_IMAGE_URL> |
|
|
|
|
|
Example usage: |
|
|
curl "https://YOUR-SPACE.hf.space/?api=1&image_url=<some_image_url>" |
|
|
|
|
|
WARNING: You will still get HTML with embedded JSON. That's a Streamlit limitation. |
|
|
""" |
|
|
|
|
|
params = st.query_params |
|
|
image_url = params.get("image_url", [None])[0] |
|
|
|
|
|
if not image_url: |
|
|
st.json({"error": "No 'image_url' provided. Usage: ?api=1&image_url=<URL>"}) |
|
|
st.stop() |
|
|
|
|
|
|
|
|
response = requests.get( |
|
|
image_url, |
|
|
headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://honeybeeclassification.streamlit.app)"} |
|
|
) |
|
|
|
|
|
if response.status_code != 200: |
|
|
st.json({"error": f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"}) |
|
|
st.stop() |
|
|
|
|
|
image_bytes = response.content |
|
|
|
|
|
image_size_mb = len(image_bytes) / (1024 * 1024) |
|
|
if image_size_mb > 10: |
|
|
st.json({"error": f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."}) |
|
|
st.stop() |
|
|
|
|
|
|
|
|
try: |
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
except Exception as e: |
|
|
st.json({"error": f"Could not open image: {e}"}) |
|
|
st.stop() |
|
|
|
|
|
|
|
|
image = resize_image_proportionally(image) |
|
|
|
|
|
|
|
|
try: |
|
|
probabilities = predict_image(image, predictor) |
|
|
honeybee_score = float(probabilities[1].iloc[0]) * 100 |
|
|
bumblebee_score = float(probabilities[2].iloc[0]) * 100 |
|
|
vespidae_score = float(probabilities[3].iloc[0]) * 100 |
|
|
except Exception as e: |
|
|
st.json({"error": f"Prediction failed: {e}"}) |
|
|
st.stop() |
|
|
|
|
|
|
|
|
highest_score = max(honeybee_score, bumblebee_score, vespidae_score) |
|
|
if highest_score < 80: |
|
|
prediction_label = "No bee detected (scores too low)." |
|
|
else: |
|
|
if honeybee_score == highest_score: |
|
|
prediction_label = "Honey Bee" |
|
|
elif bumblebee_score == highest_score: |
|
|
prediction_label = "Bumblebee" |
|
|
else: |
|
|
prediction_label = "Vespidae (wasp/hornet)" |
|
|
|
|
|
|
|
|
st.json({ |
|
|
"honeybee_score": honeybee_score, |
|
|
"bumblebee_score": bumblebee_score, |
|
|
"vespidae_score": vespidae_score, |
|
|
"prediction_label": prediction_label |
|
|
}) |
|
|
|
|
|
st.stop() |
|
|
|
|
|
def run_ui(predictor): |
|
|
st.title("Honey Bee Image Classification") |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
|
"Upload a photo of the suspected bee...", |
|
|
type=["png", "jpg", "jpeg"] |
|
|
) |
|
|
|
|
|
with st.expander("ML Model Details"): |
|
|
st.write(""" |
|
|
We trained a MultiModalPredictor to classify bee images |
|
|
(Honey Bee, Bumblebee, or Vespidae). |
|
|
Accuracy is ~97.5% on our test set. |
|
|
""") |
|
|
|
|
|
if uploaded_file is not None: |
|
|
if check_file_size(uploaded_file): |
|
|
image = Image.open(uploaded_file) |
|
|
image = resize_image_proportionally(image) |
|
|
|
|
|
progress_bar = st.progress(0) |
|
|
try: |
|
|
probabilities = predict_image(image, predictor) |
|
|
progress_bar.progress(100) |
|
|
|
|
|
honeybee_score = float(probabilities[1].iloc[0]) * 100 |
|
|
bumblebee_score = float(probabilities[2].iloc[0]) * 100 |
|
|
vespidae_score = float(probabilities[3].iloc[0]) * 100 |
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
sanitized_filename = sanitize_filename(uploaded_file.name) |
|
|
img_name = f"processed_{sanitized_filename}_{timestamp}.jpg" |
|
|
|
|
|
image_path = save_image(image, img_name) |
|
|
log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score) |
|
|
|
|
|
highest_score = max(honeybee_score, bumblebee_score, vespidae_score) |
|
|
|
|
|
if highest_score < 80: |
|
|
st.warning("We are fairly confident there is no bee in this photo.") |
|
|
else: |
|
|
if honeybee_score == highest_score: |
|
|
st.success("Yes! This is a honey bee!") |
|
|
elif bumblebee_score == highest_score: |
|
|
st.info("Likely a bumblebee, not a honey bee.") |
|
|
else: |
|
|
st.info("Likely a wasp/hornet (vespidae).") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"An error occurred: {e}") |
|
|
finally: |
|
|
progress_bar.empty() |
|
|
|
|
|
def main(): |
|
|
predictor = load_model() |
|
|
|
|
|
|
|
|
query_params = st.query_params |
|
|
if "api" in query_params: |
|
|
run_api(predictor) |
|
|
else: |
|
|
run_ui(predictor) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|