Spaces:
Sleeping
Sleeping
Commit ·
abb6af6
1
Parent(s): 3143e36
New endpoints
Browse files- app.py +104 -0
- yolo_predictor.py +164 -3
app.py
CHANGED
|
@@ -231,4 +231,108 @@ async def predict_pipeline_api(file: UploadFile = File(...)):
|
|
| 231 |
|
| 232 |
except Exception as e:
|
| 233 |
logger.error(f"Error in predict_pipeline_api: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
| 231 |
|
| 232 |
except Exception as e:
|
| 233 |
logger.error(f"Error in predict_pipeline_api: {e}")
|
| 234 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 235 |
+
|
| 236 |
+
# New endpoints to add to your FastAPI app
|
| 237 |
+
from yolo_predictor import predict_yolo_with_image, predict_pipeline_with_image, pil_image_to_bytes
|
| 238 |
+
|
| 239 |
+
@app.post("/predict_yolo_image/")
|
| 240 |
+
async def predict_yolo_image_api(file: UploadFile = File(...)):
|
| 241 |
+
"""Predict YOLO results from 4-channel TIFF image and return annotated image"""
|
| 242 |
+
if yolo_model is None:
|
| 243 |
+
return JSONResponse(status_code=500, content={"error": "YOLO model not loaded"})
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
# Save uploaded file temporarily with proper extension
|
| 247 |
+
file_extension = '.tiff' if file.filename and file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff'
|
| 248 |
+
|
| 249 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
|
| 250 |
+
contents = await file.read()
|
| 251 |
+
tmp_file.write(contents)
|
| 252 |
+
tmp_file.flush() # Ensure data is written
|
| 253 |
+
tmp_file_path = tmp_file.name
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
# Verify the file was written correctly
|
| 257 |
+
if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
|
| 258 |
+
raise ValueError("Failed to create temporary file")
|
| 259 |
+
|
| 260 |
+
logger.info(f"Processing YOLO prediction with image output for file: {file.filename}, temp path: {tmp_file_path}")
|
| 261 |
+
|
| 262 |
+
# Additional validation: check if file has 4 channels
|
| 263 |
+
try:
|
| 264 |
+
import tifffile
|
| 265 |
+
test_array = tifffile.imread(tmp_file_path)
|
| 266 |
+
if len(test_array.shape) == 3:
|
| 267 |
+
if test_array.shape[0] == 4 or test_array.shape[2] == 4:
|
| 268 |
+
channels = 4
|
| 269 |
+
else:
|
| 270 |
+
channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
|
| 271 |
+
else:
|
| 272 |
+
channels = 1
|
| 273 |
+
|
| 274 |
+
if channels != 4:
|
| 275 |
+
raise ValueError(f"YOLO model expects 4-channel images, but uploaded file has {channels} channels")
|
| 276 |
+
|
| 277 |
+
except Exception as validation_error:
|
| 278 |
+
logger.warning(f"Could not validate channels: {validation_error}")
|
| 279 |
+
|
| 280 |
+
# Predict using YOLO model and get annotated image
|
| 281 |
+
annotated_image = predict_yolo_with_image(yolo_model, tmp_file_path)
|
| 282 |
+
|
| 283 |
+
# Convert PIL Image to bytes for response
|
| 284 |
+
img_bytes = pil_image_to_bytes(annotated_image, format='PNG')
|
| 285 |
+
|
| 286 |
+
logger.info(f"YOLO prediction with image output completed successfully")
|
| 287 |
+
|
| 288 |
+
return StreamingResponse(
|
| 289 |
+
img_bytes,
|
| 290 |
+
media_type="image/png",
|
| 291 |
+
headers={"Content-Disposition": f"attachment; filename=yolo_annotated_{file.filename}.png"}
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
finally:
|
| 295 |
+
# Clean up temporary file
|
| 296 |
+
if os.path.exists(tmp_file_path):
|
| 297 |
+
os.unlink(tmp_file_path)
|
| 298 |
+
|
| 299 |
+
except Exception as e:
|
| 300 |
+
logger.error(f"Error in predict_yolo_image_api: {e}")
|
| 301 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 302 |
+
|
| 303 |
+
@app.post("/predict_pipeline_image/")
|
| 304 |
+
async def predict_pipeline_image_api(file: UploadFile = File(...)):
|
| 305 |
+
"""Full pipeline with image output: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction -> Annotated Image"""
|
| 306 |
+
if ndvi_model is None or yolo_model is None:
|
| 307 |
+
return JSONResponse(status_code=500, content={"error": "Models not loaded properly"})
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
logger.info(f"Starting full pipeline with image output for file: {file.filename}")
|
| 311 |
+
|
| 312 |
+
# Read uploaded RGB image
|
| 313 |
+
contents = await file.read()
|
| 314 |
+
logger.info(f"Read {len(contents)} bytes from uploaded file")
|
| 315 |
+
|
| 316 |
+
# Convert to PIL Image and then to numpy array
|
| 317 |
+
img = Image.open(BytesIO(contents)).convert("RGB")
|
| 318 |
+
rgb_array = np.array(img)
|
| 319 |
+
logger.info(f"Converted to RGB array with shape: {rgb_array.shape}")
|
| 320 |
+
|
| 321 |
+
# Run the full pipeline with image output (includes resizing internally)
|
| 322 |
+
annotated_image = predict_pipeline_with_image(ndvi_model, yolo_model, rgb_array)
|
| 323 |
+
logger.info("Pipeline processing with image output completed successfully")
|
| 324 |
+
|
| 325 |
+
# Convert PIL Image to bytes for response
|
| 326 |
+
img_bytes = pil_image_to_bytes(annotated_image, format='PNG')
|
| 327 |
+
|
| 328 |
+
logger.info(f"Pipeline prediction with image output completed successfully")
|
| 329 |
+
|
| 330 |
+
return StreamingResponse(
|
| 331 |
+
img_bytes,
|
| 332 |
+
media_type="image/png",
|
| 333 |
+
headers={"Content-Disposition": f"attachment; filename=pipeline_annotated_{file.filename}.png"}
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
except Exception as e:
|
| 337 |
+
logger.error(f"Error in predict_pipeline_image_api: {e}")
|
| 338 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
yolo_predictor.py
CHANGED
|
@@ -4,6 +4,9 @@ import logging
|
|
| 4 |
import tempfile
|
| 5 |
import numpy as np
|
| 6 |
import tifffile
|
|
|
|
|
|
|
|
|
|
| 7 |
from rasterio.transform import from_bounds
|
| 8 |
from ultralytics import YOLO
|
| 9 |
from ndvi_predictor import normalize_rgb, predict_ndvi
|
|
@@ -18,7 +21,7 @@ def load_yolo_model(model_path):
|
|
| 18 |
logger.info(f"Loading YOLO model from: {model_path}")
|
| 19 |
return YOLO(model_path)
|
| 20 |
|
| 21 |
-
def predict_yolo(yolo_model, image_path, conf=0.
|
| 22 |
"""
|
| 23 |
Predict using YOLO model on 4-channel TIFF image
|
| 24 |
|
|
@@ -104,7 +107,7 @@ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
|
|
| 104 |
tifffile.imwrite(output_path, four_channel)
|
| 105 |
logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}")
|
| 106 |
|
| 107 |
-
def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.
|
| 108 |
"""
|
| 109 |
Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction
|
| 110 |
|
|
@@ -196,4 +199,162 @@ def validate_4channel_tiff(tiff_path):
|
|
| 196 |
|
| 197 |
except Exception as e:
|
| 198 |
logger.error(f"Error validating TIFF file: {e}")
|
| 199 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import tempfile
|
| 5 |
import numpy as np
|
| 6 |
import tifffile
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
import cv2
|
| 9 |
+
from PIL import Image
|
| 10 |
from rasterio.transform import from_bounds
|
| 11 |
from ultralytics import YOLO
|
| 12 |
from ndvi_predictor import normalize_rgb, predict_ndvi
|
|
|
|
| 21 |
logger.info(f"Loading YOLO model from: {model_path}")
|
| 22 |
return YOLO(model_path)
|
| 23 |
|
| 24 |
+
def predict_yolo(yolo_model, image_path, conf=0.25):
|
| 25 |
"""
|
| 26 |
Predict using YOLO model on 4-channel TIFF image
|
| 27 |
|
|
|
|
| 107 |
tifffile.imwrite(output_path, four_channel)
|
| 108 |
logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}")
|
| 109 |
|
| 110 |
+
def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.25):
|
| 111 |
"""
|
| 112 |
Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction
|
| 113 |
|
|
|
|
| 199 |
|
| 200 |
except Exception as e:
|
| 201 |
logger.error(f"Error validating TIFF file: {e}")
|
| 202 |
+
return False
|
| 203 |
+
|
| 204 |
+
# Additional functions for yolo_predictor.py
|
| 205 |
+
|
| 206 |
+
def predict_yolo_with_image(yolo_model, image_path, conf=0.25, save_path=None):
|
| 207 |
+
"""
|
| 208 |
+
Predict using YOLO model on 4-channel TIFF image and return annotated image
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
yolo_model: Loaded YOLO model
|
| 212 |
+
image_path: Path to 4-channel TIFF image
|
| 213 |
+
conf: Confidence threshold
|
| 214 |
+
save_path: Optional path to save the annotated image
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
annotated_image: PIL Image object with annotations
|
| 218 |
+
"""
|
| 219 |
+
logger.info(f"Starting YOLO prediction with image output on: {image_path} with confidence: {conf}")
|
| 220 |
+
|
| 221 |
+
# Verify file exists and has correct format
|
| 222 |
+
if not os.path.exists(image_path):
|
| 223 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
# Quick validation of the TIFF file
|
| 227 |
+
test_array = tifffile.imread(image_path)
|
| 228 |
+
logger.info(f"TIFF file shape: {test_array.shape}, dtype: {test_array.dtype}")
|
| 229 |
+
|
| 230 |
+
# Validate channels
|
| 231 |
+
if len(test_array.shape) == 3:
|
| 232 |
+
channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
|
| 233 |
+
else:
|
| 234 |
+
channels = 1
|
| 235 |
+
|
| 236 |
+
if channels != 4:
|
| 237 |
+
raise ValueError(f"Expected 4-channel image, got {channels} channels")
|
| 238 |
+
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.error(f"Error validating TIFF file: {e}")
|
| 241 |
+
raise
|
| 242 |
+
|
| 243 |
+
logger.info("Running YOLO model inference with image output...")
|
| 244 |
+
|
| 245 |
+
# Run YOLO prediction directly on the input file
|
| 246 |
+
results = yolo_model([image_path], conf=conf)
|
| 247 |
+
result = results[0]
|
| 248 |
+
|
| 249 |
+
# Create temporary file for saving annotated image
|
| 250 |
+
if save_path is None:
|
| 251 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
|
| 252 |
+
save_path = tmp_file.name
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
# Save the annotated image using ultralytics built-in method
|
| 256 |
+
result.save(save_path)
|
| 257 |
+
logger.info(f"Annotated image saved to: {save_path}")
|
| 258 |
+
|
| 259 |
+
# Load the saved image and convert to PIL Image
|
| 260 |
+
annotated_image = Image.open(save_path).convert('RGB')
|
| 261 |
+
logger.info(f"YOLO prediction with image output completed successfully")
|
| 262 |
+
|
| 263 |
+
return annotated_image
|
| 264 |
+
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.error(f"Error saving annotated image: {e}")
|
| 267 |
+
raise
|
| 268 |
+
finally:
|
| 269 |
+
# Clean up temporary file if we created it
|
| 270 |
+
if save_path.endswith('.png') and os.path.exists(save_path):
|
| 271 |
+
try:
|
| 272 |
+
os.unlink(save_path)
|
| 273 |
+
logger.info(f"Cleaned up temporary annotated image file: {save_path}")
|
| 274 |
+
except Exception as cleanup_error:
|
| 275 |
+
logger.warning(f"Failed to clean up temporary file: {cleanup_error}")
|
| 276 |
+
|
| 277 |
+
def predict_pipeline_with_image(ndvi_model, yolo_model, rgb_array, conf=0.25):
|
| 278 |
+
"""
|
| 279 |
+
Full pipeline with image output: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction -> Annotated Image
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
ndvi_model: Loaded NDVI prediction model
|
| 283 |
+
yolo_model: Loaded YOLO model
|
| 284 |
+
rgb_array: RGB image as numpy array (H, W, 3)
|
| 285 |
+
conf: Confidence threshold for YOLO
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
annotated_image: PIL Image object with YOLO annotations
|
| 289 |
+
"""
|
| 290 |
+
logger.info("Starting full prediction pipeline with image output")
|
| 291 |
+
logger.info(f"Input RGB array shape: {rgb_array.shape}, dtype: {rgb_array.dtype}")
|
| 292 |
+
|
| 293 |
+
# Step 1: Resize RGB image to target size
|
| 294 |
+
logger.info("Step 1: Resizing RGB image to target size")
|
| 295 |
+
target_size = (640, 640) # (height, width)
|
| 296 |
+
rgb_resized = resize_image_optimized(rgb_array, target_size)
|
| 297 |
+
logger.info(f"Resized RGB shape: {rgb_resized.shape}")
|
| 298 |
+
|
| 299 |
+
# Step 2: Normalize RGB image
|
| 300 |
+
logger.info("Step 2: Normalizing RGB image")
|
| 301 |
+
normalized_rgb = normalize_rgb(rgb_resized)
|
| 302 |
+
logger.info(f"Normalized RGB shape: {normalized_rgb.shape}, range: [{normalized_rgb.min():.3f}, {normalized_rgb.max():.3f}]")
|
| 303 |
+
|
| 304 |
+
# Step 3: Predict NDVI
|
| 305 |
+
logger.info("Step 3: Predicting NDVI from RGB")
|
| 306 |
+
ndvi_prediction = predict_ndvi(ndvi_model, normalized_rgb)
|
| 307 |
+
logger.info(f"NDVI prediction shape: {ndvi_prediction.shape}, range: [{ndvi_prediction.min():.3f}, {ndvi_prediction.max():.3f}]")
|
| 308 |
+
|
| 309 |
+
# Step 4: Create 4-channel TIFF file
|
| 310 |
+
logger.info("Step 4: Creating 4-channel TIFF file (RGB+NDVI)")
|
| 311 |
+
|
| 312 |
+
# Create temporary file for the 4-channel TIFF
|
| 313 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
|
| 314 |
+
tiff_path = tmp_file.name
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
# Create the 4-channel TIFF using resized RGB and predicted NDVI
|
| 318 |
+
create_4channel_tiff(rgb_resized, ndvi_prediction, tiff_path)
|
| 319 |
+
|
| 320 |
+
# Verify the created file
|
| 321 |
+
if not os.path.exists(tiff_path):
|
| 322 |
+
raise FileNotFoundError(f"Failed to create 4-channel TIFF at: {tiff_path}")
|
| 323 |
+
|
| 324 |
+
file_size = os.path.getsize(tiff_path)
|
| 325 |
+
logger.info(f"Created 4-channel TIFF file size: {file_size} bytes")
|
| 326 |
+
|
| 327 |
+
# Step 5: Run YOLO prediction on the 4-channel TIFF and get annotated image
|
| 328 |
+
logger.info("Step 5: Running YOLO prediction on 4-channel TIFF with image output")
|
| 329 |
+
annotated_image = predict_yolo_with_image(yolo_model, tiff_path, conf=conf)
|
| 330 |
+
|
| 331 |
+
logger.info("Full pipeline with image output completed successfully")
|
| 332 |
+
return annotated_image
|
| 333 |
+
|
| 334 |
+
except Exception as e:
|
| 335 |
+
logger.error(f"Error in pipeline with image output: {e}")
|
| 336 |
+
raise
|
| 337 |
+
finally:
|
| 338 |
+
# Clean up temporary file
|
| 339 |
+
if os.path.exists(tiff_path):
|
| 340 |
+
try:
|
| 341 |
+
os.unlink(tiff_path)
|
| 342 |
+
logger.info(f"Cleaned up temporary file: {tiff_path}")
|
| 343 |
+
except Exception as cleanup_error:
|
| 344 |
+
logger.warning(f"Failed to clean up temporary file: {cleanup_error}")
|
| 345 |
+
|
| 346 |
+
def pil_image_to_bytes(image, format='PNG'):
|
| 347 |
+
"""
|
| 348 |
+
Convert PIL Image to bytes for API response
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
image: PIL Image object
|
| 352 |
+
format: Image format ('PNG', 'JPEG', etc.)
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
BytesIO: Image as bytes buffer
|
| 356 |
+
"""
|
| 357 |
+
img_bytes = BytesIO()
|
| 358 |
+
image.save(img_bytes, format=format)
|
| 359 |
+
img_bytes.seek(0)
|
| 360 |
+
return img_bytes
|