FruitClassifier / main.py
Sriomdash's picture
Added Code files
0855c44 verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from fastapi.responses import JSONResponse
import numpy as np
import tensorflow as tf
from PIL import Image
import io
import os
import slowapi
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import logging
logging.basicConfig(
filename='api.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
app = FastAPI(
title="Cloud Inventory AI API",
description="API for scanning fruits and returning the Fruit Name and Quality.",
root_path="/api"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Use get_remote_address which is safer and handles proxies better than x.client.ip
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
# Register the exception handler so rate-limited users get a proper HTTP 429 response
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
class CastLayer(tf.keras.layers.Layer):
def call(self, inputs):
return tf.cast(inputs, tf.float32)
MODEL_PATH = "best_fruit_model.h5"
model = None
# 3. Replace print() with logging
if os.path.exists(MODEL_PATH):
try:
custom_objects = {'Cast': CastLayer}
model = tf.keras.models.load_model(MODEL_PATH, custom_objects=custom_objects, compile=False)
logging.info("Model loaded successfully!")
except Exception as e:
logging.error(f"Error loading model: {e}")
else:
logging.warning(f"Warning: Model not found at {MODEL_PATH}")
# Your 20-class list
CLASS_NAMES = [
'fresh_apple', 'fresh_banana', 'fresh_cucumber', 'fresh_grape',
'fresh_guava', 'fresh_mango', 'fresh_orange', 'fresh_pomegranate',
'fresh_strawberry', 'fresh_tomato', 'rotten_apple', 'rotten_banana',
'rotten_cucumber', 'rotten_grape', 'rotten_guava', 'rotten_mango',
'rotten_orange', 'rotten_pomegranate', 'rotten_strawberry', 'rotten_tomato'
]
@app.get("/")
@limiter.limit("40/minute")
async def root_call(request: Request):
logging.info(f"Root endpoint accessed by {request.client.host}")
return {"message": "Fruit Quality API is running. Go to /docs to test it."}
@app.get("/health")
@limiter.limit("40/minute")
async def health_call(request: Request):
if model is None:
logging.warning("Health check failed: Model not loaded.")
return {"status": "unhealthy", "reason": "Model missing or failed to load."}
return {"status": "healthy", "model_loaded": True}
@app.post("/predict")
@limiter.limit("40/minute")
async def predict_image(request: Request, file: UploadFile = File(...)):
logging.info(f"Prediction request received from {request.client.host} for file {file.filename}")
if model is None:
logging.error("Prediction attempted, but model is not loaded.")
raise HTTPException(status_code=503, detail="Model is not loaded.")
if not file.content_type.startswith("image/"):
logging.warning(f"Invalid file type uploaded: {file.content_type}")
raise HTTPException(status_code=400, detail="Invalid file. Upload an image.")
try:
contents = await file.read()
img = Image.open(io.BytesIO(contents)).convert('RGB')
img = img.resize((224, 224))
img_arr = np.array(img) / 255.0
img_arr = np.expand_dims(img_arr, axis=0)
preds = model.predict(img_arr, verbose=0)
idx = int(np.argmax(preds[0]))
raw_label = CLASS_NAMES[idx]
parts = raw_label.split('_', 1)
quality = parts[0].capitalize()
fruit_name = parts[1].title()
logging.info(f"Prediction successful: {quality} {fruit_name}")
return JSONResponse(content={
"fruit": fruit_name,
"quality": quality
})
except Exception as e:
logging.error(f"Server error during prediction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")