Spaces:
Sleeping
Sleeping
| # yolo_predictor.py | |
| import os | |
| import numpy as np | |
| import rasterio | |
| from ultralytics import YOLO | |
| from ndvi_predictor import normalize_rgb, predict_ndvi | |
| import tempfile | |
| from rasterio.transform import from_bounds | |
| from PIL import Image | |
| import tifffile | |
| def load_yolo_model(model_path): | |
| """Load YOLO model from .pt file""" | |
| return YOLO(model_path) | |
| def predict_ndvi_from_rgb(ndvi_model, rgb_array): | |
| """ | |
| Predict NDVI channel from RGB array | |
| Args: | |
| ndvi_model: Loaded NDVI prediction model | |
| rgb_array: RGB image as numpy array (H, W, 3) | |
| Returns: | |
| ndvi_array: Predicted NDVI as numpy array (H, W) | |
| """ | |
| # Normalize RGB input | |
| norm_rgb = normalize_rgb(rgb_array) | |
| # Predict NDVI | |
| ndvi_pred = predict_ndvi(ndvi_model, norm_rgb) | |
| return ndvi_pred | |
| def predict_yolo(yolo_model, image_path, conf=0.001): | |
| """ | |
| Predict using YOLO model on 4-channel TIFF image | |
| Args: | |
| yolo_model: Loaded YOLO model | |
| image_path: Path to 4-channel TIFF image | |
| conf: Confidence threshold | |
| Returns: | |
| results: YOLO results object | |
| """ | |
| # Run YOLO prediction | |
| results = yolo_model([image_path], conf=conf) | |
| return results[0] # Return first result | |
| def create_4channel_tiff(rgb_array, ndvi_array, output_path): | |
| """ | |
| Create a 4-channel TIFF file from RGB and NDVI arrays | |
| Args: | |
| rgb_array: RGB image as numpy array (H, W, 3) | |
| ndvi_array: NDVI image as numpy array (H, W) | |
| output_path: Path to save the 4-channel TIFF | |
| """ | |
| height, width = rgb_array.shape[:2] | |
| # Stack RGB and NDVI to create 4-channel image | |
| four_channel = np.zeros((4, height, width), dtype=np.float32) | |
| # Convert RGB to proper format and range | |
| if rgb_array.dtype == np.uint8: | |
| rgb_normalized = rgb_array.astype(np.float32) / 255.0 | |
| else: | |
| rgb_normalized = rgb_array.astype(np.float32) | |
| # Assign channels in (C, H, W) format for rasterio | |
| four_channel[0] = rgb_normalized[:, :, 0] # Red | |
| four_channel[1] = rgb_normalized[:, :, 1] # Green | |
| four_channel[2] = rgb_normalized[:, :, 2] # Blue | |
| four_channel[3] = ndvi_array.astype(np.float32) # NDVI | |
| # Use tifffile for better compatibility with YOLO | |
| import tifffile | |
| tifffile.imwrite(output_path, four_channel, photometric='rgb') | |
| def load_4channel_tiff(image_path): | |
| """ | |
| Load a 4-channel TIFF image | |
| Args: | |
| image_path: Path to 4-channel TIFF image | |
| Returns: | |
| rgb_array: RGB channels as numpy array (H, W, 3) | |
| ndvi_array: NDVI channel as numpy array (H, W) | |
| """ | |
| try: | |
| with rasterio.open(image_path) as src: | |
| # Read all 4 channels | |
| channels = src.read() # Shape: (4, H, W) | |
| # Extract RGB and NDVI | |
| rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3) | |
| ndvi_array = channels[3] # (H, W) | |
| # If NDVI was scaled to uint8, convert back to [-1, 1] range | |
| if channels.dtype == np.uint8: | |
| ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1 | |
| return rgb_array, ndvi_array | |
| except Exception as e: | |
| # Try with tifffile as fallback | |
| import tifffile | |
| img_array = tifffile.imread(image_path) | |
| if len(img_array.shape) == 3 and img_array.shape[0] == 4: | |
| # Shape is (4, H, W) | |
| rgb_array = np.transpose(img_array[:3], (1, 2, 0)) # (H, W, 3) | |
| ndvi_array = img_array[3] # (H, W) | |
| elif len(img_array.shape) == 3 and img_array.shape[2] == 4: | |
| # Shape is (H, W, 4) | |
| rgb_array = img_array[:, :, :3] # (H, W, 3) | |
| ndvi_array = img_array[:, :, 3] # (H, W) | |
| else: | |
| raise ValueError(f"Unexpected image shape: {img_array.shape}") | |
| # Normalize NDVI if needed | |
| if img_array.dtype == np.uint8: | |
| ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1 | |
| return rgb_array, ndvi_array | |
| def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001): | |
| """ | |
| Full pipeline: Load 4-channel image -> Extract RGB -> Predict NDVI -> | |
| Create new 4-channel with predicted NDVI -> Run YOLO prediction | |
| Args: | |
| ndvi_model: Loaded NDVI prediction model | |
| yolo_model: Loaded YOLO model | |
| image_path: Path to input image (can be RGB or 4-channel TIFF) | |
| conf: Confidence threshold for YOLO | |
| Returns: | |
| results: YOLO results object | |
| """ | |
| rgb_array = None | |
| # Try multiple methods to load the image | |
| try: | |
| # Method 1: Try with tifffile first (best for complex TIFF files) | |
| import tifffile | |
| img_array = tifffile.imread(image_path) | |
| if len(img_array.shape) == 3: | |
| if img_array.shape[0] == 4: | |
| # Shape is (4, H, W) - extract RGB | |
| rgb_array = np.transpose(img_array[:3], (1, 2, 0)) | |
| elif img_array.shape[2] == 4: | |
| # Shape is (H, W, 4) - extract RGB | |
| rgb_array = img_array[:, :, :3] | |
| elif img_array.shape[2] == 3: | |
| # Shape is (H, W, 3) - already RGB | |
| rgb_array = img_array | |
| elif img_array.shape[0] == 3: | |
| # Shape is (3, H, W) - transpose to RGB | |
| rgb_array = np.transpose(img_array, (1, 2, 0)) | |
| except Exception as e1: | |
| try: | |
| # Method 2: Try with rasterio | |
| with rasterio.open(image_path) as src: | |
| if src.count >= 3: | |
| channels = src.read() | |
| if src.count == 4: | |
| rgb_array = np.transpose(channels[:3], (1, 2, 0)) | |
| else: | |
| rgb_array = np.transpose(channels, (1, 2, 0)) | |
| except Exception as e2: | |
| try: | |
| # Method 3: Fall back to PIL for standard formats | |
| img = Image.open(image_path).convert("RGB") | |
| rgb_array = np.array(img) | |
| except Exception as e3: | |
| raise ValueError(f"Could not load image with any method. Errors: tifffile={e1}, rasterio={e2}, PIL={e3}") | |
| if rgb_array is None: | |
| raise ValueError("Failed to extract RGB data from image") | |
| # Ensure RGB is in correct format and range | |
| if rgb_array.max() > 1: | |
| rgb_array = rgb_array.astype(np.float32) / 255.0 | |
| # Predict NDVI from RGB | |
| ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_array) | |
| # Create temporary 4-channel TIFF file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file: | |
| temp_4ch_path = tmp_file.name | |
| try: | |
| # Create 4-channel TIFF with predicted NDVI | |
| create_4channel_tiff(rgb_array, ndvi_pred, temp_4ch_path) | |
| # Run YOLO prediction on 4-channel image | |
| results = predict_yolo(yolo_model, temp_4ch_path, conf=conf) | |
| return results | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(temp_4ch_path): | |
| os.unlink(temp_4ch_path) |