File size: 8,240 Bytes
885f8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b644be2
 
885f8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b644be2
885f8ec
 
b644be2
885f8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b644be2
 
 
253cf58
 
 
 
885f8ec
253cf58
 
 
885f8ec
 
b644be2
 
885f8ec
 
 
 
 
 
 
 
 
b644be2
885f8ec
 
b644be2
253cf58
885f8ec
 
b644be2
885f8ec
b644be2
885f8ec
 
 
 
b644be2
885f8ec
b644be2
885f8ec
 
 
 
 
 
 
 
 
 
b644be2
885f8ec
 
 
 
 
 
 
 
 
 
 
 
 
253cf58
885f8ec
 
 
 
 
 
253cf58
b644be2
885f8ec
 
 
 
 
b644be2
885f8ec
 
 
 
 
b644be2
 
 
885f8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b644be2
885f8ec
 
 
 
b644be2
885f8ec
b644be2
885f8ec
 
 
 
 
 
 
 
 
b644be2
253cf58
885f8ec
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import streamlit as st
from PIL import Image
import pandas as pd
import io
import os
import requests
from autogluon.multimodal import MultiModalPredictor
from huggingface_hub import snapshot_download
import logging
import datetime
import re

# Configure logging
log_filename = "model_predictions.log"
logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s')

# Set the page config
st.set_page_config(page_title="Honey Bee Image Classification", layout="wide")

@st.cache_resource
def load_model():
    repo_id = "Honey-Bee-Society/honeybee_ml_v1"
    local_dir = snapshot_download(repo_id)

    assets_path = os.path.join(local_dir, "assets.json")
    model_checkpoint = os.path.join(local_dir, "model.ckpt")

    if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
        raise FileNotFoundError("Required model files not found in the downloaded directory.")

    return MultiModalPredictor.load(local_dir)

def resize_image_proportionally(image, max_size_mb=1):
    img_byte_array = io.BytesIO()
    image.save(img_byte_array, format='PNG')
    img_size = len(img_byte_array.getvalue()) / (1024 * 1024)

    if img_size > max_size_mb:
        scale_factor = (max_size_mb / img_size) ** 0.5
        new_width = int(image.width * scale_factor)
        new_height = int(image.height * scale_factor)
        image = image.resize((new_width, new_height))
    
    return image

def predict_image(image, predictor):
    img_byte_array = io.BytesIO()
    image.save(img_byte_array, format='PNG')
    img_data = img_byte_array.getvalue()
    df = pd.DataFrame({"image": [img_data]})
    probabilities = predictor.predict_proba(df, realtime=True)
    return probabilities

def save_image(image, img_name, target_size_kb=500):
    processed_image_path = os.path.join("processed_images", img_name)
    if not os.path.exists("processed_images"):
        os.makedirs("processed_images")
    
    quality = 95  
    img_byte_array = io.BytesIO()
    
    while quality > 10:
        img_byte_array.seek(0)
        image.save(img_byte_array, format='JPEG', quality=quality)
        img_size_kb = len(img_byte_array.getvalue()) / 1024
        if img_size_kb <= target_size_kb:
            break
        quality -= 5

    with open(processed_image_path, "wb") as f:
        f.write(img_byte_array.getvalue())
    
    return processed_image_path

def log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score):
    logging.info(
        f"Image Path: {image_path}, "
        f"Honeybee: {honeybee_score:.2f}%, "
        f"Bumblebee: {bumblebee_score:.2f}%, "
        f"Vespidae: {vespidae_score:.2f}%"
    )

def sanitize_filename(filename):
    safe_filename = re.sub(r'[^A-Za-z0-9_.-]', '_', filename)
    return safe_filename

def check_file_size(uploaded_file, max_size_mb=10):
    uploaded_file.seek(0, os.SEEK_END)
    file_size = uploaded_file.tell() / (1024 * 1024)
    uploaded_file.seek(0)
    if file_size > max_size_mb:
        st.error(f"File size exceeds {max_size_mb}MB limit. Please upload a smaller file.")
        return False
    return True

