ahadhassan's picture
Update app.py
eeac6a0 verified
raw
history blame
5.85 kB
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, File, UploadFile, HTTPException
from ndvi_predictor import load_model, normalize_rgb, predict_ndvi, create_visualization
from yolo_predictor import load_yolo_model, predict_yolo, predict_pipeline
from PIL import Image
from io import BytesIO
import numpy as np
import zipfile
import json
import rasterio
from rasterio.transform import from_bounds
import tempfile
import os
app = FastAPI()
# Load models at startup
ndvi_model = load_model("ndvi_best_model.keras")
yolo_model = load_yolo_model("4c_6c_regression.pt")
@app.get("/")
async def root():
return {"message": "Welcome to the NDVI and YOLO prediction API!"}
@app.post("/predict_ndvi/")
async def predict_ndvi_api(file: UploadFile = File(...)):
"""Predict NDVI from RGB image"""
try:
contents = await file.read()
img = Image.open(BytesIO(contents)).convert("RGB")
norm_img = normalize_rgb(np.array(img))
pred_ndvi = predict_ndvi(ndvi_model, norm_img)
# Visualization image as PNG
vis_img_bytes = create_visualization(norm_img, pred_ndvi)
vis_img_bytes.seek(0)
# NDVI band as .npy
ndvi_bytes = BytesIO()
np.save(ndvi_bytes, pred_ndvi)
ndvi_bytes.seek(0)
# Create a ZIP containing both files
zip_buf = BytesIO()
with zipfile.ZipFile(zip_buf, "w") as zip_file:
zip_file.writestr("ndvi_image.png", vis_img_bytes.read())
ndvi_bytes.seek(0)
zip_file.writestr("ndvi_band.npy", ndvi_bytes.read())
zip_buf.seek(0)
return StreamingResponse(
zip_buf,
media_type="application/x-zip-compressed",
headers={"Content-Disposition": "attachment; filename=ndvi_output.zip"}
)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/predict_yolo/")
async def predict_yolo_api(file: UploadFile = File(...)):
"""Predict YOLO results from 4-channel TIFF image"""
try:
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
contents = await file.read()
tmp_file.write(contents)
tmp_file_path = tmp_file.name
try:
# Predict using YOLO model
results = predict_yolo(yolo_model, tmp_file_path)
# Convert results to JSON-serializable format
results_dict = {
"boxes": {
"xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None,
"conf": results.boxes.conf.tolist() if results.boxes is not None else None,
"cls": results.boxes.cls.tolist() if results.boxes is not None else None
},
"classes": results.boxes.cls.tolist() if results.boxes is not None else None,
"names": results.names,
"growth_stages": getattr(results, 'growth_stages', None),
"orig_shape": results.orig_shape,
"speed": results.speed
}
# Handle growth stages if present
if hasattr(results, 'boxes') and hasattr(results.boxes, 'data'):
# Extract growth stages from the results if available
if len(results.boxes.data[0]) > 6: # Assuming growth stages are in the data
growth_stages = results.boxes.data[:, 6].tolist()
results_dict["growth_stages"] = growth_stages
return JSONResponse(content=results_dict)
finally:
# Clean up temporary file
os.unlink(tmp_file_path)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/predict_pipeline/")
async def predict_pipeline_api(file: UploadFile = File(...)):
"""Full pipeline: RGB -> NDVI -> 4-channel -> YOLO prediction"""
try:
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
contents = await file.read()
tmp_file.write(contents)
tmp_file_path = tmp_file.name
try:
# Run the full pipeline
results = predict_pipeline(ndvi_model, yolo_model, tmp_file_path)
# Convert results to JSON-serializable format
results_dict = {
"boxes": {
"xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None,
"conf": results.boxes.conf.tolist() if results.boxes is not None else None,
"cls": results.boxes.cls.tolist() if results.boxes is not None else None
},
"classes": results.boxes.cls.tolist() if results.boxes is not None else None,
"names": results.names,
"growth_stages": getattr(results, 'growth_stages', None),
"orig_shape": results.orig_shape,
"speed": results.speed
}
# Handle growth stages if present
if hasattr(results, 'boxes') and hasattr(results.boxes, 'data'):
if len(results.boxes.data[0]) > 6:
growth_stages = results.boxes.data[:, 6].tolist()
results_dict["growth_stages"] = growth_stages
return JSONResponse(content=results_dict)
finally:
# Clean up temporary file
os.unlink(tmp_file_path)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})