Testing-Pipeline-API / ndvi_predictor.py
ahadhassan's picture
Update ndvi_predictor.py
29e9dd7 verified
raw
history blame
4.15 kB
# ndvi_predictor.py
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import tensorflow as tf
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from PIL import Image
import io
def load_model(model_path):
"""Load NDVI prediction model"""
return tf.keras.models.load_model(model_path, compile=False)
def normalize_rgb(rgb):
"""Normalize RGB image to [0, 1] range using percentile normalization"""
rgb_norm = rgb.copy().astype(np.float32)
# Handle different input ranges
if rgb.max() > 1:
rgb_norm = rgb_norm / 255.0
for b in range(3):
band = rgb_norm[:, :, b]
min_val, max_val = np.percentile(band, [1, 99])
if min_val < max_val:
rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
return rgb_norm
def predict_ndvi(model, rgb_np):
"""
Predict NDVI from RGB image using tiled approach for large images
Args:
model: Loaded NDVI prediction model
rgb_np: RGB image as numpy array (H, W, 3) normalized to [0, 1]
Returns:
ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
"""
height, width = rgb_np.shape[:2]
tile_size = 512
stride = int(tile_size * 0.7)
# Initialize output arrays
ndvi_pred = np.zeros((height, width), dtype=np.float32)
weight_map = np.zeros((height, width), dtype=np.float32)
# Handle small images by padding
if height < tile_size or width < tile_size:
pad_height = max(0, tile_size - height)
pad_width = max(0, tile_size - width)
rgb_padded = np.pad(rgb_np, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
height_padded, width_padded = rgb_padded.shape[0], rgb_padded.shape[1]
else:
rgb_padded = rgb_np
height_padded, width_padded = height, width
# Process image tiles
for i in range(0, height_padded - tile_size + 1, stride):
for j in range(0, width_padded - tile_size + 1, stride):
# Extract tile
tile = rgb_padded[i:i+tile_size, j:j+tile_size, :]
# Create distance-based weights for blending
y, x = np.mgrid[0:tile_size, 0:tile_size]
weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
weights = np.clip(weights, 0, 50) / 50
# Predict NDVI for tile
tile_pred = model.predict(np.expand_dims(tile, axis=0), verbose=0)[0, :, :, 0]
# Determine valid region (handle edge cases)
valid_height = min(tile_size, height - i)
valid_width = min(tile_size, width - j)
# Accumulate weighted predictions
ndvi_pred[i:i+valid_height, j:j+valid_width] += (
tile_pred[:valid_height, :valid_width] * weights[:valid_height, :valid_width]
)
weight_map[i:i+valid_height, j:j+valid_width] += weights[:valid_height, :valid_width]
# Normalize by weights
mask = weight_map > 0
ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
return ndvi_pred
def create_visualization(rgb, ndvi):
"""
Create visualization of RGB input and predicted NDVI
Args:
rgb: RGB image array
ndvi: NDVI prediction array
Returns:
buf: BytesIO buffer containing the visualization as PNG
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Display RGB image
rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
axes[0].imshow(rgb_disp)
axes[0].set_title("RGB Input")
axes[0].axis("off")
# Display NDVI with color map
im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
axes[1].set_title("Predicted NDVI")
axes[1].axis("off")
fig.colorbar(im, ax=axes[1])
# Save to buffer
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format="png", dpi=150, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return buf