def run_api(predictor):
    """
    'API mode' for this Streamlit app. 
    Expects a query param ?api=1&image_url=<PUBLIC_IMAGE_URL>

    Example usage:
        curl "https://YOUR-SPACE.hf.space/?api=1&image_url=<some_image_url>"
    
    WARNING: You will still get HTML with embedded JSON. That's a Streamlit limitation.
    """
    # Use st.query_params (not st.experimental_get_query_params)
    params = st.query_params
    image_url = params.get("image_url", [None])[0]  # `query_params` returns dict of lists
    
    if not image_url:
        st.json({"error": "No 'image_url' provided. Usage: ?api=1&image_url=<URL>"})
        st.stop()

    # Download the image
    response = requests.get(
        image_url,
        headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://honeybeeclassification.streamlit.app)"}
    )

    if response.status_code != 200:
        st.json({"error": f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"})
        st.stop()
    
    image_bytes = response.content
    # Check file size (limit 10MB)
    image_size_mb = len(image_bytes) / (1024 * 1024)
    if image_size_mb > 10:
        st.json({"error": f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."})
        st.stop()

    # Convert to PIL
    try:
        image = Image.open(io.BytesIO(image_bytes))
    except Exception as e:
        st.json({"error": f"Could not open image: {e}"})
        st.stop()

    # Resize
    image = resize_image_proportionally(image)

    # Predict
    try:
        probabilities = predict_image(image, predictor)
        honeybee_score = float(probabilities[1].iloc[0]) * 100
        bumblebee_score = float(probabilities[2].iloc[0]) * 100
        vespidae_score = float(probabilities[3].iloc[0]) * 100
    except Exception as e:
        st.json({"error": f"Prediction failed: {e}"})
        st.stop()

    # Determine highest-scoring label
    highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
    if highest_score < 80:
        prediction_label = "No bee detected (scores too low)."
    else:
        if honeybee_score == highest_score:
            prediction_label = "Honey Bee"
        elif bumblebee_score == highest_score:
            prediction_label = "Bumblebee"
        else:
            prediction_label = "Vespidae (wasp/hornet)"

    # Return results as JSON, but note that Streamlit wraps this in HTML
    st.json({
        "honeybee_score": honeybee_score,
        "bumblebee_score": bumblebee_score,
        "vespidae_score": vespidae_score,
        "prediction_label": prediction_label
    })
    # Stop execution so the normal UI won't render
    st.stop()

def run_ui(predictor):
    st.title("Honey Bee Image Classification")

    uploaded_file = st.file_uploader(
        "Upload a photo of the suspected bee...",
        type=["png", "jpg", "jpeg"]
    )
    
    with st.expander("ML Model Details"):
        st.write("""
            We trained a MultiModalPredictor to classify bee images 
            (Honey Bee, Bumblebee, or Vespidae).
            Accuracy is ~97.5% on our test set.
        """)

    if uploaded_file is not None:
        if check_file_size(uploaded_file):
            image = Image.open(uploaded_file)
            image = resize_image_proportionally(image)

            progress_bar = st.progress(0)
            try:
                probabilities = predict_image(image, predictor)
                progress_bar.progress(100)

                honeybee_score = float(probabilities[1].iloc[0]) * 100
                bumblebee_score = float(probabilities[2].iloc[0]) * 100
                vespidae_score = float(probabilities[3].iloc[0]) * 100

                timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
                sanitized_filename = sanitize_filename(uploaded_file.name)
                img_name = f"processed_{sanitized_filename}_{timestamp}.jpg"
                
                image_path = save_image(image, img_name)
                log_predictions(image_path, honeybee_score, bumblebee_score, vespidae_score)

                highest_score = max(honeybee_score, bumblebee_score, vespidae_score)

                if highest_score < 80:
                    st.warning("We are fairly confident there is no bee in this photo.")
                else:
                    if honeybee_score == highest_score:
                        st.success("Yes! This is a honey bee!")
                    elif bumblebee_score == highest_score:
                        st.info("Likely a bumblebee, not a honey bee.")
                    else:
                        st.info("Likely a wasp/hornet (vespidae).")

            except Exception as e:
                st.error(f"An error occurred: {e}")
            finally:
                progress_bar.empty()

def main():
    predictor = load_model()

    # Decide whether we are in 'API mode' or normal UI mode
    query_params = st.query_params  # Replaces st.experimental_get_query_params
    if "api" in query_params:
        run_api(predictor)
    else:
        run_ui(predictor)

if __name__ == '__main__':
    main()