banana-ML / app.py
Crcs1225
Add application file
22872d8
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import io
import numpy as np
import tensorflow as tf
from utility import preprocess_for_model
import logging
logger = logging.getLogger("banana-api")
app = FastAPI(title="Banana Ripeness & Shelf Life API")
# CORS setup for Streamlit frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace with specific domain in production
allow_methods=["*"],
allow_headers=["*"],
)
CLASS_LABELS = ["Overripe", "Ripe", "Rotten", "Unripe"]
# Load models once at startup
try:
model_cls = tf.keras.models.load_model("banana_classification_model.h5", compile=False)
model_reg = tf.keras.models.load_model("banana_shelf_life_regression_model.h5", compile=False)
except Exception as e:
raise RuntimeError(f"Failed to load models: {e}")
# πŸ” Root endpoint for cronjobs or basic ping
@app.get("/")
def root():
return {"status": "OK", "message": "Banana API is running."}
# 🩺 Health check endpoint
@app.get("/health")
def health_check():
try:
# Dummy input to verify model readiness
dummy = np.zeros((1, 224, 224, 3), dtype=np.float32)
_ = model_cls.predict(dummy)
_ = model_reg.predict(dummy)
return {"status": "healthy", "models_loaded": True}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
# 🍌 Prediction endpoint
@app.post("/predict")
async def predict(file: UploadFile = File(...)) -> dict:
try:
image_bytes = await file.read()
img_cls = preprocess_for_model(image_bytes, mode="classification")
pred_cls = model_cls.predict(img_cls)
class_idx = int(np.argmax(pred_cls))
ripeness_stage = CLASS_LABELS[class_idx]
confidence = float(np.max(pred_cls))
img_reg = preprocess_for_model(image_bytes, mode="regression")
pred_days = model_reg.predict(img_reg)[0][0]
days_until_rotten = max(0, round(pred_days))
return {
"ripeness_stage": ripeness_stage,
"confidence": round(confidence, 4),
"days_until_rotten": days_until_rotten
}
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")