Pest_Detection / app.py
itsHamza's picture
Update app.py
47b6c1c verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import tensorflow as tf
from PIL import Image
import numpy as np
import io
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# Allow CORS for all origins (adjust as needed)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
def load_model():
return tf.keras.models.load_model("growlens_efficientnet_model.h5")
model = load_model()
class_names = [
"ants", "bees", "beetle", "catterpillar", "earthworms", "earwig",
"grasshopper", "moth", "slug", "snail", "wasp", "weevil"
]
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image = image.resize((224, 224)) # Adjust size as per your model
img_array = np.array(image) / 255.0
img_array = np.expand_dims(img_array, axis=0)
preds = model.predict(img_array)
pred_class = class_names[np.argmax(preds)]
confidence = float(np.max(preds))
return JSONResponse({"class": pred_class, "confidence": confidence})