Spaces:
Sleeping
Sleeping
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from ndvi_predictor import load_model, normalize_rgb, predict_ndvi, create_visualization | |
| from yolo_predictor import load_yolo_model, predict_yolo, predict_pipeline | |
| from PIL import Image | |
| from io import BytesIO | |
| import numpy as np | |
| import zipfile | |
| import json | |
| import rasterio | |
| from rasterio.transform import from_bounds | |
| import tempfile | |
| import os | |
| app = FastAPI() | |
| # Load models at startup | |
| ndvi_model = load_model("ndvi_best_model.keras") | |
| yolo_model = load_yolo_model("4c_6c_regression.pt") | |
| async def root(): | |
| return {"message": "Welcome to the NDVI and YOLO prediction API!"} | |
| async def predict_ndvi_api(file: UploadFile = File(...)): | |
| """Predict NDVI from RGB image""" | |
| try: | |
| contents = await file.read() | |
| img = Image.open(BytesIO(contents)).convert("RGB") | |
| norm_img = normalize_rgb(np.array(img)) | |
| pred_ndvi = predict_ndvi(ndvi_model, norm_img) | |
| # Visualization image as PNG | |
| vis_img_bytes = create_visualization(norm_img, pred_ndvi) | |
| vis_img_bytes.seek(0) | |
| # NDVI band as .npy | |
| ndvi_bytes = BytesIO() | |
| np.save(ndvi_bytes, pred_ndvi) | |
| ndvi_bytes.seek(0) | |
| # Create a ZIP containing both files | |
| zip_buf = BytesIO() | |
| with zipfile.ZipFile(zip_buf, "w") as zip_file: | |
| zip_file.writestr("ndvi_image.png", vis_img_bytes.read()) | |
| ndvi_bytes.seek(0) | |
| zip_file.writestr("ndvi_band.npy", ndvi_bytes.read()) | |
| zip_buf.seek(0) | |
| return StreamingResponse( | |
| zip_buf, | |
| media_type="application/x-zip-compressed", | |
| headers={"Content-Disposition": "attachment; filename=ndvi_output.zip"} | |
| ) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def predict_yolo_api(file: UploadFile = File(...)): | |
| """Predict YOLO results from 4-channel TIFF image""" | |
| try: | |
| # Save uploaded file temporarily with proper extension | |
| file_extension = '.tiff' if file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff' | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
| contents = await file.read() | |
| tmp_file.write(contents) | |
| tmp_file.flush() # Ensure data is written | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Verify the file was written correctly | |
| if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0: | |
| raise ValueError("Failed to create temporary file") | |
| # Predict using YOLO model | |
| results = predict_yolo(yolo_model, tmp_file_path) | |
| # Convert results to JSON-serializable format | |
| results_dict = { | |
| "boxes": { | |
| "xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None, | |
| "conf": results.boxes.conf.tolist() if results.boxes is not None else None, | |
| "cls": results.boxes.cls.tolist() if results.boxes is not None else None | |
| }, | |
| "classes": results.boxes.cls.tolist() if results.boxes is not None else None, | |
| "names": results.names, | |
| "orig_shape": results.orig_shape, | |
| "speed": results.speed | |
| } | |
| # Handle growth stages if present in the results | |
| if hasattr(results, 'boxes') and results.boxes is not None: | |
| if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0: | |
| # Check if there are additional columns for growth stages | |
| if results.boxes.data.shape[1] > 6: | |
| growth_stages = results.boxes.data[:, 6:].tolist() | |
| results_dict["growth_stages"] = growth_stages | |
| return JSONResponse(content=results_dict) | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(tmp_file_path): | |
| os.unlink(tmp_file_path) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def predict_pipeline_api(file: UploadFile = File(...)): | |
| """Full pipeline: RGB -> NDVI -> 4-channel -> YOLO prediction""" | |
| try: | |
| # Save uploaded file temporarily with proper extension | |
| file_extension = '.tiff' if file.filename.lower().endswith(('.tif', '.tiff')) else '.jpg' | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
| contents = await file.read() | |
| tmp_file.write(contents) | |
| tmp_file.flush() # Ensure data is written | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Verify the file was written correctly | |
| if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0: | |
| raise ValueError("Failed to create temporary file") | |
| # Run the full pipeline | |
| results = predict_pipeline(ndvi_model, yolo_model, tmp_file_path) | |
| # Convert results to JSON-serializable format | |
| results_dict = { | |
| "boxes": { | |
| "xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None, | |
| "conf": results.boxes.conf.tolist() if results.boxes is not None else None, | |
| "cls": results.boxes.cls.tolist() if results.boxes is not None else None | |
| }, | |
| "classes": results.boxes.cls.tolist() if results.boxes is not None else None, | |
| "names": results.names, | |
| "orig_shape": results.orig_shape, | |
| "speed": results.speed | |
| } | |
| # Handle growth stages if present in the results | |
| if hasattr(results, 'boxes') and results.boxes is not None: | |
| if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0: | |
| # Check if there are additional columns for growth stages | |
| if results.boxes.data.shape[1] > 6: | |
| growth_stages = results.boxes.data[:, 6:].tolist() | |
| results_dict["growth_stages"] = growth_stages | |
| return JSONResponse(content=results_dict) | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(tmp_file_path): | |
| os.unlink(tmp_file_path) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) |