# yolo_predictor.py import os import logging import tempfile import numpy as np import tifffile from rasterio.transform import from_bounds from ultralytics import YOLO from ndvi_predictor import normalize_rgb, predict_ndvi # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def load_yolo_model(model_path): """Load YOLO model from .pt file""" logger.info(f"Loading YOLO model from: {model_path}") return YOLO(model_path) def predict_yolo(yolo_model, image_path, conf=0.01): """ Predict using YOLO model on 4-channel TIFF image Args: yolo_model: Loaded YOLO model image_path: Path to 4-channel TIFF image conf: Confidence threshold Returns: results: YOLO results object """ logger.info(f"Starting YOLO prediction on: {image_path} with confidence: {conf}") # Verify file exists and has correct format if not os.path.exists(image_path): raise FileNotFoundError(f"Image file not found: {image_path}") try: # Quick validation of the TIFF file test_array = tifffile.imread(image_path) logger.info(f"TIFF file shape: {test_array.shape}, dtype: {test_array.dtype}") # Validate channels if len(test_array.shape) == 3: channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2] else: channels = 1 if channels != 4: raise ValueError(f"Expected 4-channel image, got {channels} channels") except Exception as e: logger.error(f"Error validating TIFF file: {e}") raise logger.info("Running YOLO model inference...") # Run YOLO prediction directly on the input file results = yolo_model([image_path], conf=conf) logger.info(f"YOLO prediction completed. Results type: {type(results[0])}") return results[0] # Return first result def create_4channel_tiff(rgb_array, ndvi_array, output_path): """ Create a 4-channel TIFF file with RGB channels + NDVI channel Args: rgb_array: RGB image array (H, W, 3) ndvi_array: NDVI array (H, W) with values in [-1, 1] output_path: Path to save the 4-channel TIFF """ logger.info(f"Creating 4-channel TIFF file at: {output_path}") logger.info(f"RGB shape: {rgb_array.shape}, NDVI shape: {ndvi_array.shape}") # Ensure RGB is in uint8 format if rgb_array.dtype != np.uint8: if rgb_array.max() <= 1.0: rgb_uint8 = (rgb_array * 255).astype(np.uint8) else: rgb_uint8 = rgb_array.astype(np.uint8) else: rgb_uint8 = rgb_array # Convert NDVI from [-1, 1] to [0, 255] uint8 format (same as reference code) ndvi_scaled = (((ndvi_array + 1) / 2) * 255).astype(np.uint8) logger.info(f"RGB range: [{rgb_uint8.min()}, {rgb_uint8.max()}]") logger.info(f"NDVI scaled range: [{ndvi_scaled.min()}, {ndvi_scaled.max()}]") # Stack RGB + NDVI to create 4-channel image # Format: (channels, height, width) - channel-first format four_channel = np.stack([ rgb_uint8[:, :, 0], # R channel rgb_uint8[:, :, 1], # G channel rgb_uint8[:, :, 2], # B channel ndvi_scaled # NDVI channel ], axis=0) logger.info(f"4-channel array shape: {four_channel.shape}, dtype: {four_channel.dtype}") logger.info(f"4-channel range: [{four_channel.min()}, {four_channel.max()}]") # Save as TIFF using tifffile tifffile.imwrite(output_path, four_channel) logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}") def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.01): """ Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction Args: ndvi_model: Loaded NDVI prediction model yolo_model: Loaded YOLO model rgb_array: RGB image as numpy array (H, W, 3) conf: Confidence threshold for YOLO Returns: results: YOLO results object """ logger.info("Starting full prediction pipeline") logger.info(f"Input RGB array shape: {rgb_array.shape}, dtype: {rgb_array.dtype}") # Step 1: Normalize RGB image logger.info("Step 1: Normalizing RGB image") normalized_rgb = normalize_rgb(rgb_array) logger.info(f"Normalized RGB shape: {normalized_rgb.shape}, range: [{normalized_rgb.min():.3f}, {normalized_rgb.max():.3f}]") # Step 2: Predict NDVI logger.info("Step 2: Predicting NDVI from RGB") ndvi_prediction = predict_ndvi(ndvi_model, normalized_rgb) logger.info(f"NDVI prediction shape: {ndvi_prediction.shape}, range: [{ndvi_prediction.min():.3f}, {ndvi_prediction.max():.3f}]") # Step 3: Create 4-channel TIFF file logger.info("Step 3: Creating 4-channel TIFF file (BGR+NDVI)") # Create temporary file for the 4-channel TIFF with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file: tiff_path = tmp_file.name try: # Create the 4-channel TIFF create_4channel_tiff(rgb_array, ndvi_prediction, tiff_path) # Verify the created file if not os.path.exists(tiff_path): raise FileNotFoundError(f"Failed to create 4-channel TIFF at: {tiff_path}") file_size = os.path.getsize(tiff_path) logger.info(f"Created 4-channel TIFF file size: {file_size} bytes") # Step 4: Run YOLO prediction on the 4-channel TIFF logger.info("Step 4: Running YOLO prediction on 4-channel TIFF") results = predict_yolo(yolo_model, tiff_path, conf=conf) logger.info("Full pipeline completed successfully") return results except Exception as e: logger.error(f"Error in pipeline: {e}") raise finally: # Clean up temporary file if os.path.exists(tiff_path): try: os.unlink(tiff_path) logger.info(f"Cleaned up temporary file: {tiff_path}") except Exception as cleanup_error: logger.warning(f"Failed to clean up temporary file: {cleanup_error}") def validate_4channel_tiff(tiff_path): """ Validate that a TIFF file has exactly 4 channels Args: tiff_path: Path to TIFF file Returns: bool: True if valid 4-channel TIFF, False otherwise """ try: array = tifffile.imread(tiff_path) if len(array.shape) == 3: channels = array.shape[0] if array.shape[0] <= 4 else array.shape[2] else: channels = 1 logger.info(f"TIFF validation - Shape: {array.shape}, Channels: {channels}") return channels == 4 except Exception as e: logger.error(f"Error validating TIFF file: {e}") return False