import tensorflow as tf from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from io import BytesIO from PIL import Image import numpy as np from typing import List import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_PATH = "model.keras" try: logger.info(f"Loading model from {MODEL_PATH}") model = tf.keras.models.load_model(MODEL_PATH) except Exception as e: logger.error(f"Failed to load model: {e}") raise CLASS_NAMES = ['healthy', 'rotten'] def process_image(image: Image.Image) -> np.ndarray: try: if image.mode != 'RGB': image = image.convert('RGB') image = image.resize((128, 128)) image_array = np.array(image) / 255.0 if image_array.shape[-1] == 4: image_array = image_array[..., :3] return image_array except Exception as e: logger.error(f"Error processing image: {e}") raise ValueError(f"Error processing image: {e}") @app.post("/predict") async def predict(files: List[UploadFile] = File(...)): try: if not files: logger.warning("No files provided in request") raise HTTPException(status_code=400, detail="No files provided") results = [] for file in files: logger.info(f"Processing file: {file.filename}") if file.size > 10 * 1024 * 1024: logger.warning(f"File {file.filename} too large: {file.size} bytes") raise HTTPException(status_code=400, detail=f"File {file.filename} too large") content = await file.read() image = Image.open(BytesIO(content)) image_array = process_image(image) image_array = np.expand_dims(image_array, axis=0) logger.info(f"Making prediction for {file.filename}") predictions = model.predict(image_array, verbose=0) predicted_index = np.argmax(predictions[0]) predicted_class = CLASS_NAMES[predicted_index] confidence = float(np.max(predictions[0])) results.append({ "filename": file.filename or f"image_{len(results)}.jpg", "class": predicted_class, "confidence": round(confidence, 4) }) logger.info(f"Returning {len(results)} predictions") return JSONResponse({ "predictions": results, "status": "success" }) except Exception as e: logger.error(f"Error in predict endpoint: {e}") return JSONResponse( status_code=400, content={ "predictions": [], "status": "error", "error": str(e) } )