Spaces:
Sleeping
Sleeping
| # test app.py | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from ndvi_predictor import load_model, normalize_rgb, predict_ndvi, create_visualization | |
| from yolo_predictor import load_yolo_model, predict_yolo, predict_pipeline | |
| from PIL import Image | |
| from io import BytesIO | |
| import numpy as np | |
| import zipfile | |
| import json | |
| import rasterio | |
| from rasterio.transform import from_bounds | |
| import tempfile | |
| import os | |
| import logging | |
| from resize_image import resize_image_optimized, resize_image_simple | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Load models at startup | |
| try: | |
| ndvi_model = load_model("ndvi_best_model") | |
| logger.info("NDVI model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load NDVI model: {e}") | |
| ndvi_model = None | |
| try: | |
| yolo_model = load_yolo_model("best_yolo_model.pt") | |
| logger.info("YOLO model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load YOLO model: {e}") | |
| yolo_model = None | |
| async def root(): | |
| return {"message": "Welcome to the NDVI and YOLO prediction API!"} | |
| # Example usage in your predict_ndvi endpoint: | |
| async def predict_ndvi_api(file: UploadFile = File(...)): | |
| """Predict NDVI from RGB image""" | |
| if ndvi_model is None: | |
| return JSONResponse(status_code=500, content={"error": "NDVI model not loaded"}) | |
| try: | |
| # Define target size (height, width) | |
| target_size = (640, 640) | |
| contents = await file.read() | |
| img = Image.open(BytesIO(contents)).convert("RGB") | |
| # Convert to numpy array | |
| rgb_array = np.array(img) | |
| # Resize image to target size | |
| rgb_resized = resize_image_optimized(rgb_array, target_size) | |
| # Normalize the resized image | |
| norm_img = normalize_rgb(rgb_resized) | |
| # Predict NDVI | |
| pred_ndvi = predict_ndvi(ndvi_model, norm_img) | |
| # Rest of the endpoint remains the same... | |
| # Visualization image as PNG | |
| vis_img_bytes = create_visualization(norm_img, pred_ndvi) | |
| vis_img_bytes.seek(0) | |
| # NDVI band as .npy | |
| ndvi_bytes = BytesIO() | |
| np.save(ndvi_bytes, pred_ndvi) | |
| ndvi_bytes.seek(0) | |
| # Create a ZIP containing both files | |
| zip_buf = BytesIO() | |
| with zipfile.ZipFile(zip_buf, "w") as zip_file: | |
| zip_file.writestr("ndvi_image.png", vis_img_bytes.read()) | |
| ndvi_bytes.seek(0) | |
| zip_file.writestr("ndvi_band.npy", ndvi_bytes.read()) | |
| zip_buf.seek(0) | |
| return StreamingResponse( | |
| zip_buf, | |
| media_type="application/x-zip-compressed", | |
| headers={"Content-Disposition": "attachment; filename=ndvi_output.zip"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in predict_ndvi_api: {e}") | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def predict_yolo_api(file: UploadFile = File(...)): | |
| """Predict YOLO results from 4-channel TIFF image""" | |
| if yolo_model is None: | |
| return JSONResponse(status_code=500, content={"error": "YOLO model not loaded"}) | |
| try: | |
| # Save uploaded file temporarily with proper extension | |
| file_extension = '.tiff' if file.filename and file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff' | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
| contents = await file.read() | |
| tmp_file.write(contents) | |
| tmp_file.flush() # Ensure data is written | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Verify the file was written correctly | |
| if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0: | |
| raise ValueError("Failed to create temporary file") | |
| logger.info(f"Processing YOLO prediction for file: {file.filename}, temp path: {tmp_file_path}") | |
| # Additional validation: check if file has 4 channels | |
| try: | |
| import tifffile | |
| test_array = tifffile.imread(tmp_file_path) | |
| if len(test_array.shape) == 3: | |
| if test_array.shape[0] == 4 or test_array.shape[2] == 4: | |
| channels = 4 | |
| else: | |
| channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2] | |
| else: | |
| channels = 1 | |
| if channels != 4: | |
| raise ValueError(f"YOLO model expects 4-channel images, but uploaded file has {channels} channels") | |
| except Exception as validation_error: | |
| logger.warning(f"Could not validate channels: {validation_error}") | |
| # Predict using YOLO model | |
| results = predict_yolo(yolo_model, tmp_file_path) | |
| # Convert results to JSON-serializable format | |
| results_dict = { | |
| "boxes": { | |
| "xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None, | |
| "conf": results.boxes.conf.tolist() if results.boxes is not None else None, | |
| "cls": results.boxes.cls.tolist() if results.boxes is not None else None | |
| }, | |
| "classes": results.boxes.cls.tolist() if results.boxes is not None else None, | |
| "names": results.names, | |
| "orig_shape": results.orig_shape, | |
| "speed": results.speed, | |
| "masks": { | |
| "data": results.masks.data.tolist() if results.masks is not None else None, | |
| "orig_shape": results.masks.orig_shape if results.masks is not None else None, | |
| "xy": [seg.tolist() for seg in results.masks.xy] if results.masks is not None else None, | |
| "xyn": [seg.tolist() for seg in results.masks.xyn] if results.masks is not None else None | |
| } | |
| } | |
| # Handle growth stages if present in the results | |
| if hasattr(results, 'boxes') and results.boxes is not None: | |
| if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0: | |
| # Check if there are additional columns for growth stages | |
| if results.boxes.data.shape[1] > 6: | |
| growth_stages = results.boxes.data[:, 6:].tolist() | |
| results_dict["growth_stages"] = growth_stages | |
| logger.info(f"YOLO prediction completed successfully") | |
| return JSONResponse(content=results_dict) | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(tmp_file_path): | |
| os.unlink(tmp_file_path) | |
| except Exception as e: | |
| logger.error(f"Error in predict_yolo_api: {e}") | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def predict_pipeline_api(file: UploadFile = File(...)): | |
| """Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction""" | |
| if ndvi_model is None or yolo_model is None: | |
| return JSONResponse(status_code=500, content={"error": "Models not loaded properly"}) | |
| try: | |
| logger.info(f"Starting full pipeline for file: {file.filename}") | |
| # Read uploaded RGB image | |
| contents = await file.read() | |
| logger.info(f"Read {len(contents)} bytes from uploaded file") | |
| # Convert to PIL Image and then to numpy array | |
| img = Image.open(BytesIO(contents)).convert("RGB") | |
| rgb_array = np.array(img) | |
| logger.info(f"Converted to RGB array with shape: {rgb_array.shape}") | |
| # Run the full pipeline (now includes resizing internally) | |
| results = predict_pipeline(ndvi_model, yolo_model, rgb_array) | |
| logger.info("Pipeline processing completed successfully") | |
| # Convert results to JSON-serializable format | |
| results_dict = { | |
| "boxes": { | |
| "xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None, | |
| "conf": results.boxes.conf.tolist() if results.boxes is not None else None, | |
| "cls": results.boxes.cls.tolist() if results.boxes is not None else None | |
| }, | |
| "classes": results.boxes.cls.tolist() if results.boxes is not None else None, | |
| "names": results.names, | |
| "orig_shape": results.orig_shape, | |
| "speed": results.speed, | |
| "masks": { | |
| "data": results.masks.data.tolist() if results.masks is not None else None, | |
| "orig_shape": results.masks.orig_shape if results.masks is not None else None, | |
| "xy": [seg.tolist() for seg in results.masks.xy] if results.masks is not None else None, | |
| "xyn": [seg.tolist() for seg in results.masks.xyn] if results.masks is not None else None | |
| } | |
| } | |
| # Handle growth stages if present in the results | |
| if hasattr(results, 'boxes') and results.boxes is not None: | |
| if hasattr(results.boxes, 'data') and len(results.boxes.data) > 0: | |
| # Check if there are additional columns for growth stages | |
| if results.boxes.data.shape[1] > 6: | |
| growth_stages = results.boxes.data[:, 6:].tolist() | |
| results_dict["growth_stages"] = growth_stages | |
| logger.info(f"Pipeline prediction completed successfully with {len(results_dict['boxes']['xyxyn']) if results_dict['boxes']['xyxyn'] else 0} detections") | |
| return JSONResponse(content=results_dict) | |
| except Exception as e: | |
| logger.error(f"Error in predict_pipeline_api: {e}") | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| # New endpoints to add to your FastAPI app | |
| from yolo_predictor import predict_yolo_with_image, predict_pipeline_with_image, pil_image_to_bytes | |
| async def predict_yolo_image_api(file: UploadFile = File(...)): | |
| """Predict YOLO results from 4-channel TIFF image and return annotated image""" | |
| if yolo_model is None: | |
| return JSONResponse(status_code=500, content={"error": "YOLO model not loaded"}) | |
| try: | |
| # Save uploaded file temporarily with proper extension | |
| file_extension = '.tiff' if file.filename and file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff' | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
| contents = await file.read() | |
| tmp_file.write(contents) | |
| tmp_file.flush() # Ensure data is written | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Verify the file was written correctly | |
| if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0: | |
| raise ValueError("Failed to create temporary file") | |
| logger.info(f"Processing YOLO prediction with image output for file: {file.filename}, temp path: {tmp_file_path}") | |
| # Additional validation: check if file has 4 channels | |
| try: | |
| import tifffile | |
| test_array = tifffile.imread(tmp_file_path) | |
| if len(test_array.shape) == 3: | |
| if test_array.shape[0] == 4 or test_array.shape[2] == 4: | |
| channels = 4 | |
| else: | |
| channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2] | |
| else: | |
| channels = 1 | |
| if channels != 4: | |
| raise ValueError(f"YOLO model expects 4-channel images, but uploaded file has {channels} channels") | |
| except Exception as validation_error: | |
| logger.warning(f"Could not validate channels: {validation_error}") | |
| # Predict using YOLO model and get annotated image | |
| annotated_image = predict_yolo_with_image(yolo_model, tmp_file_path) | |
| # Convert PIL Image to bytes for response | |
| img_bytes = pil_image_to_bytes(annotated_image, format='PNG') | |
| logger.info(f"YOLO prediction with image output completed successfully") | |
| return StreamingResponse( | |
| img_bytes, | |
| media_type="image/png", | |
| headers={"Content-Disposition": f"attachment; filename=yolo_annotated_{file.filename}.png"} | |
| ) | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(tmp_file_path): | |
| os.unlink(tmp_file_path) | |
| except Exception as e: | |
| logger.error(f"Error in predict_yolo_image_api: {e}") | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def predict_pipeline_image_api(file: UploadFile = File(...)): | |
| """Full pipeline with image output: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction -> Annotated Image""" | |
| if ndvi_model is None or yolo_model is None: | |
| return JSONResponse(status_code=500, content={"error": "Models not loaded properly"}) | |
| try: | |
| logger.info(f"Starting full pipeline with image output for file: {file.filename}") | |
| # Read uploaded RGB image | |
| contents = await file.read() | |
| logger.info(f"Read {len(contents)} bytes from uploaded file") | |
| # Convert to PIL Image and then to numpy array | |
| img = Image.open(BytesIO(contents)).convert("RGB") | |
| rgb_array = np.array(img) | |
| logger.info(f"Converted to RGB array with shape: {rgb_array.shape}") | |
| # Run the full pipeline with image output (includes resizing internally) | |
| annotated_image = predict_pipeline_with_image(ndvi_model, yolo_model, rgb_array) | |
| logger.info("Pipeline processing with image output completed successfully") | |
| # Convert PIL Image to bytes for response | |
| img_bytes = pil_image_to_bytes(annotated_image, format='PNG') | |
| logger.info(f"Pipeline prediction with image output completed successfully") | |
| return StreamingResponse( | |
| img_bytes, | |
| media_type="image/png", | |
| headers={"Content-Disposition": f"attachment; filename=pipeline_annotated_{file.filename}.png"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in predict_pipeline_image_api: {e}") | |
| return JSONResponse(status_code=500, content={"error": str(e)}) |