Testing-Pipeline-API / yolo_predictor.py
ahadhassan's picture
Update yolo_predictor.py
59519c9 verified
raw
history blame
10.1 kB
# yolo_predictor.py
import os
import numpy as np
import rasterio
from ultralytics import YOLO
from ndvi_predictor import normalize_rgb, predict_ndvi
import tempfile
from rasterio.transform import from_bounds
from PIL import Image
import tifffile
def load_yolo_model(model_path):
"""Load YOLO model from .pt file"""
return YOLO(model_path)
def predict_ndvi_from_rgb(ndvi_model, rgb_array):
"""
Predict NDVI channel from RGB array
Args:
ndvi_model: Loaded NDVI prediction model
rgb_array: RGB image as numpy array (H, W, 3)
Returns:
ndvi_array: Predicted NDVI as numpy array (H, W)
"""
# Normalize RGB input
norm_rgb = normalize_rgb(rgb_array)
# Predict NDVI
ndvi_pred = predict_ndvi(ndvi_model, norm_rgb)
return ndvi_pred
def predict_yolo(yolo_model, image_path, conf=0.001):
"""
Predict using YOLO model on 4-channel TIFF image
Args:
yolo_model: Loaded YOLO model
image_path: Path to 4-channel TIFF image
conf: Confidence threshold
Returns:
results: YOLO results object
"""
# Verify the image has 4 channels before prediction
try:
# Check image format and channels
with Image.open(image_path) as img:
if hasattr(img, 'n_frames'):
# Multi-frame TIFF
channels = img.n_frames
else:
# Regular image
channels = len(img.getbands()) if hasattr(img, 'getbands') else 3
# If not 4 channels, try with tifffile
if channels != 4:
img_array = tifffile.imread(image_path)
if len(img_array.shape) == 3:
if img_array.shape[0] == 4:
channels = 4
elif img_array.shape[2] == 4:
channels = 4
else:
channels = img_array.shape[0] if img_array.shape[0] <= 4 else img_array.shape[2]
else:
channels = 1
if channels != 4:
raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels")
except Exception as e:
raise ValueError(f"Error reading image channels: {str(e)}")
# Run YOLO prediction
results = yolo_model([image_path], conf=conf)
return results[0] # Return first result
def create_4channel_tiff(rgb_array, ndvi_array, output_path):
"""
Create a 4-channel TIFF file from RGB and NDVI arrays compatible with PIL and YOLO
Args:
rgb_array: RGB image as numpy array (H, W, 3)
ndvi_array: NDVI image as numpy array (H, W)
output_path: Path to save the 4-channel TIFF
"""
height, width = rgb_array.shape[:2]
# Ensure RGB is in uint8 format for better compatibility
if rgb_array.dtype != np.uint8:
if rgb_array.max() <= 1.0:
rgb_normalized = (rgb_array * 255).astype(np.uint8)
else:
rgb_normalized = np.clip(rgb_array, 0, 255).astype(np.uint8)
else:
rgb_normalized = rgb_array
# Convert NDVI from [-1, 1] to [0, 255] for uint8 storage
ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8)
# Create 4-channel array in (H, W, 4) format
four_channel = np.zeros((height, width, 4), dtype=np.uint8)
four_channel[:, :, 0] = rgb_normalized[:, :, 0] # Red
four_channel[:, :, 1] = rgb_normalized[:, :, 1] # Green
four_channel[:, :, 2] = rgb_normalized[:, :, 2] # Blue
four_channel[:, :, 3] = ndvi_normalized # NDVI
# Save using tifffile with proper format for YOLO compatibility
tifffile.imwrite(
output_path,
four_channel,
photometric='rgb',
compress='lzw',
metadata={'axes': 'YXC'}
)
def load_4channel_tiff(image_path):
"""
Load a 4-channel TIFF image
Args:
image_path: Path to 4-channel TIFF image
Returns:
rgb_array: RGB channels as numpy array (H, W, 3)
ndvi_array: NDVI channel as numpy array (H, W)
"""
try:
# Try with tifffile first for better TIFF support
img_array = tifffile.imread(image_path)
if len(img_array.shape) == 3:
if img_array.shape[0] == 4:
# Shape is (4, H, W) - transpose to (H, W, 4)
img_array = np.transpose(img_array, (1, 2, 0))
elif img_array.shape[2] != 4:
raise ValueError(f"Expected 4 channels, got {img_array.shape}")
# Extract RGB and NDVI from (H, W, 4) format
rgb_array = img_array[:, :, :3]
ndvi_array = img_array[:, :, 3]
# Convert NDVI back from [0, 255] to [-1, 1] if it was stored as uint8
if img_array.dtype == np.uint8:
ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
return rgb_array, ndvi_array
except Exception as e:
# Fallback to rasterio
try:
with rasterio.open(image_path) as src:
if src.count != 4:
raise ValueError(f"Expected 4 channels, got {src.count}")
channels = src.read() # Shape: (4, H, W)
# Extract RGB and NDVI
rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3)
ndvi_array = channels[3] # (H, W)
# Convert NDVI if needed
if channels.dtype == np.uint8:
ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
return rgb_array, ndvi_array
except Exception as e2:
raise ValueError(f"Could not load 4-channel TIFF. Errors: tifffile={e}, rasterio={e2}")
def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
"""
Full pipeline: Load image -> Extract RGB -> Predict NDVI ->
Create 4-channel TIFF -> Run YOLO prediction
Args:
ndvi_model: Loaded NDVI prediction model
yolo_model: Loaded YOLO model
image_path: Path to input image (can be RGB or 4-channel TIFF)
conf: Confidence threshold for YOLO
Returns:
results: YOLO results object
"""
rgb_array = None
# Try multiple methods to load the image and extract RGB
try:
# Method 1: Try with tifffile first (best for complex TIFF files)
img_array = tifffile.imread(image_path)
if len(img_array.shape) == 3:
if img_array.shape[0] == 4:
# Shape is (4, H, W) - extract RGB
rgb_array = np.transpose(img_array[:3], (1, 2, 0))
elif img_array.shape[0] == 3:
# Shape is (3, H, W) - transpose to RGB
rgb_array = np.transpose(img_array, (1, 2, 0))
elif img_array.shape[2] == 4:
# Shape is (H, W, 4) - extract RGB
rgb_array = img_array[:, :, :3]
elif img_array.shape[2] == 3:
# Shape is (H, W, 3) - already RGB
rgb_array = img_array
elif len(img_array.shape) == 2:
# Grayscale - convert to RGB
rgb_array = np.stack([img_array] * 3, axis=-1)
except Exception as e1:
try:
# Method 2: Try with rasterio
with rasterio.open(image_path) as src:
channels = src.read()
if src.count >= 3:
rgb_array = np.transpose(channels[:3], (1, 2, 0))
elif src.count == 1:
# Single channel - convert to RGB
single_channel = channels[0]
rgb_array = np.stack([single_channel] * 3, axis=-1)
except Exception as e2:
try:
# Method 3: Fall back to PIL for standard formats
img = Image.open(image_path)
if img.mode != 'RGB':
img = img.convert('RGB')
rgb_array = np.array(img)
except Exception as e3:
raise ValueError(f"Could not load image with any method. Errors: tifffile={e1}, rasterio={e2}, PIL={e3}")
if rgb_array is None:
raise ValueError("Failed to extract RGB data from image")
# Ensure RGB is in correct format and range
if rgb_array.dtype == np.uint8:
# Keep as uint8 but also create float version for NDVI prediction
rgb_float = rgb_array.astype(np.float32) / 255.0
else:
# Already float, ensure range is [0, 1]
if rgb_array.max() > 1.0:
rgb_float = rgb_array / 255.0
else:
rgb_float = rgb_array
rgb_array = (rgb_float * 255).astype(np.uint8)
# Predict NDVI from RGB
ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_float)
# Create temporary 4-channel TIFF file
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
temp_4ch_path = tmp_file.name
try:
# Create 4-channel TIFF with predicted NDVI
create_4channel_tiff(rgb_array, ndvi_pred, temp_4ch_path)
# Verify the created file can be read
try:
test_img = Image.open(temp_4ch_path)
if hasattr(test_img, 'n_frames'):
channels = test_img.n_frames
else:
channels = len(test_img.getbands())
test_img.close()
if channels != 4:
raise ValueError(f"Created TIFF has {channels} channels instead of 4")
except Exception as e:
raise ValueError(f"Created TIFF file is not readable: {str(e)}")
# Run YOLO prediction on 4-channel image
results = predict_yolo(yolo_model, temp_4ch_path, conf=conf)
return results
finally:
# Clean up temporary file
if os.path.exists(temp_4ch_path):
os.unlink(temp_4ch_path)