File size: 8,351 Bytes
c442c56
 
 
69357bd
c442c56
 
 
b0b3b40
c442c56
 
 
b0b3b40
c442c56
b0b3b40
c442c56
 
 
 
 
 
 
 
 
e9471fb
c442c56
 
 
b0b3b40
c442c56
 
 
 
b0b3b40
 
c442c56
 
 
 
 
 
 
b0b3b40
 
 
c442c56
 
 
 
b0b3b40
c442c56
 
b0b3b40
c442c56
 
 
e9471fb
c442c56
 
 
 
 
 
 
e9471fb
c442c56
 
 
 
 
 
 
 
 
 
b0b3b40
c442c56
e9471fb
c442c56
 
 
 
 
 
b0b3b40
 
 
e9471fb
b0b3b40
 
c442c56
b0b3b40
 
c442c56
b0b3b40
c442c56
b0b3b40
c442c56
e9471fb
0731f2b
b0b3b40
 
 
c442c56
b0b3b40
 
e9471fb
c442c56
0731f2b
c442c56
0731f2b
c442c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0731f2b
c442c56
0731f2b
c442c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0731f2b
c442c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4caeae
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
import uvicorn
import io
import logging
import datetime
import re
import os
import requests
import pandas as pd
from PIL import Image
from autogluon.multimodal import MultiModalPredictor
from huggingface_hub import snapshot_download

###############################################################################
# Logging configuration (optional)
###############################################################################
log_filename = "model_predictions.log"
logging.basicConfig(
    filename=log_filename, 
    level=logging.INFO, 
    format='%(asctime)s - %(message)s'
)

###############################################################################
# Model loading
###############################################################################
def load_model():
    """
    Downloads the model from the specified huggingface hub repo and 
    loads it using MultiModalPredictor.
    """
    repo_id = "Honey-Bee-Society/honeybee_ml_v1"
    local_dir = snapshot_download(repo_id)

    assets_path = os.path.join(local_dir, "assets.json")
    model_checkpoint = os.path.join(local_dir, "model.ckpt")

    if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
        raise FileNotFoundError("Required model files not found in the downloaded directory.")

    predictor = MultiModalPredictor.load(local_dir)
    return predictor

###############################################################################
# Image processing and prediction routines
###############################################################################
def resize_image_proportionally(image, max_size_mb=1):
    """
    If the in-memory size of the image is > max_size_mb, 
    resize it proportionally.
    """
    img_byte_array = io.BytesIO()
    image.save(img_byte_array, format='PNG')
    img_size = len(img_byte_array.getvalue()) / (1024 * 1024)

    if img_size > max_size_mb:
        scale_factor = (max_size_mb / img_size) ** 0.5
        new_width = int(image.width * scale_factor)
        new_height = int(image.height * scale_factor)
        image = image.resize((new_width, new_height))

    return image


def predict_image(image: Image.Image, predictor: MultiModalPredictor):
    """
    Run the prediction via the AutoGluon MultiModalPredictor. 
    Returns probability dataframe for each class.
    """
    img_byte_array = io.BytesIO()
    image.save(img_byte_array, format='PNG')
    img_data = img_byte_array.getvalue()
    df = pd.DataFrame({"image": [img_data]})
    probabilities = predictor.predict_proba(df, realtime=True)
    return probabilities


def determine_label(probabilities):
    """
    Given the probabilities DataFrame, compute the final label.
    Returns a dict with numeric scores and a text label.
    """
    honeybee_score = float(probabilities[1].iloc[0]) * 100
    bumblebee_score = float(probabilities[2].iloc[0]) * 100
    vespidae_score = float(probabilities[3].iloc[0]) * 100

    highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
    if highest_score < 80:
        prediction_label = "No bee detected (scores too low)."
    else:
        if honeybee_score == highest_score:
            prediction_label = "Honey Bee"
        elif bumblebee_score == highest_score:
            prediction_label = "Bumblebee"
        else:
            prediction_label = "Vespidae (wasp/hornet)"

    return {
        "honeybee_score": honeybee_score,
        "bumblebee_score": bumblebee_score,
        "vespidae_score": vespidae_score,
        "prediction_label": prediction_label
    }


