Testing-Pipeline-API / ndvi_predictor.py
ahadhassan's picture
Merge branch 'pr/12' into pr/13
563ccc9
raw
history blame
6.86 kB
# ndvi_predictor.py
import os
# os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import tensorflow as tf
from tensorflow.keras.models import model_from_json
from efficientnet.tfkeras import EfficientNetB2
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from PIL import Image
import io
from tensorflow.keras.models import model_from_json
import traceback
import gc
# Custom loss functions and activation functions
def balanced_mse_loss(y_true, y_pred):
mse = tf.square(y_true - y_pred)
negative_weight = tf.where(y_true < -0.2, 1.5, 1.0)
boundary_weight = tf.where(tf.abs(y_true) > 0.5, 1.5, 1.0)
weights = negative_weight * boundary_weight
weighted_mse = weights * mse
return tf.reduce_mean(mse)
def custom_mae(y_true, y_pred):
mae = tf.abs(y_true - y_pred)
return tf.reduce_mean(mae)
def load_model(models_dir):
"""Load NDVI prediction model with custom objects"""
# Define custom objects dictionary
custom_objects = {
'balanced_mse_loss': balanced_mse_loss,
'custom_mae': custom_mae
}
try:
# Load model architecture
with open(os.path.join(models_dir, "model_architecture.json"), "r") as json_file:
model_json = json_file.read()
model = model_from_json(model_json, custom_objects=custom_objects)
# Load weights
model.load_weights(os.path.join(models_dir, "best_model_weights.weights.h5"))
# Compile model with custom functions
optimizer = tf.keras.optimizers.AdamW(learning_rate=0.0005, weight_decay=1e-4)
model.compile(
optimizer=optimizer,
loss=balanced_mse_loss,
metrics=[custom_mae, 'mse']
)
return model
except Exception as e:
traceback.print_exc()
return None
def normalize_rgb(rgb):
"""Normalize RGB image to [0, 1] range using percentile normalization"""
rgb_norm = rgb.copy().astype(np.float32)
# Handle different input ranges
if rgb.max() > 1:
rgb_norm = rgb_norm / 255.0
for b in range(3):
band = rgb_norm[:, :, b]
min_val, max_val = np.percentile(band, [1, 99])
if min_val < max_val:
rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
return rgb_norm
def predict_ndvi(model, rgb_np):
"""
Faster NDVI prediction with larger tiles and more efficient processing
Args:
model: Loaded NDVI prediction model
rgb_np: RGB image as numpy array (H, W, 3) normalized to [0, 1]
Returns:
ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
"""
height, width = rgb_np.shape[:2]
# Larger tiles for faster processing
tile_size = 512
stride = int(tile_size * 0.75) # 25% overlap
# For smaller images, process whole image at once
if height <= tile_size and width <= tile_size:
# Pad to tile size if needed
pad_height = max(0, tile_size - height)
pad_width = max(0, tile_size - width)
if pad_height > 0 or pad_width > 0:
rgb_padded = np.pad(rgb_np, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
else:
rgb_padded = rgb_np
# Single prediction
pred = model.predict(np.expand_dims(rgb_padded, axis=0), verbose=0, batch_size=1)[0, :, :, 0]
return pred[:height, :width]
# Initialize output arrays
ndvi_pred = np.zeros((height, width), dtype=np.float32)
weight_map = np.zeros((height, width), dtype=np.float32)
# Pre-compute weights for efficiency
y, x = np.mgrid[0:tile_size, 0:tile_size]
base_weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
base_weights = np.clip(base_weights, 0, 64) / 64
# Collect all tiles first
tiles = []
positions = []
for i in range(0, height, stride):
for j in range(0, width, stride):
# Calculate actual tile bounds
end_i = min(i + tile_size, height)
end_j = min(j + tile_size, width)
actual_height = end_i - i
actual_width = end_j - j
# Extract tile
tile = rgb_np[i:end_i, j:end_j, :]
# Pad if necessary
if actual_height < tile_size or actual_width < tile_size:
pad_height = tile_size - actual_height
pad_width = tile_size - actual_width
tile = np.pad(tile, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
tiles.append(tile)
positions.append((i, j, actual_height, actual_width))
# Process all tiles in larger batches
batch_size = 8 # Process 8 tiles at once
for batch_start in range(0, len(tiles), batch_size):
batch_end = min(batch_start + batch_size, len(tiles))
batch_tiles = np.array(tiles[batch_start:batch_end])
# Predict batch
batch_preds = model.predict(batch_tiles, verbose=0, batch_size=batch_size)
# Apply predictions
for k in range(batch_end - batch_start):
pred = batch_preds[k, :, :, 0]
i, j, actual_height, actual_width = positions[batch_start + k]
# Use appropriate weights
weights = base_weights[:actual_height, :actual_width]
# Add to output
ndvi_pred[i:i+actual_height, j:j+actual_width] += pred[:actual_height, :actual_width] * weights
weight_map[i:i+actual_height, j:j+actual_width] += weights
# Clean up batch
del batch_tiles, batch_preds
# Normalize by weights
mask = weight_map > 0
ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
return ndvi_pred
def create_visualization(rgb, ndvi):
"""
Create visualization of RGB input and predicted NDVI
Args:
rgb: RGB image array
ndvi: NDVI prediction array
Returns:
buf: BytesIO buffer containing the visualization as PNG
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Display RGB image
rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
axes[0].imshow(rgb_disp)
axes[0].set_title("RGB Input")
axes[0].axis("off")
# Display NDVI with color map
im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
axes[1].set_title("Predicted NDVI")
axes[1].axis("off")
fig.colorbar(im, ax=axes[1])
# Save to buffer
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format="png", dpi=150, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return buf