Spaces:
Sleeping
Sleeping
JackRabbit commited on
Commit ·
c442c56
1
Parent(s): e9471fb
new fastapi app
Browse files- app.py +202 -103
- requirements.txt +5 -4
app.py
CHANGED
|
@@ -1,143 +1,242 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
import
|
| 4 |
-
from PIL import Image
|
| 5 |
import io
|
|
|
|
|
|
|
|
|
|
| 6 |
import os
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
from autogluon.multimodal import MultiModalPredictor
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
###########################
|
| 13 |
-
#
|
| 14 |
-
###########################
|
| 15 |
def load_model():
|
| 16 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 17 |
repo_id = "Honey-Bee-Society/honeybee_ml_v1"
|
| 18 |
local_dir = snapshot_download(repo_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
predictor = MultiModalPredictor.load(local_dir)
|
| 20 |
return predictor
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# Utility Functions
|
| 27 |
-
###########################
|
| 28 |
-
def resize_image_if_large(image, max_size_mb=1):
|
| 29 |
-
"""Resizes the image if it is larger than `max_size_mb` MB."""
|
| 30 |
-
img_bytes = io.BytesIO()
|
| 31 |
-
image.save(img_bytes, format='PNG')
|
| 32 |
-
size_mb = len(img_bytes.getvalue()) / (1024 * 1024)
|
| 33 |
-
if size_mb > max_size_mb:
|
| 34 |
-
scale_factor = (max_size_mb / size_mb) ** 0.5
|
| 35 |
-
new_w = int(image.width * scale_factor)
|
| 36 |
-
new_h = int(image.height * scale_factor)
|
| 37 |
-
image = image.resize((new_w, new_h))
|
| 38 |
-
return image
|
| 39 |
-
|
| 40 |
-
def classify_image(image: Image.Image):
|
| 41 |
"""
|
| 42 |
-
|
| 43 |
-
|
| 44 |
"""
|
| 45 |
-
|
| 46 |
-
image
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
probabilities = predictor.predict_proba(df, realtime=True)
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
honeybee_score = float(probabilities[1].iloc[0]) * 100
|
| 57 |
bumblebee_score = float(probabilities[2].iloc[0]) * 100
|
| 58 |
vespidae_score = float(probabilities[3].iloc[0]) * 100
|
| 59 |
|
| 60 |
highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
|
| 61 |
if highest_score < 80:
|
| 62 |
-
|
| 63 |
else:
|
| 64 |
if honeybee_score == highest_score:
|
| 65 |
-
|
| 66 |
elif bumblebee_score == highest_score:
|
| 67 |
-
|
| 68 |
else:
|
| 69 |
-
|
| 70 |
|
| 71 |
return {
|
| 72 |
"honeybee_score": honeybee_score,
|
| 73 |
"bumblebee_score": bumblebee_score,
|
| 74 |
"vespidae_score": vespidae_score,
|
| 75 |
-
"prediction_label":
|
| 76 |
}
|
| 77 |
|
| 78 |
|
| 79 |
-
|
| 80 |
-
# The Main Predict Function
|
| 81 |
-
###########################
|
| 82 |
-
def predict(uploaded_image_dict, fallback_url):
|
| 83 |
"""
|
| 84 |
-
|
| 85 |
-
fetch the image from the internet.
|
| 86 |
-
2) Otherwise, Gradio will have downloaded it to a local file path
|
| 87 |
-
and `uploaded_image_dict["name"]` is that path. (or "path" is local)
|
| 88 |
-
3) If the user provides nothing in the first input, try `fallback_url`.
|
| 89 |
"""
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
try:
|
| 110 |
-
|
| 111 |
-
resp.raise_for_status()
|
| 112 |
-
image = Image.open(io.BytesIO(resp.content))
|
| 113 |
except Exception as e:
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
#
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
import uvicorn
|
|
|
|
| 4 |
import io
|
| 5 |
+
import logging
|
| 6 |
+
import datetime
|
| 7 |
+
import re
|
| 8 |
import os
|
| 9 |
+
import requests
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from PIL import Image
|
| 12 |
from autogluon.multimodal import MultiModalPredictor
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
|
| 15 |
+
###############################################################################
|
| 16 |
+
# Logging configuration (optional)
|
| 17 |
+
###############################################################################
|
| 18 |
+
log_filename = "model_predictions.log"
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
filename=log_filename,
|
| 21 |
+
level=logging.INFO,
|
| 22 |
+
format='%(asctime)s - %(message)s'
|
| 23 |
+
)
|
| 24 |
|
| 25 |
+
###############################################################################
|
| 26 |
+
# Model loading
|
| 27 |
+
###############################################################################
|
| 28 |
def load_model():
|
| 29 |
+
"""
|
| 30 |
+
Downloads the model from the specified huggingface hub repo and
|
| 31 |
+
loads it using MultiModalPredictor.
|
| 32 |
+
"""
|
| 33 |
repo_id = "Honey-Bee-Society/honeybee_ml_v1"
|
| 34 |
local_dir = snapshot_download(repo_id)
|
| 35 |
+
|
| 36 |
+
assets_path = os.path.join(local_dir, "assets.json")
|
| 37 |
+
model_checkpoint = os.path.join(local_dir, "model.ckpt")
|
| 38 |
+
|
| 39 |
+
if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint):
|
| 40 |
+
raise FileNotFoundError("Required model files not found in the downloaded directory.")
|
| 41 |
+
|
| 42 |
predictor = MultiModalPredictor.load(local_dir)
|
| 43 |
return predictor
|
| 44 |
|
| 45 |
+
###############################################################################
|
| 46 |
+
# Image processing and prediction routines
|
| 47 |
+
###############################################################################
|
| 48 |
+
def resize_image_proportionally(image, max_size_mb=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
"""
|
| 50 |
+
If the in-memory size of the image is > max_size_mb,
|
| 51 |
+
resize it proportionally.
|
| 52 |
"""
|
| 53 |
+
img_byte_array = io.BytesIO()
|
| 54 |
+
image.save(img_byte_array, format='PNG')
|
| 55 |
+
img_size = len(img_byte_array.getvalue()) / (1024 * 1024)
|
| 56 |
|
| 57 |
+
if img_size > max_size_mb:
|
| 58 |
+
scale_factor = (max_size_mb / img_size) ** 0.5
|
| 59 |
+
new_width = int(image.width * scale_factor)
|
| 60 |
+
new_height = int(image.height * scale_factor)
|
| 61 |
+
image = image.resize((new_width, new_height))
|
| 62 |
+
|
| 63 |
+
return image
|
| 64 |
|
| 65 |
+
|
| 66 |
+
def predict_image(image: Image.Image, predictor: MultiModalPredictor):
|
| 67 |
+
"""
|
| 68 |
+
Run the prediction via the AutoGluon MultiModalPredictor.
|
| 69 |
+
Returns probability dataframe for each class.
|
| 70 |
+
"""
|
| 71 |
+
img_byte_array = io.BytesIO()
|
| 72 |
+
image.save(img_byte_array, format='PNG')
|
| 73 |
+
img_data = img_byte_array.getvalue()
|
| 74 |
+
df = pd.DataFrame({"image": [img_data]})
|
| 75 |
probabilities = predictor.predict_proba(df, realtime=True)
|
| 76 |
+
return probabilities
|
| 77 |
|
| 78 |
+
|
| 79 |
+
def determine_label(probabilities):
|
| 80 |
+
"""
|
| 81 |
+
Given the probabilities DataFrame, compute the final label.
|
| 82 |
+
Returns a dict with numeric scores and a text label.
|
| 83 |
+
"""
|
| 84 |
honeybee_score = float(probabilities[1].iloc[0]) * 100
|
| 85 |
bumblebee_score = float(probabilities[2].iloc[0]) * 100
|
| 86 |
vespidae_score = float(probabilities[3].iloc[0]) * 100
|
| 87 |
|
| 88 |
highest_score = max(honeybee_score, bumblebee_score, vespidae_score)
|
| 89 |
if highest_score < 80:
|
| 90 |
+
prediction_label = "No bee detected (scores too low)."
|
| 91 |
else:
|
| 92 |
if honeybee_score == highest_score:
|
| 93 |
+
prediction_label = "Honey Bee"
|
| 94 |
elif bumblebee_score == highest_score:
|
| 95 |
+
prediction_label = "Bumblebee"
|
| 96 |
else:
|
| 97 |
+
prediction_label = "Vespidae (wasp/hornet)"
|
| 98 |
|
| 99 |
return {
|
| 100 |
"honeybee_score": honeybee_score,
|
| 101 |
"bumblebee_score": bumblebee_score,
|
| 102 |
"vespidae_score": vespidae_score,
|
| 103 |
+
"prediction_label": prediction_label
|
| 104 |
}
|
| 105 |
|
| 106 |
|
| 107 |
+
def log_predictions(honeybee_score, bumblebee_score, vespidae_score, source_info):
|
|
|
|
|
|
|
|
|
|
| 108 |
"""
|
| 109 |
+
Log predictions to a file (optional).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
"""
|
| 111 |
+
logging.info(
|
| 112 |
+
f"Source: {source_info}, "
|
| 113 |
+
f"Honeybee: {honeybee_score:.2f}%, "
|
| 114 |
+
f"Bumblebee: {bumblebee_score:.2f}%, "
|
| 115 |
+
f"Vespidae: {vespidae_score:.2f}%"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
###############################################################################
|
| 119 |
+
# Request models
|
| 120 |
+
###############################################################################
|
| 121 |
+
class ImageUrlRequest(BaseModel):
|
| 122 |
+
image_url: str
|
| 123 |
+
|
| 124 |
+
###############################################################################
|
| 125 |
+
# FastAPI app and endpoints
|
| 126 |
+
###############################################################################
|
| 127 |
+
app = FastAPI(title="Honey Bee Classification API")
|
| 128 |
+
|
| 129 |
+
# Load the model at startup (only once).
|
| 130 |
+
predictor = load_model()
|
| 131 |
+
|
| 132 |
+
@app.get("/ping")
|
| 133 |
+
def ping():
|
| 134 |
+
"""
|
| 135 |
+
A simple endpoint to check if the API is running.
|
| 136 |
+
"""
|
| 137 |
+
return {"message": "pong"}
|
| 138 |
+
|
| 139 |
+
@app.post("/predict")
|
| 140 |
+
async def predict_endpoint(
|
| 141 |
+
image_url_req: ImageUrlRequest = None,
|
| 142 |
+
file: UploadFile = File(None)
|
| 143 |
+
):
|
| 144 |
+
"""
|
| 145 |
+
Accepts either a JSON body with `image_url` or a multipart form-data `file`.
|
| 146 |
+
Returns JSON with honeybee, bumblebee, vespidae scores, and a predicted label.
|
| 147 |
+
"""
|
| 148 |
+
# 1) If user provided an image URL
|
| 149 |
+
if image_url_req and image_url_req.image_url:
|
| 150 |
+
image_url = image_url_req.image_url
|
| 151 |
+
# Download the image
|
| 152 |
+
try:
|
| 153 |
+
response = requests.get(
|
| 154 |
+
image_url,
|
| 155 |
+
headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://example.com)"}
|
| 156 |
+
)
|
| 157 |
+
if response.status_code != 200:
|
| 158 |
+
raise HTTPException(
|
| 159 |
+
status_code=400,
|
| 160 |
+
detail=f"Failed to retrieve image from {image_url}. HTTP {response.status_code}"
|
| 161 |
+
)
|
| 162 |
+
except Exception as e:
|
| 163 |
+
raise HTTPException(
|
| 164 |
+
status_code=400,
|
| 165 |
+
detail=f"Error downloading image from {image_url}: {e}"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
image_bytes = response.content
|
| 169 |
+
image_size_mb = len(image_bytes) / (1024*1024)
|
| 170 |
+
if image_size_mb > 10:
|
| 171 |
+
raise HTTPException(
|
| 172 |
+
status_code=413,
|
| 173 |
+
detail=f"Image size {image_size_mb:.2f}MB exceeds 10MB limit."
|
| 174 |
+
)
|
| 175 |
+
# Convert to PIL Image
|
| 176 |
try:
|
| 177 |
+
image = Image.open(io.BytesIO(image_bytes))
|
|
|
|
|
|
|
| 178 |
except Exception as e:
|
| 179 |
+
raise HTTPException(
|
| 180 |
+
status_code=400,
|
| 181 |
+
detail=f"Could not open image: {e}"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# 2) If user instead provided a file
|
| 185 |
+
elif file is not None:
|
| 186 |
+
# Check file size
|
| 187 |
+
file_size = 0
|
| 188 |
+
file.file.seek(0, 2) # move to end
|
| 189 |
+
file_size = file.file.tell()
|
| 190 |
+
file.file.seek(0) # reset pointer
|
| 191 |
+
mb_size = file_size / (1024 * 1024)
|
| 192 |
+
if mb_size > 10:
|
| 193 |
+
raise HTTPException(
|
| 194 |
+
status_code=413,
|
| 195 |
+
detail=f"Uploaded file size {mb_size:.2f}MB exceeds 10MB limit."
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Convert to PIL Image
|
| 199 |
+
try:
|
| 200 |
+
contents = await file.read()
|
| 201 |
+
image = Image.open(io.BytesIO(contents))
|
| 202 |
+
except Exception as e:
|
| 203 |
+
raise HTTPException(
|
| 204 |
+
status_code=400,
|
| 205 |
+
detail=f"Could not open uploaded image: {e}"
|
| 206 |
+
)
|
| 207 |
+
source_info = f"uploaded_file:{file.filename}"
|
| 208 |
else:
|
| 209 |
+
raise HTTPException(
|
| 210 |
+
status_code=400,
|
| 211 |
+
detail="No image provided. Supply either `image_url` or `file`."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Resize the image if needed
|
| 215 |
+
image = resize_image_proportionally(image)
|
| 216 |
+
|
| 217 |
+
# Predict
|
| 218 |
+
try:
|
| 219 |
+
probabilities = predict_image(image, predictor)
|
| 220 |
+
results = determine_label(probabilities)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
raise HTTPException(
|
| 223 |
+
status_code=500,
|
| 224 |
+
detail=f"Prediction failed: {e}"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Optionally log predictions
|
| 228 |
+
source_name = image_url_req.image_url if (image_url_req and image_url_req.image_url) else file.filename
|
| 229 |
+
log_predictions(
|
| 230 |
+
results["honeybee_score"],
|
| 231 |
+
results["bumblebee_score"],
|
| 232 |
+
results["vespidae_score"],
|
| 233 |
+
source_info=source_name
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
return results
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# If running locally, uncomment to start the server via `python app.py`
|
| 240 |
+
# (On Hugging Face Spaces, a separate command may be used.)
|
| 241 |
+
# if __name__ == "__main__":
|
| 242 |
+
# uvicorn.run(app, host="0.0.0.0", port=7860)
|
requirements.txt
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
| 2 |
pandas
|
| 3 |
autogluon.multimodal
|
| 4 |
-
huggingface_hub
|
| 5 |
-
requests
|
| 6 |
-
gradio
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
pillow
|
| 4 |
+
requests
|
| 5 |
pandas
|
| 6 |
autogluon.multimodal
|
| 7 |
+
huggingface_hub
|
|
|
|
|
|