def log_predictions(honeybee_score, bumblebee_score, vespidae_score, source_info):
    """
    Log predictions to a file (optional).
    """
    logging.info(
        f"Source: {source_info}, "
        f"Honeybee: {honeybee_score:.2f}%, "
        f"Bumblebee: {bumblebee_score:.2f}%, "
        f"Vespidae: {vespidae_score:.2f}%"
    )

###############################################################################
# Request models
###############################################################################
class ImageUrlRequest(BaseModel):
    image_url: str

###############################################################################
# FastAPI app and endpoints
###############################################################################
app = FastAPI(title="Honey Bee Classification API")

# Load the model at startup (only once).
predictor = load_model()

@app.get("/ping")
def ping():
    """
    A simple endpoint to check if the API is running.
    """
    return {"message": "pong"}

@app.post("/predict")
async def predict_endpoint(
    image_url_req: ImageUrlRequest = None,
    file: UploadFile = File(None)
):
    """
    Accepts either a JSON body with `image_url` or a multipart form-data `file`.
    Returns JSON with honeybee, bumblebee, vespidae scores, and a predicted label.
    """
    # 1) If user provided an image URL
    if image_url_req and image_url_req.image_url:
        image_url = image_url_req.image_url
        # Download the image
        try:
            response = requests.get(
                image_url, 
                headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://example.com)"}
            )
            if response.status_code != 200:
                raise HTTPException(
                    status_code=400, 
                    detail=f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"
                )
        except Exception as e:
            raise HTTPException(
                status_code=400, 
                detail=f"Error downloading image from {image_url}: {e}"
            )
        
        image_bytes = response.content
        image_size_mb = len(image_bytes) / (1024*1024)
        if image_size_mb > 10:
            raise HTTPException(
                status_code=413,
                detail=f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."
            )
        # Convert to PIL Image
        try:
            image = Image.open(io.BytesIO(image_bytes))
        except Exception as e:
            raise HTTPException(
                status_code=400, 
                detail=f"Could not open image: {e}"
            )
        
        # 2) If user instead provided a file
    elif file is not None:
        # Check file size
        file_size = 0
        file.file.seek(0, 2)  # move to end
        file_size = file.file.tell()
        file.file.seek(0)     # reset pointer
        mb_size = file_size / (1024 * 1024)
        if mb_size > 10:
            raise HTTPException(
                status_code=413,
                detail=f"Uploaded file size {mb_size:.2f}MB exceeds 10MB limit."
            )
        
        # Convert to PIL Image
        try:
            contents = await file.read()
            image = Image.open(io.BytesIO(contents))
        except Exception as e:
            raise HTTPException(
                status_code=400, 
                detail=f"Could not open uploaded image: {e}"
            )
        source_info = f"uploaded_file:{file.filename}"
    else:
        raise HTTPException(
            status_code=400, 
            detail="No image provided. Supply either `image_url` or `file`."
        )

    # Resize the image if needed
    image = resize_image_proportionally(image)

    # Predict
    try:
        probabilities = predict_image(image, predictor)
        results = determine_label(probabilities)
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Prediction failed: {e}"
        )

    # Optionally log predictions
    source_name = image_url_req.image_url if (image_url_req and image_url_req.image_url) else file.filename
    log_predictions(
        results["honeybee_score"],
        results["bumblebee_score"],
        results["vespidae_score"],
        source_info=source_name
    )

    return results


# If running locally, uncomment to start the server via `python app.py`
# (On Hugging Face Spaces, a separate command may be used.)
# if __name__ == "__main__":
#     uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
    import uvicorn
    import os

    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)