Testing-Pipeline-API / yolo_predictor.py
Muhammad Ahad Hassan Khan
Added pipeline inference code
70135b4
raw
history blame
7.02 kB
# yolo_predictor.py
import os
import logging
import tempfile
import numpy as np
import tifffile
from rasterio.transform import from_bounds
from ultralytics import YOLO
from ndvi_predictor import normalize_rgb, predict_ndvi
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_yolo_model(model_path):
"""Load YOLO model from .pt file"""
logger.info(f"Loading YOLO model from: {model_path}")
return YOLO(model_path)
def predict_yolo(yolo_model, image_path, conf=0.01):
"""
Predict using YOLO model on 4-channel TIFF image
Args:
yolo_model: Loaded YOLO model
image_path: Path to 4-channel TIFF image
conf: Confidence threshold
Returns:
results: YOLO results object
"""
logger.info(f"Starting YOLO prediction on: {image_path} with confidence: {conf}")
# Verify file exists and has correct format
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
try:
# Quick validation of the TIFF file
test_array = tifffile.imread(image_path)
logger.info(f"TIFF file shape: {test_array.shape}, dtype: {test_array.dtype}")
# Validate channels
if len(test_array.shape) == 3:
channels = test_array.shape[0] if test_array.shape[0] <= 4 else test_array.shape[2]
else:
channels = 1
if channels != 4:
raise ValueError(f"Expected 4-channel image, got {channels} channels")
except Exception as e:
logger.error(f"Error validating TIFF file: {e}")
raise
logger.info("Running YOLO model inference...")
# Run YOLO prediction directly on the input file
results = yolo_model([image_path], conf=conf)
logger.info(f"YOLO prediction completed. Results type: {type(results[0])}")
return results[0] # Return first result
def create_4channel_tiff(rgb_array, ndvi_array, output_path):
"""
Create a 4-channel TIFF file with RGB channels + NDVI channel
Args:
rgb_array: RGB image array (H, W, 3)
ndvi_array: NDVI array (H, W) with values in [-1, 1]
output_path: Path to save the 4-channel TIFF
"""
logger.info(f"Creating 4-channel TIFF file at: {output_path}")
logger.info(f"RGB shape: {rgb_array.shape}, NDVI shape: {ndvi_array.shape}")
# Ensure RGB is in uint8 format
if rgb_array.dtype != np.uint8:
if rgb_array.max() <= 1.0:
rgb_uint8 = (rgb_array * 255).astype(np.uint8)
else:
rgb_uint8 = rgb_array.astype(np.uint8)
else:
rgb_uint8 = rgb_array
# Convert NDVI from [-1, 1] to [0, 255] uint8 format (same as reference code)
ndvi_scaled = (((ndvi_array + 1) / 2) * 255).astype(np.uint8)
logger.info(f"RGB range: [{rgb_uint8.min()}, {rgb_uint8.max()}]")
logger.info(f"NDVI scaled range: [{ndvi_scaled.min()}, {ndvi_scaled.max()}]")
# Stack RGB + NDVI to create 4-channel image
# Format: (channels, height, width) - channel-first format
four_channel = np.stack([
rgb_uint8[:, :, 0], # R channel
rgb_uint8[:, :, 1], # G channel
rgb_uint8[:, :, 2], # B channel
ndvi_scaled # NDVI channel
], axis=0)
logger.info(f"4-channel array shape: {four_channel.shape}, dtype: {four_channel.dtype}")
logger.info(f"4-channel range: [{four_channel.min()}, {four_channel.max()}]")
# Save as TIFF using tifffile
tifffile.imwrite(output_path, four_channel)
logger.info(f"Successfully saved 4-channel TIFF (RGB+NDVI format) to: {output_path}")
def predict_pipeline(ndvi_model, yolo_model, rgb_array, conf=0.01):
"""
Full pipeline: RGB -> NDVI -> 32-bit 4-channel TIFF (RGB+NDVI) -> YOLO prediction
Args:
ndvi_model: Loaded NDVI prediction model
yolo_model: Loaded YOLO model
rgb_array: RGB image as numpy array (H, W, 3)
conf: Confidence threshold for YOLO
Returns:
results: YOLO results object
"""
logger.info("Starting full prediction pipeline")
logger.info(f"Input RGB array shape: {rgb_array.shape}, dtype: {rgb_array.dtype}")
# Step 1: Normalize RGB image
logger.info("Step 1: Normalizing RGB image")
normalized_rgb = normalize_rgb(rgb_array)
logger.info(f"Normalized RGB shape: {normalized_rgb.shape}, range: [{normalized_rgb.min():.3f}, {normalized_rgb.max():.3f}]")
# Step 2: Predict NDVI
logger.info("Step 2: Predicting NDVI from RGB")
ndvi_prediction = predict_ndvi(ndvi_model, normalized_rgb)
logger.info(f"NDVI prediction shape: {ndvi_prediction.shape}, range: [{ndvi_prediction.min():.3f}, {ndvi_prediction.max():.3f}]")
# Step 3: Create 4-channel TIFF file
logger.info("Step 3: Creating 4-channel TIFF file (BGR+NDVI)")
# Create temporary file for the 4-channel TIFF
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
tiff_path = tmp_file.name
try:
# Create the 4-channel TIFF
create_4channel_tiff(rgb_array, ndvi_prediction, tiff_path)
# Verify the created file
if not os.path.exists(tiff_path):
raise FileNotFoundError(f"Failed to create 4-channel TIFF at: {tiff_path}")
file_size = os.path.getsize(tiff_path)
logger.info(f"Created 4-channel TIFF file size: {file_size} bytes")
# Step 4: Run YOLO prediction on the 4-channel TIFF
logger.info("Step 4: Running YOLO prediction on 4-channel TIFF")
results = predict_yolo(yolo_model, tiff_path, conf=conf)
logger.info("Full pipeline completed successfully")
return results
except Exception as e:
logger.error(f"Error in pipeline: {e}")
raise
finally:
# Clean up temporary file
if os.path.exists(tiff_path):
try:
os.unlink(tiff_path)
logger.info(f"Cleaned up temporary file: {tiff_path}")
except Exception as cleanup_error:
logger.warning(f"Failed to clean up temporary file: {cleanup_error}")
def validate_4channel_tiff(tiff_path):
"""
Validate that a TIFF file has exactly 4 channels
Args:
tiff_path: Path to TIFF file
Returns:
bool: True if valid 4-channel TIFF, False otherwise
"""
try:
array = tifffile.imread(tiff_path)
if len(array.shape) == 3:
channels = array.shape[0] if array.shape[0] <= 4 else array.shape[2]
else:
channels = 1
logger.info(f"TIFF validation - Shape: {array.shape}, Channels: {channels}")
return channels == 4
except Exception as e:
logger.error(f"Error validating TIFF file: {e}")
return False