Spaces:
Sleeping
Sleeping
Update ndvi_predictor.py
Browse files- ndvi_predictor.py +58 -10
ndvi_predictor.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
| 2 |
import os
|
| 3 |
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 4 |
os.environ["SM_FRAMEWORK"] = "tf.keras"
|
| 5 |
-
|
| 6 |
import segmentation_models as sm
|
| 7 |
import tensorflow as tf
|
| 8 |
import numpy as np
|
|
@@ -12,25 +11,45 @@ from PIL import Image
|
|
| 12 |
import io
|
| 13 |
|
| 14 |
def load_model(model_path):
|
|
|
|
| 15 |
return tf.keras.models.load_model(model_path, compile=False)
|
| 16 |
|
| 17 |
def normalize_rgb(rgb):
|
|
|
|
| 18 |
rgb_norm = rgb.copy().astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
for b in range(3):
|
| 20 |
band = rgb_norm[:, :, b]
|
| 21 |
min_val, max_val = np.percentile(band, [1, 99])
|
| 22 |
if min_val < max_val:
|
| 23 |
rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
|
|
|
|
| 24 |
return rgb_norm
|
| 25 |
|
| 26 |
def predict_ndvi(model, rgb_np):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
height, width = rgb_np.shape[:2]
|
| 28 |
tile_size = 512
|
| 29 |
stride = int(tile_size * 0.7)
|
| 30 |
-
|
|
|
|
| 31 |
ndvi_pred = np.zeros((height, width), dtype=np.float32)
|
| 32 |
weight_map = np.zeros((height, width), dtype=np.float32)
|
| 33 |
-
|
|
|
|
| 34 |
if height < tile_size or width < tile_size:
|
| 35 |
pad_height = max(0, tile_size - height)
|
| 36 |
pad_width = max(0, tile_size - width)
|
|
@@ -39,38 +58,67 @@ def predict_ndvi(model, rgb_np):
|
|
| 39 |
else:
|
| 40 |
rgb_padded = rgb_np
|
| 41 |
height_padded, width_padded = height, width
|
| 42 |
-
|
|
|
|
| 43 |
for i in range(0, height_padded - tile_size + 1, stride):
|
| 44 |
for j in range(0, width_padded - tile_size + 1, stride):
|
|
|
|
| 45 |
tile = rgb_padded[i:i+tile_size, j:j+tile_size, :]
|
|
|
|
|
|
|
| 46 |
y, x = np.mgrid[0:tile_size, 0:tile_size]
|
| 47 |
weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
|
| 48 |
weights = np.clip(weights, 0, 50) / 50
|
|
|
|
|
|
|
| 49 |
tile_pred = model.predict(np.expand_dims(tile, axis=0), verbose=0)[0, :, :, 0]
|
|
|
|
|
|
|
| 50 |
valid_height = min(tile_size, height - i)
|
| 51 |
valid_width = min(tile_size, width - j)
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
weight_map[i:i+valid_height, j:j+valid_width] += weights[:valid_height, :valid_width]
|
| 54 |
-
|
|
|
|
| 55 |
mask = weight_map > 0
|
| 56 |
ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
|
|
|
|
| 57 |
return ndvi_pred
|
| 58 |
|
| 59 |
def create_visualization(rgb, ndvi):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
|
|
|
|
|
|
|
| 61 |
rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
|
| 62 |
axes[0].imshow(rgb_disp)
|
| 63 |
axes[0].set_title("RGB Input")
|
| 64 |
axes[0].axis("off")
|
| 65 |
-
|
|
|
|
| 66 |
im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
|
| 67 |
axes[1].set_title("Predicted NDVI")
|
| 68 |
axes[1].axis("off")
|
| 69 |
fig.colorbar(im, ax=axes[1])
|
| 70 |
-
|
|
|
|
| 71 |
buf = io.BytesIO()
|
| 72 |
plt.tight_layout()
|
| 73 |
-
plt.savefig(buf, format="png")
|
| 74 |
plt.close(fig)
|
| 75 |
buf.seek(0)
|
| 76 |
-
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 4 |
os.environ["SM_FRAMEWORK"] = "tf.keras"
|
|
|
|
| 5 |
import segmentation_models as sm
|
| 6 |
import tensorflow as tf
|
| 7 |
import numpy as np
|
|
|
|
| 11 |
import io
|
| 12 |
|
| 13 |
def load_model(model_path):
|
| 14 |
+
"""Load NDVI prediction model"""
|
| 15 |
return tf.keras.models.load_model(model_path, compile=False)
|
| 16 |
|
| 17 |
def normalize_rgb(rgb):
|
| 18 |
+
"""Normalize RGB image to [0, 1] range using percentile normalization"""
|
| 19 |
rgb_norm = rgb.copy().astype(np.float32)
|
| 20 |
+
|
| 21 |
+
# Handle different input ranges
|
| 22 |
+
if rgb.max() > 1:
|
| 23 |
+
rgb_norm = rgb_norm / 255.0
|
| 24 |
+
|
| 25 |
for b in range(3):
|
| 26 |
band = rgb_norm[:, :, b]
|
| 27 |
min_val, max_val = np.percentile(band, [1, 99])
|
| 28 |
if min_val < max_val:
|
| 29 |
rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
|
| 30 |
+
|
| 31 |
return rgb_norm
|
| 32 |
|
| 33 |
def predict_ndvi(model, rgb_np):
|
| 34 |
+
"""
|
| 35 |
+
Predict NDVI from RGB image using tiled approach for large images
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model: Loaded NDVI prediction model
|
| 39 |
+
rgb_np: RGB image as numpy array (H, W, 3) normalized to [0, 1]
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
|
| 43 |
+
"""
|
| 44 |
height, width = rgb_np.shape[:2]
|
| 45 |
tile_size = 512
|
| 46 |
stride = int(tile_size * 0.7)
|
| 47 |
+
|
| 48 |
+
# Initialize output arrays
|
| 49 |
ndvi_pred = np.zeros((height, width), dtype=np.float32)
|
| 50 |
weight_map = np.zeros((height, width), dtype=np.float32)
|
| 51 |
+
|
| 52 |
+
# Handle small images by padding
|
| 53 |
if height < tile_size or width < tile_size:
|
| 54 |
pad_height = max(0, tile_size - height)
|
| 55 |
pad_width = max(0, tile_size - width)
|
|
|
|
| 58 |
else:
|
| 59 |
rgb_padded = rgb_np
|
| 60 |
height_padded, width_padded = height, width
|
| 61 |
+
|
| 62 |
+
# Process image tiles
|
| 63 |
for i in range(0, height_padded - tile_size + 1, stride):
|
| 64 |
for j in range(0, width_padded - tile_size + 1, stride):
|
| 65 |
+
# Extract tile
|
| 66 |
tile = rgb_padded[i:i+tile_size, j:j+tile_size, :]
|
| 67 |
+
|
| 68 |
+
# Create distance-based weights for blending
|
| 69 |
y, x = np.mgrid[0:tile_size, 0:tile_size]
|
| 70 |
weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
|
| 71 |
weights = np.clip(weights, 0, 50) / 50
|
| 72 |
+
|
| 73 |
+
# Predict NDVI for tile
|
| 74 |
tile_pred = model.predict(np.expand_dims(tile, axis=0), verbose=0)[0, :, :, 0]
|
| 75 |
+
|
| 76 |
+
# Determine valid region (handle edge cases)
|
| 77 |
valid_height = min(tile_size, height - i)
|
| 78 |
valid_width = min(tile_size, width - j)
|
| 79 |
+
|
| 80 |
+
# Accumulate weighted predictions
|
| 81 |
+
ndvi_pred[i:i+valid_height, j:j+valid_width] += (
|
| 82 |
+
tile_pred[:valid_height, :valid_width] * weights[:valid_height, :valid_width]
|
| 83 |
+
)
|
| 84 |
weight_map[i:i+valid_height, j:j+valid_width] += weights[:valid_height, :valid_width]
|
| 85 |
+
|
| 86 |
+
# Normalize by weights
|
| 87 |
mask = weight_map > 0
|
| 88 |
ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
|
| 89 |
+
|
| 90 |
return ndvi_pred
|
| 91 |
|
| 92 |
def create_visualization(rgb, ndvi):
|
| 93 |
+
"""
|
| 94 |
+
Create visualization of RGB input and predicted NDVI
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
rgb: RGB image array
|
| 98 |
+
ndvi: NDVI prediction array
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
buf: BytesIO buffer containing the visualization as PNG
|
| 102 |
+
"""
|
| 103 |
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
|
| 104 |
+
|
| 105 |
+
# Display RGB image
|
| 106 |
rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
|
| 107 |
axes[0].imshow(rgb_disp)
|
| 108 |
axes[0].set_title("RGB Input")
|
| 109 |
axes[0].axis("off")
|
| 110 |
+
|
| 111 |
+
# Display NDVI with color map
|
| 112 |
im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
|
| 113 |
axes[1].set_title("Predicted NDVI")
|
| 114 |
axes[1].axis("off")
|
| 115 |
fig.colorbar(im, ax=axes[1])
|
| 116 |
+
|
| 117 |
+
# Save to buffer
|
| 118 |
buf = io.BytesIO()
|
| 119 |
plt.tight_layout()
|
| 120 |
+
plt.savefig(buf, format="png", dpi=150, bbox_inches='tight')
|
| 121 |
plt.close(fig)
|
| 122 |
buf.seek(0)
|
| 123 |
+
|
| 124 |
+
return buf
|