File size: 10,335 Bytes
c82cafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from tensorflow.keras.models import load_model
import numpy as np
import base64
import logging

MODEL_PATH = "models/Forest_Segmentation_Best.keras"
model = None
EPS = 1e-6

# Setup logging
logger = logging.getLogger("forest_segmentation.inference")

def load():
    global model
    if model is None:
        logger.info("[INFERENCE] Loading model from: " + MODEL_PATH)
        model = load_model(MODEL_PATH, compile=False)
        logger.info("[INFERENCE] Model loaded successfully")

def decode_band_float32(b64):
    """Decode base64-encoded float32 band data to array"""
    raw = base64.b64decode(b64)
    arr = np.frombuffer(raw, dtype=np.float32)
    side = int(np.sqrt(arr.size))
    return arr.reshape((side, side))

def validate_landsat_data(bands_dict):
    """
    Validate that input data matches Landsat 8 Collection 2 Level 2 format
    Expected range: [-0.2, 0.6] for optical bands, [-1, 1] for indices
    """
    for band_name, data in bands_dict.items():
        if data.ndim != 2:
            raise ValueError(f"{band_name}: Expected 2D array, got shape {data.shape}")
        if data.dtype != np.float32:
            data = data.astype(np.float32)
    return bands_dict

def ndvi(red, nir):
    """Normalized Difference Vegetation Index"""
    return (nir - red) / (nir + red + EPS)

def ndwi(green, nir):
    """Normalized Difference Water Index"""
    return (green - nir) / (green + nir + EPS)

def nbr(nir, swir2):
    """Normalized Burn Ratio"""
    return (nir - swir2) / (nir + swir2 + EPS)

def analyze_input_bands(bands):
    """Analyze input bands and return statistics"""
    stats = {}
    
    logger.info("[ANALYSIS] === INPUT BAND ANALYSIS ===")
    
    for band_name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NDVI', 'NDWI', 'NBR']:
        if band_name in bands:
            data = bands[band_name]
            stats[band_name] = {
                "min": float(data.min()),
                "max": float(data.max()),
                "mean": float(data.mean()),
                "std": float(data.std())
            }
            logger.info(f"[ANALYSIS] {band_name}: min={stats[band_name]['min']:.4f}, max={stats[band_name]['max']:.4f}, mean={stats[band_name]['mean']:.4f}")
    
    # Analyze vegetation coverage
    if 'NDVI' in bands:
        ndvi_data = bands['NDVI']
        veg_pixels = np.sum(ndvi_data > 0.5)
        veg_pct = (veg_pixels / ndvi_data.size) * 100
        logger.info(f"[ANALYSIS] NDVI > 0.5 (vegetation): {veg_pct:.2f}% of pixels")
        stats['vegetation_coverage_pct'] = veg_pct
    
    return stats

def preprocess_for_model(bands, clip_optical=False, clip_indices=False):
    """
    Preprocess bands to match model training expectations
    
    Args:
        bands: Dictionary of band arrays
        clip_optical: If True, clip optical bands to [-0.2, 0.6]
        clip_indices: If True, clip indices to [-1, 1]
    
    Returns:
        Preprocessed bands dictionary
    """
    processed = {}
    
    if clip_optical:
        logger.info("[PREPROCESS] Clipping optical bands to [-0.2, 0.6]")
        for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']:
            if name in bands:
                processed[name] = np.clip(bands[name], -0.2, 0.6)
            else:
                processed[name] = bands.get(name)
    else:
        for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']:
            if name in bands:
                processed[name] = bands[name]
    
    if clip_indices:
        logger.info("[PREPROCESS] Clipping indices to [-1.0, 1.0]")
        for name in ['NDVI', 'NDWI', 'NBR']:
            if name in bands:
                processed[name] = np.clip(bands[name], -1.0, 1.0)
            else:
                processed[name] = bands.get(name)
    else:
        for name in ['NDVI', 'NDWI', 'NBR']:
            if name in bands:
                processed[name] = bands[name]
    
    return processed

def build_input_tensor(bands):
    """
    Build 9-channel input tensor from Landsat 8 bands
    
    Expected band dict keys:
      - Blue, Green, Red: Optical bands (indices 0-2)
      - NIR, SWIR1, SWIR2: Infrared bands (indices 3-5)
      - NDVI, NDWI, NBR: Pre-calculated or computed indices (indices 6-8)
    
    Returns: (1, 256, 256, 9) array ready for model inference
    """
    # Extract optical bands
    blue   = decode_band_float32(bands["Blue"])   if isinstance(bands["Blue"], str) else bands["Blue"]
    green  = decode_band_float32(bands["Green"])  if isinstance(bands["Green"], str) else bands["Green"]
    red    = decode_band_float32(bands["Red"])    if isinstance(bands["Red"], str) else bands["Red"]
    nir    = decode_band_float32(bands["NIR"])    if isinstance(bands["NIR"], str) else bands["NIR"]
    swir1  = decode_band_float32(bands["SWIR1"])  if isinstance(bands["SWIR1"], str) else bands["SWIR1"]
    swir2  = decode_band_float32(bands["SWIR2"])  if isinstance(bands["SWIR2"], str) else bands["SWIR2"]

    # Use pre-calculated indices if provided, otherwise compute them
    if isinstance(bands.get("NDVI"), str) or isinstance(bands.get("NDVI"), np.ndarray):
        ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"]
    else:
        ndvi_map = ndvi(red, nir)
    
    if isinstance(bands.get("NDWI"), str) or isinstance(bands.get("NDWI"), np.ndarray):
        ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"]
    else:
        ndwi_map = ndwi(green, nir)
    
    if isinstance(bands.get("NBR"), str) or isinstance(bands.get("NBR"), np.ndarray):
        nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"]
    else:
        nbr_map = nbr(nir, swir2)

    # Stack into 9-channel tensor: (H, W, 9)
    stacked = np.stack([
        blue,
        green,
        red,
        nir,
        swir1,
        swir2,
        ndvi_map,
        ndwi_map,
        nbr_map
    ], axis=-1).astype(np.float32)

    # Validate data range matches training expectations
    opt_min, opt_max = np.min(stacked[..., :6]), np.max(stacked[..., :6])
    if opt_min < -0.3 or opt_max > 1.0:
        logger.warning(f"[BUILD] WARNING: Optical bands range [{opt_min:.4f}, {opt_max:.4f}] outside expected [-0.2, 0.6]")

    # Add batch dimension: (1, H, W, 9)
    stacked = np.expand_dims(stacked, axis=0)
    return stacked

