Spaces:
Sleeping
Sleeping
File size: 4,146 Bytes
336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 336fa4a 29e9dd7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | # ndvi_predictor.py
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import tensorflow as tf
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from PIL import Image
import io
def load_model(model_path):
"""Load NDVI prediction model"""
return tf.keras.models.load_model(model_path, compile=False)
def normalize_rgb(rgb):
"""Normalize RGB image to [0, 1] range using percentile normalization"""
rgb_norm = rgb.copy().astype(np.float32)
# Handle different input ranges
if rgb.max() > 1:
rgb_norm = rgb_norm / 255.0
for b in range(3):
band = rgb_norm[:, :, b]
min_val, max_val = np.percentile(band, [1, 99])
if min_val < max_val:
rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
return rgb_norm
def predict_ndvi(model, rgb_np):
"""
Predict NDVI from RGB image using tiled approach for large images
Args:
model: Loaded NDVI prediction model
rgb_np: RGB image as numpy array (H, W, 3) normalized to [0, 1]
Returns:
ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
"""
height, width = rgb_np.shape[:2]
tile_size = 512
stride = int(tile_size * 0.7)
# Initialize output arrays
ndvi_pred = np.zeros((height, width), dtype=np.float32)
weight_map = np.zeros((height, width), dtype=np.float32)
# Handle small images by padding
if height < tile_size or width < tile_size:
pad_height = max(0, tile_size - height)
pad_width = max(0, tile_size - width)
rgb_padded = np.pad(rgb_np, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
height_padded, width_padded = rgb_padded.shape[0], rgb_padded.shape[1]
else:
rgb_padded = rgb_np
height_padded, width_padded = height, width
# Process image tiles
for i in range(0, height_padded - tile_size + 1, stride):
for j in range(0, width_padded - tile_size + 1, stride):
# Extract tile
tile = rgb_padded[i:i+tile_size, j:j+tile_size, :]
# Create distance-based weights for blending
y, x = np.mgrid[0:tile_size, 0:tile_size]
weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
weights = np.clip(weights, 0, 50) / 50
# Predict NDVI for tile
tile_pred = model.predict(np.expand_dims(tile, axis=0), verbose=0)[0, :, :, 0]
# Determine valid region (handle edge cases)
valid_height = min(tile_size, height - i)
valid_width = min(tile_size, width - j)
# Accumulate weighted predictions
ndvi_pred[i:i+valid_height, j:j+valid_width] += (
tile_pred[:valid_height, :valid_width] * weights[:valid_height, :valid_width]
)
weight_map[i:i+valid_height, j:j+valid_width] += weights[:valid_height, :valid_width]
# Normalize by weights
mask = weight_map > 0
ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
return ndvi_pred
def create_visualization(rgb, ndvi):
"""
Create visualization of RGB input and predicted NDVI
Args:
rgb: RGB image array
ndvi: NDVI prediction array
Returns:
buf: BytesIO buffer containing the visualization as PNG
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Display RGB image
rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
axes[0].imshow(rgb_disp)
axes[0].set_title("RGB Input")
axes[0].axis("off")
# Display NDVI with color map
im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
axes[1].set_title("Predicted NDVI")
axes[1].axis("off")
fig.colorbar(im, ax=axes[1])
# Save to buffer
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format="png", dpi=150, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return buf |