Testing-Pipeline-API / yolo_predictor.py
ahadhassan's picture
Update yolo_predictor.py
65137d3 verified
raw
history blame
7.17 kB
# yolo_predictor.py
import os
import numpy as np
import rasterio
from ultralytics import YOLO
from ndvi_predictor import normalize_rgb, predict_ndvi
import tempfile
from rasterio.transform import from_bounds
from PIL import Image
import tifffile
def load_yolo_model(model_path):
"""Load YOLO model from .pt file"""
return YOLO(model_path)
def predict_ndvi_from_rgb(ndvi_model, rgb_array):
"""
Predict NDVI channel from RGB array
Args:
ndvi_model: Loaded NDVI prediction model
rgb_array: RGB image as numpy array (H, W, 3)
Returns:
ndvi_array: Predicted NDVI as numpy array (H, W)
"""
# Normalize RGB input
norm_rgb = normalize_rgb(rgb_array)
# Predict NDVI
ndvi_pred = predict_ndvi(ndvi_model, norm_rgb)
return ndvi_pred
def predict_yolo(yolo_model, image_path, conf=0.001):
"""
Predict using YOLO model on 4-channel TIFF image
Args:
yolo_model: Loaded YOLO model
image_path: Path to 4-channel TIFF image
conf: Confidence threshold
Returns:
results: YOLO results object
"""
# Run YOLO prediction
results = yolo_model([image_path], conf=conf)
return results[0] # Return first result
def create_4channel_tiff(rgb_array, ndvi_array, output_path):
"""
Create a 4-channel TIFF file from RGB and NDVI arrays
Args:
rgb_array: RGB image as numpy array (H, W, 3)
ndvi_array: NDVI image as numpy array (H, W)
output_path: Path to save the 4-channel TIFF
"""
height, width = rgb_array.shape[:2]
# Stack RGB and NDVI to create 4-channel image
four_channel = np.zeros((4, height, width), dtype=np.float32)
# Convert RGB to proper format and range
if rgb_array.dtype == np.uint8:
rgb_normalized = rgb_array.astype(np.float32) / 255.0
else:
rgb_normalized = rgb_array.astype(np.float32)
# Assign channels in (C, H, W) format for rasterio
four_channel[0] = rgb_normalized[:, :, 0] # Red
four_channel[1] = rgb_normalized[:, :, 1] # Green
four_channel[2] = rgb_normalized[:, :, 2] # Blue
four_channel[3] = ndvi_array.astype(np.float32) # NDVI
# Use tifffile for better compatibility with YOLO
import tifffile
tifffile.imwrite(output_path, four_channel, photometric='rgb')
def load_4channel_tiff(image_path):
"""
Load a 4-channel TIFF image
Args:
image_path: Path to 4-channel TIFF image
Returns:
rgb_array: RGB channels as numpy array (H, W, 3)
ndvi_array: NDVI channel as numpy array (H, W)
"""
try:
with rasterio.open(image_path) as src:
# Read all 4 channels
channels = src.read() # Shape: (4, H, W)
# Extract RGB and NDVI
rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3)
ndvi_array = channels[3] # (H, W)
# If NDVI was scaled to uint8, convert back to [-1, 1] range
if channels.dtype == np.uint8:
ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
return rgb_array, ndvi_array
except Exception as e:
# Try with tifffile as fallback
import tifffile
img_array = tifffile.imread(image_path)
if len(img_array.shape) == 3 and img_array.shape[0] == 4:
# Shape is (4, H, W)
rgb_array = np.transpose(img_array[:3], (1, 2, 0)) # (H, W, 3)
ndvi_array = img_array[3] # (H, W)
elif len(img_array.shape) == 3 and img_array.shape[2] == 4:
# Shape is (H, W, 4)
rgb_array = img_array[:, :, :3] # (H, W, 3)
ndvi_array = img_array[:, :, 3] # (H, W)
else:
raise ValueError(f"Unexpected image shape: {img_array.shape}")
# Normalize NDVI if needed
if img_array.dtype == np.uint8:
ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1
return rgb_array, ndvi_array
def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001):
"""
Full pipeline: Load 4-channel image -> Extract RGB -> Predict NDVI ->
Create new 4-channel with predicted NDVI -> Run YOLO prediction
Args:
ndvi_model: Loaded NDVI prediction model
yolo_model: Loaded YOLO model
image_path: Path to input image (can be RGB or 4-channel TIFF)
conf: Confidence threshold for YOLO
Returns:
results: YOLO results object
"""
rgb_array = None
# Try multiple methods to load the image
try:
# Method 1: Try with tifffile first (best for complex TIFF files)
import tifffile
img_array = tifffile.imread(image_path)
if len(img_array.shape) == 3:
if img_array.shape[0] == 4:
# Shape is (4, H, W) - extract RGB
rgb_array = np.transpose(img_array[:3], (1, 2, 0))
elif img_array.shape[2] == 4:
# Shape is (H, W, 4) - extract RGB
rgb_array = img_array[:, :, :3]
elif img_array.shape[2] == 3:
# Shape is (H, W, 3) - already RGB
rgb_array = img_array
elif img_array.shape[0] == 3:
# Shape is (3, H, W) - transpose to RGB
rgb_array = np.transpose(img_array, (1, 2, 0))
except Exception as e1:
try:
# Method 2: Try with rasterio
with rasterio.open(image_path) as src:
if src.count >= 3:
channels = src.read()
if src.count == 4:
rgb_array = np.transpose(channels[:3], (1, 2, 0))
else:
rgb_array = np.transpose(channels, (1, 2, 0))
except Exception as e2:
try:
# Method 3: Fall back to PIL for standard formats
img = Image.open(image_path).convert("RGB")
rgb_array = np.array(img)
except Exception as e3:
raise ValueError(f"Could not load image with any method. Errors: tifffile={e1}, rasterio={e2}, PIL={e3}")
if rgb_array is None:
raise ValueError("Failed to extract RGB data from image")
# Ensure RGB is in correct format and range
if rgb_array.max() > 1:
rgb_array = rgb_array.astype(np.float32) / 255.0
# Predict NDVI from RGB
ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_array)
# Create temporary 4-channel TIFF file
with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
temp_4ch_path = tmp_file.name
try:
# Create 4-channel TIFF with predicted NDVI
create_4channel_tiff(rgb_array, ndvi_pred, temp_4ch_path)
# Run YOLO prediction on 4-channel image
results = predict_yolo(yolo_model, temp_4ch_path, conf=conf)
return results
finally:
# Clean up temporary file
if os.path.exists(temp_4ch_path):
os.unlink(temp_4ch_path)