File size: 4,146 Bytes
336fa4a
 
 
 
 
 
 
 
 
 
 
 
 
29e9dd7
336fa4a
 
 
29e9dd7
336fa4a
29e9dd7
 
 
 
 
336fa4a
 
 
 
 
29e9dd7
336fa4a
 
 
29e9dd7
 
 
 
 
 
 
 
 
 
336fa4a
 
 
29e9dd7
 
336fa4a
 
29e9dd7
 
336fa4a
 
 
 
 
 
 
 
29e9dd7
 
336fa4a
 
29e9dd7
336fa4a
29e9dd7
 
336fa4a
 
 
29e9dd7
 
336fa4a
29e9dd7
 
336fa4a
 
29e9dd7
 
 
 
 
336fa4a
29e9dd7
 
336fa4a
 
29e9dd7
336fa4a
 
 
29e9dd7
 
 
 
 
 
 
 
 
 
336fa4a
29e9dd7
 
336fa4a
 
 
 
29e9dd7
 
336fa4a
 
 
 
29e9dd7
 
336fa4a
 
29e9dd7
336fa4a
 
29e9dd7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# ndvi_predictor.py
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import tensorflow as tf
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from PIL import Image
import io

def load_model(model_path):
    """Load NDVI prediction model"""
    return tf.keras.models.load_model(model_path, compile=False)

def normalize_rgb(rgb):
    """Normalize RGB image to [0, 1] range using percentile normalization"""
    rgb_norm = rgb.copy().astype(np.float32)
    
    # Handle different input ranges
    if rgb.max() > 1:
        rgb_norm = rgb_norm / 255.0
    
    for b in range(3):
        band = rgb_norm[:, :, b]
        min_val, max_val = np.percentile(band, [1, 99])
        if min_val < max_val:
            rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
    
    return rgb_norm

def predict_ndvi(model, rgb_np):
    """
    Predict NDVI from RGB image using tiled approach for large images
    
    Args:
        model: Loaded NDVI prediction model
        rgb_np: RGB image as numpy array (H, W, 3) normalized to [0, 1]
    
    Returns:
        ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
    """
    height, width = rgb_np.shape[:2]
    tile_size = 512
    stride = int(tile_size * 0.7)
    
    # Initialize output arrays
    ndvi_pred = np.zeros((height, width), dtype=np.float32)
    weight_map = np.zeros((height, width), dtype=np.float32)
    
    # Handle small images by padding
    if height < tile_size or width < tile_size:
        pad_height = max(0, tile_size - height)
        pad_width = max(0, tile_size - width)
        rgb_padded = np.pad(rgb_np, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
        height_padded, width_padded = rgb_padded.shape[0], rgb_padded.shape[1]
    else:
        rgb_padded = rgb_np
        height_padded, width_padded = height, width
    
    # Process image tiles
    for i in range(0, height_padded - tile_size + 1, stride):
        for j in range(0, width_padded - tile_size + 1, stride):
            # Extract tile
            tile = rgb_padded[i:i+tile_size, j:j+tile_size, :]
            
            # Create distance-based weights for blending
            y, x = np.mgrid[0:tile_size, 0:tile_size]
            weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
            weights = np.clip(weights, 0, 50) / 50
            
            # Predict NDVI for tile
            tile_pred = model.predict(np.expand_dims(tile, axis=0), verbose=0)[0, :, :, 0]
            
            # Determine valid region (handle edge cases)
            valid_height = min(tile_size, height - i)
            valid_width = min(tile_size, width - j)
            
            # Accumulate weighted predictions
            ndvi_pred[i:i+valid_height, j:j+valid_width] += (
                tile_pred[:valid_height, :valid_width] * weights[:valid_height, :valid_width]
            )
            weight_map[i:i+valid_height, j:j+valid_width] += weights[:valid_height, :valid_width]
    
    # Normalize by weights
    mask = weight_map > 0
    ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
    
    return ndvi_pred

def create_visualization(rgb, ndvi):
    """
    Create visualization of RGB input and predicted NDVI
    
    Args:
        rgb: RGB image array
        ndvi: NDVI prediction array
    
    Returns:
        buf: BytesIO buffer containing the visualization as PNG
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Display RGB image
    rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
    axes[0].imshow(rgb_disp)
    axes[0].set_title("RGB Input")
    axes[0].axis("off")
    
    # Display NDVI with color map
    im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
    axes[1].set_title("Predicted NDVI")
    axes[1].axis("off")
    fig.colorbar(im, ax=axes[1])
    
    # Save to buffer
    buf = io.BytesIO()
    plt.tight_layout()
    plt.savefig(buf, format="png", dpi=150, bbox_inches='tight')
    plt.close(fig)
    buf.seek(0)
    
    return buf