Muhammad Ahad Hassan Khan
corrected last commit
19b89fd
raw
history blame
1.32 kB
# app.py
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse, JSONResponse
from ndvi_predictor import load_model, normalize_rgb, predict_ndvi, create_visualization
from PIL import Image
from io import BytesIO
import numpy as np
import io
import os
import base64
app = FastAPI()
model = load_model("ndvi_best_model.keras")
@app.get("/")
async def root():
return {"message": "Welcome to the NDVI prediction API!"}
@app.post("/predict/")
async def predict_ndvi_api(file: UploadFile = File(...)):
try:
contents = await file.read()
img = Image.open(BytesIO(contents)).convert("RGB")
norm_img = normalize_rgb(np.array(img))
pred_ndvi = predict_ndvi(model, norm_img)
# Prepare visualization image
vis_img_io = create_visualization(norm_img, pred_ndvi)
vis_img_io.seek(0)
vis_img_base64 = base64.b64encode(vis_img_io.read()).decode("utf-8")
# Convert NDVI array to nested list (e.g., for JSON)
ndvi_list = pred_ndvi.tolist()
return JSONResponse(
content={
"ndvi_array": ndvi_list,
"visualization": vis_img_base64
}
)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})