def predict_forest(bands, debug=False, clip_optical=False, clip_indices=False):
    """
    Predict forest segmentation mask from Landsat 8 9-band input
    
    Args:
        bands: Dictionary with keys: Blue, Green, Red, NIR, SWIR1, SWIR2, NDVI, NDWI, NBR
        debug: If True, return detailed debug statistics
        clip_optical: If True, clip optical bands to [-0.2, 0.6]
        clip_indices: If True, clip indices to [-1, 1]
    
    Returns:
        Dictionary with mask, confidence scores, and optional debug data
    """
    load()
    
    # Analyze input
    logger.info("[PREDICT] Starting prediction...")
    input_stats = analyze_input_bands(bands)
    
    # Preprocess if requested
    if clip_optical or clip_indices:
        logger.info("[PREDICT] Applying preprocessing (clip_optical={}, clip_indices={})...".format(clip_optical, clip_indices))
        bands = preprocess_for_model(bands, clip_optical=clip_optical, clip_indices=clip_indices)
    
    # Build input tensor
    logger.info("[PREDICT] Building input tensor...")
    x = build_input_tensor(bands)
    
    # Run inference
    logger.info("[PREDICT] Running model inference...")
    pred = model.predict(x, verbose=0)[0, :, :, 0]  # Extract (H, W) from (1, H, W, 1)
    
    # Analyze output
    logger.info("[PREDICT] === RAW MODEL OUTPUT ===")
    logger.info(f"[PREDICT] Output shape: {pred.shape}, dtype: {pred.dtype}")
    logger.info(f"[PREDICT] Output range: [{pred.min():.4f}, {pred.max():.4f}]")
    logger.info(f"[PREDICT] Output mean: {pred.mean():.4f}, std: {pred.std():.4f}")
    logger.info(f"[PREDICT] Pixels > 0.5: {np.sum(pred > 0.5):,} / {pred.size:,} ({100*np.sum(pred > 0.5)/pred.size:.2f}%)")
    logger.info(f"[PREDICT] Pixels > 0.8: {np.sum(pred > 0.8):,} / {pred.size:,}")
    
    # Generate binary mask
    mask = (pred > 0.5).astype(np.uint8) * 255
    
    # Calculate statistics
    forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0
    forest_percentage = float((pred > 0.5).sum() / pred.size * 100)

    result = {
        "mask": mask.tolist(),
        "forest_confidence": forest_confidence,
        "forest_percentage": forest_percentage,
        "mean_prediction": float(pred.mean()),
        "classes": ["forest", "non-forest"],
        "model_version": "landsat8_trained"
    }
    
    if debug:
        logger.info("[PREDICT] Adding debug information...")
        result["debug"] = {
            "input_stats": input_stats,
            "output_distribution": {
                "min": float(pred.min()),
                "max": float(pred.max()),
                "mean": float(pred.mean()),
                "std": float(pred.std()),
                "percentile_10": float(np.percentile(pred, 10)),
                "percentile_25": float(np.percentile(pred, 25)),
                "percentile_50": float(np.percentile(pred, 50)),
                "percentile_75": float(np.percentile(pred, 75)),
                "percentile_90": float(np.percentile(pred, 90)),
                "histogram": {
                    "0.0-0.1": int(np.sum((pred >= 0.0) & (pred < 0.1))),
                    "0.1-0.2": int(np.sum((pred >= 0.1) & (pred < 0.2))),
                    "0.2-0.3": int(np.sum((pred >= 0.2) & (pred < 0.3))),
                    "0.3-0.4": int(np.sum((pred >= 0.3) & (pred < 0.4))),
                    "0.4-0.5": int(np.sum((pred >= 0.4) & (pred < 0.5))),
                    "0.5-0.6": int(np.sum((pred >= 0.5) & (pred < 0.6))),
                    "0.6-0.7": int(np.sum((pred >= 0.6) & (pred < 0.7))),
                    "0.7-0.8": int(np.sum((pred >= 0.7) & (pred < 0.8))),
                    "0.8-0.9": int(np.sum((pred >= 0.8) & (pred < 0.9))),
                    "0.9-1.0": int(np.sum((pred >= 0.9) & (pred <= 1.0)))
                }
            }
        }
    
    logger.info("[PREDICT] Forest prediction: {:.2f}%".format(forest_percentage))
    return result