import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model import base64 import io MODEL_PATH = "models/Forest_Segmentation_Best.keras" EPS = 1e-6 model = None # ---------------------------- # Load model once # ---------------------------- def load_segmentation_model(): global model if model is None: model = load_model(MODEL_PATH, compile=False) # ---------------------------- # Decode Landsat band from base64 # ---------------------------- def decode_band_float32(b64): """Decode base64-encoded float32 data to 2D array""" raw = base64.b64decode(b64) arr = np.frombuffer(raw, dtype=np.float32) side = int(np.sqrt(arr.size)) # assumes square tile arr = arr.reshape((side, side)) return arr # ---------------------------- # Spectral Indices (matching training pipeline) # ---------------------------- def ndvi(red, nir): """Normalized Difference Vegetation Index""" return (nir - red) / (nir + red + EPS) def ndwi(green, nir): """Normalized Difference Water Index""" return (green - nir) / (green + nir + EPS) def nbr(nir, swir2): """Normalized Burn Ratio""" return (nir - swir2) / (nir + swir2 + EPS) # ---------------------------- # Build 9-channel tensor from Landsat 8 # ---------------------------- def build_input_tensor(bands): """ Build 9-channel input tensor from Landsat 8 Collection 2 Level 2 data Args: bands: Dictionary with keys: - Blue, Green, Red: Optical bands (0-2) - NIR, SWIR1, SWIR2: Infrared bands (3-5) - NDVI, NDWI, NBR: Indices (6-8) Values can be: - Base64-encoded float32 strings (from API) - Numpy arrays (from direct processing) Returns: (1, H, W, 9) array ready for model inference Expected value range: - Optical bands: [-0.2, 0.6] - Indices: [-1, 1] """ # Extract and decode optical bands blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"] green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"] red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"] nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"] swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"] swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"] # Use pre-calculated indices if provided, otherwise compute them if "NDVI" in bands and bands["NDVI"] is not None: ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"] else: ndvi_map = ndvi(red, nir) if "NDWI" in bands and bands["NDWI"] is not None: ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"] else: ndwi_map = ndwi(green, nir) if "NBR" in bands and bands["NBR"] is not None: nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"] else: nbr_map = nbr(nir, swir2) # Stack into (H, W, 9) - matches training data format exactly stacked = np.stack([ blue, green, red, nir, swir1, swir2, ndvi_map, ndwi_map, nbr_map ], axis=-1) stacked = stacked.astype(np.float32) stacked = np.expand_dims(stacked, axis=0) # (1, H, W, 9) return stacked # ---------------------------- # Inference # ---------------------------- def predict_segmentation(bands): """ Predict forest segmentation mask Args: bands: Dictionary with Landsat 8 bands Returns: Dictionary with: - mask: (H, W) binary segmentation - forest_percentage: % of pixels classified as forest - forest_confidence: average confidence on forest pixels - metadata: model and input information """ load_segmentation_model() x = build_input_tensor(bands) pred = model.predict(x, verbose=0)[0, :, :, 0] # Generate binary mask mask = (pred > 0.5).astype(np.uint8) * 255 # Calculate statistics forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0 forest_percentage = float((pred > 0.5).sum() / pred.size * 100) return { "mask": mask.tolist(), "forest_percentage": forest_percentage, "forest_confidence": forest_confidence, "mean_prediction": float(pred.mean()), "classes": ["forest", "non-forest"], "model_info": { "training_data": "Landsat 8 Collection 2 Level 2", "bands": ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2", "NDVI", "NDWI", "NBR"], "patch_size": 256, "value_range": "[-0.2, 0.6] for optical, [-1, 1] for indices" } }