# ndvi_predictor.py import os os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" os.environ["SM_FRAMEWORK"] = "tf.keras" import segmentation_models as sm import tensorflow as tf import numpy as np import rasterio import matplotlib.pyplot as plt from PIL import Image import io def load_model(model_path): """Load NDVI prediction model""" return tf.keras.models.load_model(model_path, compile=False) def normalize_rgb(rgb): """Normalize RGB image to [0, 1] range using percentile normalization""" rgb_norm = rgb.copy().astype(np.float32) # Handle different input ranges if rgb.max() > 1: rgb_norm = rgb_norm / 255.0 for b in range(3): band = rgb_norm[:, :, b] min_val, max_val = np.percentile(band, [1, 99]) if min_val < max_val: rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1) return rgb_norm def predict_ndvi(model, rgb_np): """ Predict NDVI from RGB image using tiled approach for large images Args: model: Loaded NDVI prediction model rgb_np: RGB image as numpy array (H, W, 3) normalized to [0, 1] Returns: ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1] """ height, width = rgb_np.shape[:2] tile_size = 512 stride = int(tile_size * 0.7) # Initialize output arrays ndvi_pred = np.zeros((height, width), dtype=np.float32) weight_map = np.zeros((height, width), dtype=np.float32) # Handle small images by padding if height < tile_size or width < tile_size: pad_height = max(0, tile_size - height) pad_width = max(0, tile_size - width) rgb_padded = np.pad(rgb_np, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect') height_padded, width_padded = rgb_padded.shape[0], rgb_padded.shape[1] else: rgb_padded = rgb_np height_padded, width_padded = height, width # Process image tiles for i in range(0, height_padded - tile_size + 1, stride): for j in range(0, width_padded - tile_size + 1, stride): # Extract tile tile = rgb_padded[i:i+tile_size, j:j+tile_size, :] # Create distance-based weights for blending y, x = np.mgrid[0:tile_size, 0:tile_size] weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1)) weights = np.clip(weights, 0, 50) / 50 # Predict NDVI for tile tile_pred = model.predict(np.expand_dims(tile, axis=0), verbose=0)[0, :, :, 0] # Determine valid region (handle edge cases) valid_height = min(tile_size, height - i) valid_width = min(tile_size, width - j) # Accumulate weighted predictions ndvi_pred[i:i+valid_height, j:j+valid_width] += ( tile_pred[:valid_height, :valid_width] * weights[:valid_height, :valid_width] ) weight_map[i:i+valid_height, j:j+valid_width] += weights[:valid_height, :valid_width] # Normalize by weights mask = weight_map > 0 ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask] return ndvi_pred def create_visualization(rgb, ndvi): """ Create visualization of RGB input and predicted NDVI Args: rgb: RGB image array ndvi: NDVI prediction array Returns: buf: BytesIO buffer containing the visualization as PNG """ fig, axes = plt.subplots(1, 2, figsize=(12, 6)) # Display RGB image rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1) axes[0].imshow(rgb_disp) axes[0].set_title("RGB Input") axes[0].axis("off") # Display NDVI with color map im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1) axes[1].set_title("Predicted NDVI") axes[1].axis("off") fig.colorbar(im, ax=axes[1]) # Save to buffer buf = io.BytesIO() plt.tight_layout() plt.savefig(buf, format="png", dpi=150, bbox_inches='tight') plt.close(fig) buf.seek(0) return buf