ahadhassan commited on
Commit
abb6af6
·
1 Parent(s): 3143e36

New endpoints

Browse files
Files changed (2) hide show
  1. app.py +104 -0
  2. yolo_predictor.py +164 -3
app.py CHANGED
@@ -231,4 +231,108 @@ async def predict_pipeline_api(file: UploadFile = File(...)):
231
 
232
  except Exception as e:
233
  logger.error(f"Error in predict_pipeline_api: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return JSONResponse(status_code=500, content={"error": str(e)})
 
231
 
232
  except Exception as e:
233
  logger.error(f"Error in predict_pipeline_api: {e}")
234
+ return JSONResponse(status_code=500, content={"error": str(e)})
235
+
236
+ # New endpoints to add to your FastAPI app
237
+ from yolo_predictor import predict_yolo_with_image, predict_pipeline_with_image, pil_image_to_bytes
238
+
239
+ @app.post("/predict_yolo_image/")
240
+ async def predict_yolo_image_api(file: UploadFile = File(...)):
241
+ """Predict YOLO results from 4-channel TIFF image and return annotated image"""
242
+ if yolo_model is None:
243
+ return JSONResponse(status_code=500, content={"error": "YOLO model not loaded"})
244
+
245
+ try:
246
+ # Save uploaded file temporarily with proper extension
247
+ file_extension = '.tiff' if file.filename and file.filename.lower().endswith(('.tif', '.tiff')) else '.tiff'
248
+
249
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
250
+ contents = await file.read()
251
+ tmp_file.write(contents)
252
+ tmp_file.flush() # Ensure data is written
253
+ tmp_file_path = tmp_file.name
254
+
255
+ try:
256
+ # Verify the file was written correctly
257
+ if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
258
+ raise ValueError("Failed to create temporary file")
259
+
260
+ logger.info(f"Processing YOLO prediction with image output for file: {file.filename}, temp path: {tmp_file_path}")
261
+
262
+ # Additional validation: check if file has 4 channels
263
+ try:
264
+ import tifffile
265
+ test_array = tifffile.imread(tmp_file_path)
266
+ if len(test_array.shape) == 3:
267
+ if test_array.shape[0] == 4 or test_array.shape[2] == 4:
268
+ channels = 4
269
+ else:
270
+ channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
271
+ else:
272
+ channels = 1
273
+
274
+ if channels != 4:
275
+ raise ValueError(f"YOLO model expects 4-channel images, but uploaded file has {channels} channels")
276
+
277
+ except Exception as validation_error:
278
+ logger.warning(f"Could not validate channels: {validation_error}")
279
+
280
+ # Predict using YOLO model and get annotated image
281
+ annotated_image = predict_yolo_with_image(yolo_model, tmp_file_path)
282
+
283
+ # Convert PIL Image to bytes for response
284
+ img_bytes = pil_image_to_bytes(annotated_image, format='PNG')
285
+
286
+ logger.info(f"YOLO prediction with image output completed successfully")
287
+
288
+ return StreamingResponse(
289
+ img_bytes,
290
+ media_type="image/png",
291
+ headers={"Content-Disposition": f"attachment; filename=yolo_annotated_{file.filename}.png"}
292
+ )
293
+
294
+ finally:
295
+ # Clean up temporary file
296
+ if os.path.exists(tmp_file_path):
297
+ os.unlink(tmp_file_path)
298
+
299
+ except Exception as e:
300
+ logger.error(f"Error in predict_yolo_image_api: {e}")
301
+ return JSONResponse(status_code=500, content={"error": str(e)})
302
+
303
+ @app.post("/predict_pipeline_image/")
304
+ async def predict_pipeline_image_api(file: UploadFile = File(...)):
305
+ """Full pipeline with image output: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction -> Annotated Image"""
306
+ if ndvi_model is None or yolo_model is None:
307
+ return JSONResponse(status_code=500, content={"error": "Models not loaded properly"})
308
+
309
+ try:
310
+ logger.info(f"Starting full pipeline with image output for file: {file.filename}")
311
+
312
+ # Read uploaded RGB image
313
+ contents = await file.read()
314
+ logger.info(f"Read {len(contents)} bytes from uploaded file")
315
+
316
+ # Convert to PIL Image and then to numpy array
317
+ img = Image.open(BytesIO(contents)).convert("RGB")
318
+ rgb_array = np.array(img)
319
+ logger.info(f"Converted to RGB array with shape: {rgb_array.shape}")
320
+
321
+ # Run the full pipeline with image output (includes resizing internally)
322
+ annotated_image = predict_pipeline_with_image(ndvi_model, yolo_model, rgb_array)
323
+ logger.info("Pipeline processing with image output completed successfully")
324
+
325
+ # Convert PIL Image to bytes for response
326
+ img_bytes = pil_image_to_bytes(annotated_image, format='PNG')
327
+
328
+ logger.info(f"Pipeline prediction with image output completed successfully")
329
+
330
+ return StreamingResponse(
331
+ img_bytes,
332
+ media_type="image/png",
333
+ headers={"Content-Disposition": f"attachment; filename=pipeline_annotated_{file.filename}.png"}
334
+ )
335
+
336
+ except Exception as e:
337
+ logger.error(f"Error in predict_pipeline_image_api: {e}")
338
  return JSONResponse(status_code=500, content={"error": str(e)})
yolo_predictor.py CHANGED
@@ -4,6 +4,9 @@ import logging
4
  import tempfile
5
  import numpy as np
6
  import tifffile
 
 
 
7
  from rasterio.transform import from_bounds
8
  from ultralytics import YOLO
9
  from ndvi_predictor import normalize_rgb, predict_ndvi
@@ -18,7 +21,7 @@ def load_yolo_model(model_path):
18
  logger.info(f"Loading YOLO model from: {model_path}")
19
  return YOLO(model_path)
20
 
21
- def predict_yolo(yolo_model, image_path, conf=0.001):
22
  """
23
  Predict using YOLO model on 4-channel TIFF image
24
 
@@ -104,7 +107,7 @@ def create_4channel_tiff(rgb_array, ndvi_array, output_path):
104
  tifffile.imwrite(output_path, four_channel)
105
  logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}")
106
 
107
- def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.001):
108
  """
109
  Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction
110
 
@@ -196,4 +199,162 @@ def validate_4channel_tiff(tiff_path):
196
 
197
  except Exception as e:
198
  logger.error(f"Error validating TIFF file: {e}")
199
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import tempfile
5
  import numpy as np
6
  import tifffile
7
+ from io import BytesIO
8
+ import cv2
9
+ from PIL import Image
10
  from rasterio.transform import from_bounds
11
  from ultralytics import YOLO
12
  from ndvi_predictor import normalize_rgb, predict_ndvi
 
21
  logger.info(f"Loading YOLO model from: {model_path}")
22
  return YOLO(model_path)
23
 
24
+ def predict_yolo(yolo_model, image_path, conf=0.25):
25
  """
26
  Predict using YOLO model on 4-channel TIFF image
27
 
 
107
  tifffile.imwrite(output_path, four_channel)
108
  logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}")
109
 
110
+ def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.25):
111
  """
112
  Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction
113
 
 
199
 
200
  except Exception as e:
201
  logger.error(f"Error validating TIFF file: {e}")
202
+ return False
203
+
204
+ # Additional functions for yolo_predictor.py
205
+
206
+ def predict_yolo_with_image(yolo_model, image_path, conf=0.25, save_path=None):
207
+ """
208
+ Predict using YOLO model on 4-channel TIFF image and return annotated image
209
+
210
+ Args:
211
+ yolo_model: Loaded YOLO model
212
+ image_path: Path to 4-channel TIFF image
213
+ conf: Confidence threshold
214
+ save_path: Optional path to save the annotated image
215
+
216
+ Returns:
217
+ annotated_image: PIL Image object with annotations
218
+ """
219
+ logger.info(f"Starting YOLO prediction with image output on: {image_path} with confidence: {conf}")
220
+
221
+ # Verify file exists and has correct format
222
+ if not os.path.exists(image_path):
223
+ raise FileNotFoundError(f"Image file not found: {image_path}")
224
+
225
+ try:
226
+ # Quick validation of the TIFF file
227
+ test_array = tifffile.imread(image_path)
228
+ logger.info(f"TIFF file shape: {test_array.shape}, dtype: {test_array.dtype}")
229
+
230
+ # Validate channels
231
+ if len(test_array.shape) == 3:
232
+ channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
233
+ else:
234
+ channels = 1
235
+
236
+ if channels != 4:
237
+ raise ValueError(f"Expected 4-channel image, got {channels} channels")
238
+
239
+ except Exception as e:
240
+ logger.error(f"Error validating TIFF file: {e}")
241
+ raise
242
+
243
+ logger.info("Running YOLO model inference with image output...")
244
+
245
+ # Run YOLO prediction directly on the input file
246
+ results = yolo_model([image_path], conf=conf)
247
+ result = results[0]
248
+
249
+ # Create temporary file for saving annotated image
250
+ if save_path is None:
251
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
252
+ save_path = tmp_file.name
253
+
254
+ try:
255
+ # Save the annotated image using ultralytics built-in method
256
+ result.save(save_path)
257
+ logger.info(f"Annotated image saved to: {save_path}")
258
+
259
+ # Load the saved image and convert to PIL Image
260
+ annotated_image = Image.open(save_path).convert('RGB')
261
+ logger.info(f"YOLO prediction with image output completed successfully")
262
+
263
+ return annotated_image
264
+
265
+ except Exception as e:
266
+ logger.error(f"Error saving annotated image: {e}")
267
+ raise
268
+ finally:
269
+ # Clean up temporary file if we created it
270
+ if save_path.endswith('.png') and os.path.exists(save_path):
271
+ try:
272
+ os.unlink(save_path)
273
+ logger.info(f"Cleaned up temporary annotated image file: {save_path}")
274
+ except Exception as cleanup_error:
275
+ logger.warning(f"Failed to clean up temporary file: {cleanup_error}")
276
+
277
+ def predict_pipeline_with_image(ndvi_model, yolo_model, rgb_array, conf=0.25):
278
+ """
279
+ Full pipeline with image output: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction -> Annotated Image
280
+
281
+ Args:
282
+ ndvi_model: Loaded NDVI prediction model
283
+ yolo_model: Loaded YOLO model
284
+ rgb_array: RGB image as numpy array (H, W, 3)
285
+ conf: Confidence threshold for YOLO
286
+
287
+ Returns:
288
+ annotated_image: PIL Image object with YOLO annotations
289
+ """
290
+ logger.info("Starting full prediction pipeline with image output")
291
+ logger.info(f"Input RGB array shape: {rgb_array.shape}, dtype: {rgb_array.dtype}")
292
+
293
+ # Step 1: Resize RGB image to target size
294
+ logger.info("Step 1: Resizing RGB image to target size")
295
+ target_size = (640, 640) # (height, width)
296
+ rgb_resized = resize_image_optimized(rgb_array, target_size)
297
+ logger.info(f"Resized RGB shape: {rgb_resized.shape}")
298
+
299
+ # Step 2: Normalize RGB image
300
+ logger.info("Step 2: Normalizing RGB image")
301
+ normalized_rgb = normalize_rgb(rgb_resized)
302
+ logger.info(f"Normalized RGB shape: {normalized_rgb.shape}, range: [{normalized_rgb.min():.3f}, {normalized_rgb.max():.3f}]")
303
+
304
+ # Step 3: Predict NDVI
305
+ logger.info("Step 3: Predicting NDVI from RGB")
306
+ ndvi_prediction = predict_ndvi(ndvi_model, normalized_rgb)
307
+ logger.info(f"NDVI prediction shape: {ndvi_prediction.shape}, range: [{ndvi_prediction.min():.3f}, {ndvi_prediction.max():.3f}]")
308
+
309
+ # Step 4: Create 4-channel TIFF file
310
+ logger.info("Step 4: Creating 4-channel TIFF file (RGB+NDVI)")
311
+
312
+ # Create temporary file for the 4-channel TIFF
313
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
314
+ tiff_path = tmp_file.name
315
+
316
+ try:
317
+ # Create the 4-channel TIFF using resized RGB and predicted NDVI
318
+ create_4channel_tiff(rgb_resized, ndvi_prediction, tiff_path)
319
+
320
+ # Verify the created file
321
+ if not os.path.exists(tiff_path):
322
+ raise FileNotFoundError(f"Failed to create 4-channel TIFF at: {tiff_path}")
323
+
324
+ file_size = os.path.getsize(tiff_path)
325
+ logger.info(f"Created 4-channel TIFF file size: {file_size} bytes")
326
+
327
+ # Step 5: Run YOLO prediction on the 4-channel TIFF and get annotated image
328
+ logger.info("Step 5: Running YOLO prediction on 4-channel TIFF with image output")
329
+ annotated_image = predict_yolo_with_image(yolo_model, tiff_path, conf=conf)
330
+
331
+ logger.info("Full pipeline with image output completed successfully")
332
+ return annotated_image
333
+
334
+ except Exception as e:
335
+ logger.error(f"Error in pipeline with image output: {e}")
336
+ raise
337
+ finally:
338
+ # Clean up temporary file
339
+ if os.path.exists(tiff_path):
340
+ try:
341
+ os.unlink(tiff_path)
342
+ logger.info(f"Cleaned up temporary file: {tiff_path}")
343
+ except Exception as cleanup_error:
344
+ logger.warning(f"Failed to clean up temporary file: {cleanup_error}")
345
+
346
+ def pil_image_to_bytes(image, format='PNG'):
347
+ """
348
+ Convert PIL Image to bytes for API response
349
+
350
+ Args:
351
+ image: PIL Image object
352
+ format: Image format ('PNG', 'JPEG', etc.)
353
+
354
+ Returns:
355
+ BytesIO: Image as bytes buffer
356
+ """
357
+ img_bytes = BytesIO()
358
+ image.save(img_bytes, format=format)
359
+ img_bytes.seek(0)
360
+ return img_bytes