ahadhassan's picture
New endpoints
abb6af6
raw
history blame
15 kB
# test app.py
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, File, UploadFile, HTTPException
from ndvi_predictor import load_model, normalize_rgb, predict_ndvi, create_visualization
from yolo_predictor import load_yolo_model, predict_yolo, predict_pipeline
from PIL import Image
from io import BytesIO
import numpy as np
import zipfile
import json
import rasterio
from rasterio.transform import from_bounds
import tempfile
import os
import logging
from resize_image import resize_image_optimized, resize_image_simple
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Load models at startup
try:
ndvi_model = load_model("ndvi_best_model")
logger.info("NDVI model loaded successfully")
except Exception as e:
logger.error(f"Failed to load NDVI model: {e}")
ndvi_model = None
try:
yolo_model = load_yolo_model("best_yolo_model.pt")
logger.info("YOLO model loaded successfully")
except Exception as e:
logger.error(f"Failed to load YOLO model: {e}")
yolo_model = None
@app.get("/")
async def root():
return {"message": "Welcome to the NDVI and YOLO prediction API!"}
# Example usage in your predict_ndvi endpoint:
@app.post("/predict_ndvi/")
async def predict_ndvi_api(file: UploadFile = File(...)):
"""Predict NDVI from RGB image"""
if ndvi_model is None:
return JSONResponse(status_code=500, content={"error": "NDVI model not loaded"})
try:
# Define target size (height, width)
target_size = (640, 640)
contents = await file.read()
img = Image.open(BytesIO(contents)).convert("RGB")
# Convert to numpy array
rgb_array = np.array(img)
# Resize image to target size
rgb_resized = resize_image_optimized(rgb_array, target_size)
# Normalize the resized image
norm_img = normalize_rgb(rgb_resized)
# Predict NDVI
pred_ndvi = predict_ndvi(ndvi_model, norm_img)
# Rest of the endpoint remains the same...
# Visualization image as PNG
vis_img_bytes = create_visualization(norm_img, pred_ndvi)
vis_img_bytes.seek(0)
# NDVI band as .npy
ndvi_bytes = BytesIO()
np.save(ndvi_bytes, pred_ndvi)
ndvi_bytes.seek(0)
# Create a ZIP containing both files
zip_buf = BytesIO()
with zipfile.ZipFile(zip_buf, "w") as zip_file:
zip_file.writestr("ndvi_image.png", vis_img_bytes.read())
ndvi_bytes.seek(0)
zip_file.writestr("ndvi_band.npy", ndvi_bytes.read())
zip_buf.seek(0)
return StreamingResponse(
zip_buf,
media_type="application/x-zip-compressed",
headers={"Content-Disposition": "attachment; filename=ndvi_output.zip"}
)
except Exception as e:
logger.error(f"Error in predict_ndvi_api: {e}")
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/predict_yolo/")
async def predict_yolo_api(file: UploadFile = File(...)):
"""Predict YOLO results from 4-channel TIFF image"""
if yolo_model is None:
return JSONResponse(status_code=500, content={"error": "YOLO model not loaded"})
try:
# Save uploaded file temporarily with proper extension
file_extension = '.tiff' if file.filename and file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff'
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
contents = await file.read()
tmp_file.write(contents)
tmp_file.flush() # Ensure data is written
tmp_file_path = tmp_file.name
try:
# Verify the file was written correctly
if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
raise ValueError("Failed to create temporary file")
logger.info(f"Processing YOLO prediction for file: {file.filename}, temp path: {tmp_file_path}")
# Additional validation: check if file has 4 channels
try:
import tifffile
test_array = tifffile.imread(tmp_file_path)
if len(test_array.shape) == 3:
if test_array.shape[0] == 4 or test_array.shape[2] == 4:
channels = 4
else:
channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
else:
channels = 1
if channels != 4:
raise ValueError(f"YOLO model expects 4-channel images, but uploaded file has {channels} channels")
except Exception as validation_error:
logger.warning(f"Could not validate channels: {validation_error}")
# Predict using YOLO model
results = predict_yolo(yolo_model, tmp_file_path)
# Convert results to JSON-serializable format
results_dict = {
"boxes": {
"xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None,
"conf": results.boxes.conf.tolist() if results.boxes is not None else None,
"cls": results.boxes.cls.tolist() if results.boxes is not None else None
},
"classes": results.boxes.cls.tolist() if results.boxes is not None else None,
"names": results.names,
"orig_shape": results.orig_shape,
"speed": results.speed,
"masks": {
"data": results.masks.data.tolist() if results.masks is not None else None,
"orig_shape": results.masks.orig_shape if results.masks is not None else None,
"xy": [seg.tolist() for seg in results.masks.xy] if results.masks is not None else None,
"xyn": [seg.tolist() for seg in results.masks.xyn] if results.masks is not None else None
}
}
# Handle growth stages if present in the results
if hasattr(results, 'boxes') and results.boxes is not None:
if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0:
# Check if there are additional columns for growth stages
if results.boxes.data.shape[1] > 6:
growth_stages = results.boxes.data[:, 6:].tolist()
results_dict["growth_stages"] = growth_stages
logger.info(f"YOLO prediction completed successfully")
return JSONResponse(content=results_dict)
finally:
# Clean up temporary file
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
except Exception as e:
logger.error(f"Error in predict_yolo_api: {e}")
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/predict_pipeline/")
async def predict_pipeline_api(file: UploadFile = File(...)):
"""Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction"""
if ndvi_model is None or yolo_model is None:
return JSONResponse(status_code=500, content={"error": "Models not loaded properly"})
try:
logger.info(f"Starting full pipeline for file: {file.filename}")
# Read uploaded RGB image
contents = await file.read()
logger.info(f"Read {len(contents)} bytes from uploaded file")
# Convert to PIL Image and then to numpy array
img = Image.open(BytesIO(contents)).convert("RGB")
rgb_array = np.array(img)
logger.info(f"Converted to RGB array with shape: {rgb_array.shape}")
# Run the full pipeline (now includes resizing internally)
results = predict_pipeline(ndvi_model, yolo_model, rgb_array)
logger.info("Pipeline processing completed successfully")
# Convert results to JSON-serializable format
results_dict = {
"boxes": {
"xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None,
"conf": results.boxes.conf.tolist() if results.boxes is not None else None,
"cls": results.boxes.cls.tolist() if results.boxes is not None else None
},
"classes": results.boxes.cls.tolist() if results.boxes is not None else None,
"names": results.names,
"orig_shape": results.orig_shape,
"speed": results.speed,
"masks": {
"data": results.masks.data.tolist() if results.masks is not None else None,
"orig_shape": results.masks.orig_shape if results.masks is not None else None,
"xy": [seg.tolist() for seg in results.masks.xy] if results.masks is not None else None,
"xyn": [seg.tolist() for seg in results.masks.xyn] if results.masks is not None else None
}
}
# Handle growth stages if present in the results
if hasattr(results, 'boxes') and results.boxes is not None:
if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0:
# Check if there are additional columns for growth stages
if results.boxes.data.shape[1] > 6:
growth_stages = results.boxes.data[:, 6:].tolist()
results_dict["growth_stages"] = growth_stages
logger.info(f"Pipeline prediction completed successfully with {len(results_dict['boxes']['xyxyn']) if results_dict['boxes']['xyxyn'] else 0} detections")
return JSONResponse(content=results_dict)
except Exception as e:
logger.error(f"Error in predict_pipeline_api: {e}")
return JSONResponse(status_code=500, content={"error": str(e)})
# New endpoints to add to your FastAPI app
from yolo_predictor import predict_yolo_with_image, predict_pipeline_with_image, pil_image_to_bytes
@app.post("/predict_yolo_image/")
async def predict_yolo_image_api(file: UploadFile = File(...)):
"""Predict YOLO results from 4-channel TIFF image and return annotated image"""
if yolo_model is None:
return JSONResponse(status_code=500, content={"error": "YOLO model not loaded"})
try:
# Save uploaded file temporarily with proper extension
file_extension = '.tiff' if file.filename and file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff'
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
contents = await file.read()
tmp_file.write(contents)
tmp_file.flush() # Ensure data is written
tmp_file_path = tmp_file.name
try:
# Verify the file was written correctly
if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
raise ValueError("Failed to create temporary file")
logger.info(f"Processing YOLO prediction with image output for file: {file.filename}, temp path: {tmp_file_path}")
# Additional validation: check if file has 4 channels
try:
import tifffile
test_array = tifffile.imread(tmp_file_path)
if len(test_array.shape) == 3:
if test_array.shape[0] == 4 or test_array.shape[2] == 4:
channels = 4
else:
channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
else:
channels = 1
if channels != 4:
raise ValueError(f"YOLO model expects 4-channel images, but uploaded file has {channels} channels")
except Exception as validation_error:
logger.warning(f"Could not validate channels: {validation_error}")
# Predict using YOLO model and get annotated image
annotated_image = predict_yolo_with_image(yolo_model, tmp_file_path)
# Convert PIL Image to bytes for response
img_bytes = pil_image_to_bytes(annotated_image, format='PNG')
logger.info(f"YOLO prediction with image output completed successfully")
return StreamingResponse(
img_bytes,
media_type="image/png",
headers={"Content-Disposition": f"attachment; filename=yolo_annotated_{file.filename}.png"}
)
finally:
# Clean up temporary file
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
except Exception as e:
logger.error(f"Error in predict_yolo_image_api: {e}")
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/predict_pipeline_image/")
async def predict_pipeline_image_api(file: UploadFile = File(...)):
"""Full pipeline with image output: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction -> Annotated Image"""
if ndvi_model is None or yolo_model is None:
return JSONResponse(status_code=500, content={"error": "Models not loaded properly"})
try:
logger.info(f"Starting full pipeline with image output for file: {file.filename}")
# Read uploaded RGB image
contents = await file.read()
logger.info(f"Read {len(contents)} bytes from uploaded file")
# Convert to PIL Image and then to numpy array
img = Image.open(BytesIO(contents)).convert("RGB")
rgb_array = np.array(img)
logger.info(f"Converted to RGB array with shape: {rgb_array.shape}")
# Run the full pipeline with image output (includes resizing internally)
annotated_image = predict_pipeline_with_image(ndvi_model, yolo_model, rgb_array)
logger.info("Pipeline processing with image output completed successfully")
# Convert PIL Image to bytes for response
img_bytes = pil_image_to_bytes(annotated_image, format='PNG')
logger.info(f"Pipeline prediction with image output completed successfully")
return StreamingResponse(
img_bytes,
media_type="image/png",
headers={"Content-Disposition": f"attachment; filename=pipeline_annotated_{file.filename}.png"}
)
except Exception as e:
logger.error(f"Error in predict_pipeline_image_api: {e}")
return JSONResponse(status_code=500, content={"error": str(e)})