File size: 6,855 Bytes
336fa4a
 
a13fdf1
336fa4a
 
 
cdf9e57
 
336fa4a
 
 
 
 
a13fdf1
 
 
336fa4a
cdf9e57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a13fdf1
cdf9e57
 
 
 
 
 
a13fdf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336fa4a
 
29e9dd7
336fa4a
29e9dd7
 
 
 
 
336fa4a
 
 
 
 
29e9dd7
336fa4a
 
 
29e9dd7
a13fdf1
29e9dd7
 
 
 
 
 
 
 
336fa4a
a13fdf1
 
336fa4a
a13fdf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29e9dd7
 
336fa4a
 
29e9dd7
a13fdf1
 
 
 
 
 
 
 
29e9dd7
a13fdf1
 
 
 
 
 
 
 
29e9dd7
a13fdf1
29e9dd7
a13fdf1
 
 
 
 
29e9dd7
a13fdf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29e9dd7
a13fdf1
 
29e9dd7
a13fdf1
 
 
 
 
 
29e9dd7
 
336fa4a
 
29e9dd7
336fa4a
 
 
29e9dd7
 
 
 
 
 
 
 
 
 
336fa4a
29e9dd7
 
336fa4a
 
 
 
29e9dd7
 
336fa4a
 
 
 
29e9dd7
 
336fa4a
 
29e9dd7
336fa4a
 
29e9dd7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# ndvi_predictor.py
import os
# os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import tensorflow as tf
from tensorflow.keras.models import model_from_json
from efficientnet.tfkeras import EfficientNetB2
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from PIL import Image
import io
from tensorflow.keras.models import model_from_json
import traceback
import gc

# Custom loss functions and activation functions
def balanced_mse_loss(y_true, y_pred):
    mse = tf.square(y_true - y_pred)
    negative_weight = tf.where(y_true < -0.2, 1.5, 1.0)
    boundary_weight = tf.where(tf.abs(y_true) > 0.5, 1.5, 1.0)
    weights = negative_weight * boundary_weight
    weighted_mse = weights * mse
    return tf.reduce_mean(mse)

def custom_mae(y_true, y_pred):
    mae = tf.abs(y_true - y_pred)
    return tf.reduce_mean(mae)

def load_model(models_dir):
    """Load NDVI prediction model with custom objects"""
    
    # Define custom objects dictionary
    custom_objects = {
        'balanced_mse_loss': balanced_mse_loss,
        'custom_mae': custom_mae
    }
    
    try:
        # Load model architecture
        with open(os.path.join(models_dir, "model_architecture.json"), "r") as json_file:
            model_json = json_file.read()
        
        model = model_from_json(model_json, custom_objects=custom_objects)
        
        # Load weights
        model.load_weights(os.path.join(models_dir, "best_model_weights.weights.h5"))
        
        # Compile model with custom functions
        optimizer = tf.keras.optimizers.AdamW(learning_rate=0.0005, weight_decay=1e-4)
        
        model.compile(
            optimizer=optimizer,
            loss=balanced_mse_loss,
            metrics=[custom_mae, 'mse']
        )
        
        return model
    except Exception as e:
        traceback.print_exc()
        return None

def normalize_rgb(rgb):
    """Normalize RGB image to [0, 1] range using percentile normalization"""
    rgb_norm = rgb.copy().astype(np.float32)
    
    # Handle different input ranges
    if rgb.max() > 1:
        rgb_norm = rgb_norm / 255.0
    
    for b in range(3):
        band = rgb_norm[:, :, b]
        min_val, max_val = np.percentile(band, [1, 99])
        if min_val < max_val:
            rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
    
    return rgb_norm

def predict_ndvi(model, rgb_np):
    """
    Faster NDVI prediction with larger tiles and more efficient processing
    
    Args:
        model: Loaded NDVI prediction model
        rgb_np: RGB image as numpy array (H, W, 3) normalized to [0, 1]
    
    Returns:
        ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
    """
    height, width = rgb_np.shape[:2]
    
    # Larger tiles for faster processing
    tile_size = 512
    stride = int(tile_size * 0.75)  # 25% overlap
    
    # For smaller images, process whole image at once
    if height <= tile_size and width <= tile_size:
        # Pad to tile size if needed
        pad_height = max(0, tile_size - height)
        pad_width = max(0, tile_size - width)
        if pad_height > 0 or pad_width > 0:
            rgb_padded = np.pad(rgb_np, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
        else:
            rgb_padded = rgb_np
        
        # Single prediction
        pred = model.predict(np.expand_dims(rgb_padded, axis=0), verbose=0, batch_size=1)[0, :, :, 0]
        return pred[:height, :width]
    
    # Initialize output arrays
    ndvi_pred = np.zeros((height, width), dtype=np.float32)
    weight_map = np.zeros((height, width), dtype=np.float32)
    
    # Pre-compute weights for efficiency
    y, x = np.mgrid[0:tile_size, 0:tile_size]
    base_weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
    base_weights = np.clip(base_weights, 0, 64) / 64
    
    # Collect all tiles first
    tiles = []
    positions = []
    
    for i in range(0, height, stride):
        for j in range(0, width, stride):
            # Calculate actual tile bounds
            end_i = min(i + tile_size, height)
            end_j = min(j + tile_size, width)
            actual_height = end_i - i
            actual_width = end_j - j
            
            # Extract tile
            tile = rgb_np[i:end_i, j:end_j, :]
            
            # Pad if necessary
            if actual_height < tile_size or actual_width < tile_size:
                pad_height = tile_size - actual_height
                pad_width = tile_size - actual_width
                tile = np.pad(tile, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect')
            
            tiles.append(tile)
            positions.append((i, j, actual_height, actual_width))
    
    # Process all tiles in larger batches
    batch_size = 8  # Process 8 tiles at once
    for batch_start in range(0, len(tiles), batch_size):
        batch_end = min(batch_start + batch_size, len(tiles))
        batch_tiles = np.array(tiles[batch_start:batch_end])
        
        # Predict batch
        batch_preds = model.predict(batch_tiles, verbose=0, batch_size=batch_size)
        
        # Apply predictions
        for k in range(batch_end - batch_start):
            pred = batch_preds[k, :, :, 0]
            i, j, actual_height, actual_width = positions[batch_start + k]
            
            # Use appropriate weights
            weights = base_weights[:actual_height, :actual_width]
            
            # Add to output
            ndvi_pred[i:i+actual_height, j:j+actual_width] += pred[:actual_height, :actual_width] * weights
            weight_map[i:i+actual_height, j:j+actual_width] += weights
        
        # Clean up batch
        del batch_tiles, batch_preds
    
    # Normalize by weights
    mask = weight_map > 0
    ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
    
    return ndvi_pred

def create_visualization(rgb, ndvi):
    """
    Create visualization of RGB input and predicted NDVI
    
    Args:
        rgb: RGB image array
        ndvi: NDVI prediction array
    
    Returns:
        buf: BytesIO buffer containing the visualization as PNG
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Display RGB image
    rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
    axes[0].imshow(rgb_disp)
    axes[0].set_title("RGB Input")
    axes[0].axis("off")
    
    # Display NDVI with color map
    im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
    axes[1].set_title("Predicted NDVI")
    axes[1].axis("off")
    fig.colorbar(im, ax=axes[1])
    
    # Save to buffer
    buf = io.BytesIO()
    plt.tight_layout()
    plt.savefig(buf, format="png", dpi=150, bbox_inches='tight')
    plt.close(fig)
    buf.seek(0)
    
    return buf