Spaces:
Sleeping
Sleeping
| # yolo_predictor.py | |
| import os | |
| import numpy as np | |
| import rasterio | |
| from ultralytics import YOLO | |
| from ndvi_predictor import normalize_rgb, predict_ndvi | |
| import tempfile | |
| from rasterio.transform import from_bounds | |
| from PIL import Image | |
| import tifffile | |
| def load_yolo_model(model_path): | |
| """Load YOLO model from .pt file""" | |
| return YOLO(model_path) | |
| def predict_ndvi_from_rgb(ndvi_model, rgb_array): | |
| """ | |
| Predict NDVI channel from RGB array | |
| Args: | |
| ndvi_model: Loaded NDVI prediction model | |
| rgb_array: RGB image as numpy array (H, W, 3) | |
| Returns: | |
| ndvi_array: Predicted NDVI as numpy array (H, W) | |
| """ | |
| # Normalize RGB input | |
| norm_rgb = normalize_rgb(rgb_array) | |
| # Predict NDVI | |
| ndvi_pred = predict_ndvi(ndvi_model, norm_rgb) | |
| return ndvi_pred | |
| def predict_yolo(yolo_model, image_path, conf=0.001): | |
| """ | |
| Predict using YOLO model on 4-channel TIFF image | |
| Args: | |
| yolo_model: Loaded YOLO model | |
| image_path: Path to 4-channel TIFF image | |
| conf: Confidence threshold | |
| Returns: | |
| results: YOLO results object | |
| """ | |
| # Verify the image has 4 channels before prediction | |
| try: | |
| # Check image format and channels | |
| with Image.open(image_path) as img: | |
| if hasattr(img, 'n_frames'): | |
| # Multi-frame TIFF | |
| channels = img.n_frames | |
| else: | |
| # Regular image | |
| channels = len(img.getbands()) if hasattr(img, 'getbands') else 3 | |
| # If not 4 channels, try with tifffile | |
| if channels != 4: | |
| img_array = tifffile.imread(image_path) | |
| if len(img_array.shape) == 3: | |
| if img_array.shape[0] == 4: | |
| channels = 4 | |
| elif img_array.shape[2] == 4: | |
| channels = 4 | |
| else: | |
| channels = img_array.shape[0] if img_array.shape[0] <= 4 else img_array.shape[2] | |
| else: | |
| channels = 1 | |
| if channels != 4: | |
| raise ValueError(f"YOLO model expects 4-channel images, but got {channels} channels") | |
| except Exception as e: | |
| raise ValueError(f"Error reading image channels: {str(e)}") | |
| # Run YOLO prediction | |
| results = yolo_model([image_path], conf=conf) | |
| return results[0] # Return first result | |
| def create_4channel_tiff(rgb_array, ndvi_array, output_path): | |
| """ | |
| Create a 4-channel TIFF file from RGB and NDVI arrays compatible with PIL and YOLO | |
| Args: | |
| rgb_array: RGB image as numpy array (H, W, 3) | |
| ndvi_array: NDVI image as numpy array (H, W) | |
| output_path: Path to save the 4-channel TIFF | |
| """ | |
| height, width = rgb_array.shape[:2] | |
| # Ensure RGB is in uint8 format for better compatibility | |
| if rgb_array.dtype != np.uint8: | |
| if rgb_array.max() <= 1.0: | |
| rgb_normalized = (rgb_array * 255).astype(np.uint8) | |
| else: | |
| rgb_normalized = np.clip(rgb_array, 0, 255).astype(np.uint8) | |
| else: | |
| rgb_normalized = rgb_array | |
| # Convert NDVI from [-1, 1] to [0, 255] for uint8 storage | |
| ndvi_normalized = ((ndvi_array + 1) * 127.5).astype(np.uint8) | |
| # Create 4-channel array in (H, W, 4) format | |
| four_channel = np.zeros((height, width, 4), dtype=np.uint8) | |
| four_channel[:, :, 0] = rgb_normalized[:, :, 0] # Red | |
| four_channel[:, :, 1] = rgb_normalized[:, :, 1] # Green | |
| four_channel[:, :, 2] = rgb_normalized[:, :, 2] # Blue | |
| four_channel[:, :, 3] = ndvi_normalized # NDVI | |
| # Save using tifffile with proper format for YOLO compatibility | |
| tifffile.imwrite( | |
| output_path, | |
| four_channel, | |
| photometric='rgb', | |
| compress='lzw', | |
| metadata={'axes': 'YXC'} | |
| ) | |
| def load_4channel_tiff(image_path): | |
| """ | |
| Load a 4-channel TIFF image | |
| Args: | |
| image_path: Path to 4-channel TIFF image | |
| Returns: | |
| rgb_array: RGB channels as numpy array (H, W, 3) | |
| ndvi_array: NDVI channel as numpy array (H, W) | |
| """ | |
| try: | |
| # Try with tifffile first for better TIFF support | |
| img_array = tifffile.imread(image_path) | |
| if len(img_array.shape) == 3: | |
| if img_array.shape[0] == 4: | |
| # Shape is (4, H, W) - transpose to (H, W, 4) | |
| img_array = np.transpose(img_array, (1, 2, 0)) | |
| elif img_array.shape[2] != 4: | |
| raise ValueError(f"Expected 4 channels, got {img_array.shape}") | |
| # Extract RGB and NDVI from (H, W, 4) format | |
| rgb_array = img_array[:, :, :3] | |
| ndvi_array = img_array[:, :, 3] | |
| # Convert NDVI back from [0, 255] to [-1, 1] if it was stored as uint8 | |
| if img_array.dtype == np.uint8: | |
| ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1 | |
| return rgb_array, ndvi_array | |
| except Exception as e: | |
| # Fallback to rasterio | |
| try: | |
| with rasterio.open(image_path) as src: | |
| if src.count != 4: | |
| raise ValueError(f"Expected 4 channels, got {src.count}") | |
| channels = src.read() # Shape: (4, H, W) | |
| # Extract RGB and NDVI | |
| rgb_array = np.transpose(channels[:3], (1, 2, 0)) # (H, W, 3) | |
| ndvi_array = channels[3] # (H, W) | |
| # Convert NDVI if needed | |
| if channels.dtype == np.uint8: | |
| ndvi_array = (ndvi_array.astype(np.float32) / 127.5) - 1 | |
| return rgb_array, ndvi_array | |
| except Exception as e2: | |
| raise ValueError(f"Could not load 4-channel TIFF. Errors: tifffile={e}, rasterio={e2}") | |
| def predict_pipeline(ndvi_model, yolo_model, image_path, conf=0.001): | |
| """ | |
| Full pipeline: Load image -> Extract RGB -> Predict NDVI -> | |
| Create 4-channel TIFF -> Run YOLO prediction | |
| Args: | |
| ndvi_model: Loaded NDVI prediction model | |
| yolo_model: Loaded YOLO model | |
| image_path: Path to input image (can be RGB or 4-channel TIFF) | |
| conf: Confidence threshold for YOLO | |
| Returns: | |
| results: YOLO results object | |
| """ | |
| rgb_array = None | |
| # Try multiple methods to load the image and extract RGB | |
| try: | |
| # Method 1: Try with tifffile first (best for complex TIFF files) | |
| img_array = tifffile.imread(image_path) | |
| if len(img_array.shape) == 3: | |
| if img_array.shape[0] == 4: | |
| # Shape is (4, H, W) - extract RGB | |
| rgb_array = np.transpose(img_array[:3], (1, 2, 0)) | |
| elif img_array.shape[0] == 3: | |
| # Shape is (3, H, W) - transpose to RGB | |
| rgb_array = np.transpose(img_array, (1, 2, 0)) | |
| elif img_array.shape[2] == 4: | |
| # Shape is (H, W, 4) - extract RGB | |
| rgb_array = img_array[:, :, :3] | |
| elif img_array.shape[2] == 3: | |
| # Shape is (H, W, 3) - already RGB | |
| rgb_array = img_array | |
| elif len(img_array.shape) == 2: | |
| # Grayscale - convert to RGB | |
| rgb_array = np.stack([img_array] * 3, axis=-1) | |
| except Exception as e1: | |
| try: | |
| # Method 2: Try with rasterio | |
| with rasterio.open(image_path) as src: | |
| channels = src.read() | |
| if src.count >= 3: | |
| rgb_array = np.transpose(channels[:3], (1, 2, 0)) | |
| elif src.count == 1: | |
| # Single channel - convert to RGB | |
| single_channel = channels[0] | |
| rgb_array = np.stack([single_channel] * 3, axis=-1) | |
| except Exception as e2: | |
| try: | |
| # Method 3: Fall back to PIL for standard formats | |
| img = Image.open(image_path) | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| rgb_array = np.array(img) | |
| except Exception as e3: | |
| raise ValueError(f"Could not load image with any method. Errors: tifffile={e1}, rasterio={e2}, PIL={e3}") | |
| if rgb_array is None: | |
| raise ValueError("Failed to extract RGB data from image") | |
| # Ensure RGB is in correct format and range | |
| if rgb_array.dtype == np.uint8: | |
| # Keep as uint8 but also create float version for NDVI prediction | |
| rgb_float = rgb_array.astype(np.float32) / 255.0 | |
| else: | |
| # Already float, ensure range is [0, 1] | |
| if rgb_array.max() > 1.0: | |
| rgb_float = rgb_array / 255.0 | |
| else: | |
| rgb_float = rgb_array | |
| rgb_array = (rgb_float * 255).astype(np.uint8) | |
| # Predict NDVI from RGB | |
| ndvi_pred = predict_ndvi_from_rgb(ndvi_model, rgb_float) | |
| # Create temporary 4-channel TIFF file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file: | |
| temp_4ch_path = tmp_file.name | |
| try: | |
| # Create 4-channel TIFF with predicted NDVI | |
| create_4channel_tiff(rgb_array, ndvi_pred, temp_4ch_path) | |
| # Verify the created file can be read | |
| try: | |
| test_img = Image.open(temp_4ch_path) | |
| if hasattr(test_img, 'n_frames'): | |
| channels = test_img.n_frames | |
| else: | |
| channels = len(test_img.getbands()) | |
| test_img.close() | |
| if channels != 4: | |
| raise ValueError(f"Created TIFF has {channels} channels instead of 4") | |
| except Exception as e: | |
| raise ValueError(f"Created TIFF file is not readable: {str(e)}") | |
| # Run YOLO prediction on 4-channel image | |
| results = predict_yolo(yolo_model, temp_4ch_path, conf=conf) | |
| return results | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(temp_4ch_path): | |
| os.unlink(temp_4ch_path) |