Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import io | |
| import numpy as np | |
| import tensorflow as tf | |
| from utility import preprocess_for_model | |
| import logging | |
| logger = logging.getLogger("banana-api") | |
| app = FastAPI(title="Banana Ripeness & Shelf Life API") | |
| # CORS setup for Streamlit frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Replace with specific domain in production | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| CLASS_LABELS = ["Overripe", "Ripe", "Rotten", "Unripe"] | |
| # Load models once at startup | |
| try: | |
| model_cls = tf.keras.models.load_model("banana_classification_model.h5", compile=False) | |
| model_reg = tf.keras.models.load_model("banana_shelf_life_regression_model.h5", compile=False) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load models: {e}") | |
| # π Root endpoint for cronjobs or basic ping | |
| def root(): | |
| return {"status": "OK", "message": "Banana API is running."} | |
| # π©Ί Health check endpoint | |
| def health_check(): | |
| try: | |
| # Dummy input to verify model readiness | |
| dummy = np.zeros((1, 224, 224, 3), dtype=np.float32) | |
| _ = model_cls.predict(dummy) | |
| _ = model_reg.predict(dummy) | |
| return {"status": "healthy", "models_loaded": True} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}") | |
| # π Prediction endpoint | |
| async def predict(file: UploadFile = File(...)) -> dict: | |
| try: | |
| image_bytes = await file.read() | |
| img_cls = preprocess_for_model(image_bytes, mode="classification") | |
| pred_cls = model_cls.predict(img_cls) | |
| class_idx = int(np.argmax(pred_cls)) | |
| ripeness_stage = CLASS_LABELS[class_idx] | |
| confidence = float(np.max(pred_cls)) | |
| img_reg = preprocess_for_model(image_bytes, mode="regression") | |
| pred_days = model_reg.predict(img_reg)[0][0] | |
| days_until_rotten = max(0, round(pred_days)) | |
| return { | |
| "ripeness_stage": ripeness_stage, | |
| "confidence": round(confidence, 4), | |
| "days_until_rotten": days_until_rotten | |
| } | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |