Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import io | |
| import logging | |
| import datetime | |
| import re | |
| import os | |
| import requests | |
| import pandas as pd | |
| from PIL import Image | |
| from autogluon.multimodal import MultiModalPredictor | |
| from huggingface_hub import snapshot_download | |
| ############################################################################### | |
| # Logging configuration (optional) | |
| ############################################################################### | |
| log_filename = "model_predictions.log" | |
| logging.basicConfig( | |
| filename=log_filename, | |
| level=logging.INFO, | |
| format='%(asctime)s - %(message)s' | |
| ) | |
| ############################################################################### | |
| # Model loading | |
| ############################################################################### | |
| def load_model(): | |
| """ | |
| Downloads the model from the specified huggingface hub repo and | |
| loads it using MultiModalPredictor. | |
| """ | |
| repo_id = "Honey-Bee-Society/honeybee_ml_v1" | |
| local_dir = snapshot_download(repo_id) | |
| assets_path = os.path.join(local_dir, "assets.json") | |
| model_checkpoint = os.path.join(local_dir, "model.ckpt") | |
| if not os.path.exists(assets_path) or not os.path.exists(model_checkpoint): | |
| raise FileNotFoundError("Required model files not found in the downloaded directory.") | |
| predictor = MultiModalPredictor.load(local_dir) | |
| return predictor | |
| ############################################################################### | |
| # Image processing and prediction routines | |
| ############################################################################### | |
| def resize_image_proportionally(image, max_size_mb=1): | |
| """ | |
| If the in-memory size of the image is > max_size_mb, | |
| resize it proportionally. | |
| """ | |
| img_byte_array = io.BytesIO() | |
| image.save(img_byte_array, format='PNG') | |
| img_size = len(img_byte_array.getvalue()) / (1024 * 1024) | |
| if img_size > max_size_mb: | |
| scale_factor = (max_size_mb / img_size) ** 0.5 | |
| new_width = int(image.width * scale_factor) | |
| new_height = int(image.height * scale_factor) | |
| image = image.resize((new_width, new_height)) | |
| return image | |
| def predict_image(image: Image.Image, predictor: MultiModalPredictor): | |
| """ | |
| Run the prediction via the AutoGluon MultiModalPredictor. | |
| Returns probability dataframe for each class. | |
| """ | |
| img_byte_array = io.BytesIO() | |
| image.save(img_byte_array, format='PNG') | |
| img_data = img_byte_array.getvalue() | |
| df = pd.DataFrame({"image": [img_data]}) | |
| probabilities = predictor.predict_proba(df, realtime=True) | |
| return probabilities | |
| def determine_label(probabilities): | |
| """ | |
| Given the probabilities DataFrame, compute the final label. | |
| Returns a dict with numeric scores and a text label. | |
| """ | |
| honeybee_score = float(probabilities[1].iloc[0]) * 100 | |
| bumblebee_score = float(probabilities[2].iloc[0]) * 100 | |
| vespidae_score = float(probabilities[3].iloc[0]) * 100 | |
| highest_score = max(honeybee_score, bumblebee_score, vespidae_score) | |
| if highest_score < 80: | |
| prediction_label = "No bee detected (scores too low)." | |
| else: | |
| if honeybee_score == highest_score: | |
| prediction_label = "Honey Bee" | |
| elif bumblebee_score == highest_score: | |
| prediction_label = "Bumblebee" | |
| else: | |
| prediction_label = "Vespidae (wasp/hornet)" | |
| return { | |
| "honeybee_score": honeybee_score, | |
| "bumblebee_score": bumblebee_score, | |
| "vespidae_score": vespidae_score, | |
| "prediction_label": prediction_label | |
| } | |
| def log_predictions(honeybee_score, bumblebee_score, vespidae_score, source_info): | |
| """ | |
| Log predictions to a file (optional). | |
| """ | |
| logging.info( | |
| f"Source: {source_info}, " | |
| f"Honeybee: {honeybee_score:.2f}%, " | |
| f"Bumblebee: {bumblebee_score:.2f}%, " | |
| f"Vespidae: {vespidae_score:.2f}%" | |
| ) | |
| ############################################################################### | |
| # Request models | |
| ############################################################################### | |
| class ImageUrlRequest(BaseModel): | |
| image_url: str | |
| ############################################################################### | |
| # FastAPI app and endpoints | |
| ############################################################################### | |
| app = FastAPI(title="Honey Bee Classification API") | |
| # Load the model at startup (only once). | |
| predictor = load_model() | |
| def ping(): | |
| """ | |
| A simple endpoint to check if the API is running. | |
| """ | |
| return {"message": "pong"} | |
| async def predict_endpoint( | |
| image_url_req: ImageUrlRequest = None, | |
| file: UploadFile = File(None) | |
| ): | |
| """ | |
| Accepts either a JSON body with `image_url` or a multipart form-data `file`. | |
| Returns JSON with honeybee, bumblebee, vespidae scores, and a predicted label. | |
| """ | |
| # 1) If user provided an image URL | |
| if image_url_req and image_url_req.image_url: | |
| image_url = image_url_req.image_url | |
| # Download the image | |
| try: | |
| response = requests.get( | |
| image_url, | |
| headers={"User-Agent": "HoneyBeeClassification/1.0 (+https://example.com)"} | |
| ) | |
| if response.status_code != 200: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Failed to retrieve image from {image_url}. HTTP {response.status_code}" | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Error downloading image from {image_url}: {e}" | |
| ) | |
| image_bytes = response.content | |
| image_size_mb = len(image_bytes) / (1024*1024) | |
| if image_size_mb > 10: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"Image size {image_size_mb:.2f}MB exceeds 10MB limit." | |
| ) | |
| # Convert to PIL Image | |
| try: | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Could not open image: {e}" | |
| ) | |
| # 2) If user instead provided a file | |
| elif file is not None: | |
| # Check file size | |
| file_size = 0 | |
| file.file.seek(0, 2) # move to end | |
| file_size = file.file.tell() | |
| file.file.seek(0) # reset pointer | |
| mb_size = file_size / (1024 * 1024) | |
| if mb_size > 10: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"Uploaded file size {mb_size:.2f}MB exceeds 10MB limit." | |
| ) | |
| # Convert to PIL Image | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Could not open uploaded image: {e}" | |
| ) | |
| source_info = f"uploaded_file:{file.filename}" | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No image provided. Supply either `image_url` or `file`." | |
| ) | |
| # Resize the image if needed | |
| image = resize_image_proportionally(image) | |
| # Predict | |
| try: | |
| probabilities = predict_image(image, predictor) | |
| results = determine_label(probabilities) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction failed: {e}" | |
| ) | |
| # Optionally log predictions | |
| source_name = image_url_req.image_url if (image_url_req and image_url_req.image_url) else file.filename | |
| log_predictions( | |
| results["honeybee_score"], | |
| results["bumblebee_score"], | |
| results["vespidae_score"], | |
| source_info=source_name | |
| ) | |
| return results | |
| # If running locally, uncomment to start the server via `python app.py` | |
| # (On Hugging Face Spaces, a separate command may be used.) | |
| # if __name__ == "__main__": | |
| # uvicorn.run(app, host="0.0.0.0", port=7860) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |