prshntdxt's picture
Deploy Forest Segmentation API with LFS
c82cafe
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import base64
import io
MODEL_PATH = "models/Forest_Segmentation_Best.keras"
EPS = 1e-6
model = None
# ----------------------------
# Load model once
# ----------------------------
def load_segmentation_model():
global model
if model is None:
model = load_model(MODEL_PATH, compile=False)
# ----------------------------
# Decode Landsat band from base64
# ----------------------------
def decode_band_float32(b64):
"""Decode base64-encoded float32 data to 2D array"""
raw = base64.b64decode(b64)
arr = np.frombuffer(raw, dtype=np.float32)
side = int(np.sqrt(arr.size)) # assumes square tile
arr = arr.reshape((side, side))
return arr
# ----------------------------
# Spectral Indices (matching training pipeline)
# ----------------------------
def ndvi(red, nir):
"""Normalized Difference Vegetation Index"""
return (nir - red) / (nir + red + EPS)
def ndwi(green, nir):
"""Normalized Difference Water Index"""
return (green - nir) / (green + nir + EPS)
def nbr(nir, swir2):
"""Normalized Burn Ratio"""
return (nir - swir2) / (nir + swir2 + EPS)
# ----------------------------
# Build 9-channel tensor from Landsat 8
# ----------------------------
def build_input_tensor(bands):
"""
Build 9-channel input tensor from Landsat 8 Collection 2 Level 2 data
Args:
bands: Dictionary with keys:
- Blue, Green, Red: Optical bands (0-2)
- NIR, SWIR1, SWIR2: Infrared bands (3-5)
- NDVI, NDWI, NBR: Indices (6-8)
Values can be:
- Base64-encoded float32 strings (from API)
- Numpy arrays (from direct processing)
Returns:
(1, H, W, 9) array ready for model inference
Expected value range:
- Optical bands: [-0.2, 0.6]
- Indices: [-1, 1]
"""
# Extract and decode optical bands
blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"]
green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"]
red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"]
nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"]
swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"]
swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"]
# Use pre-calculated indices if provided, otherwise compute them
if "NDVI" in bands and bands["NDVI"] is not None:
ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"]
else:
ndvi_map = ndvi(red, nir)
if "NDWI" in bands and bands["NDWI"] is not None:
ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"]
else:
ndwi_map = ndwi(green, nir)
if "NBR" in bands and bands["NBR"] is not None:
nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"]
else:
nbr_map = nbr(nir, swir2)
# Stack into (H, W, 9) - matches training data format exactly
stacked = np.stack([
blue,
green,
red,
nir,
swir1,
swir2,
ndvi_map,
ndwi_map,
nbr_map
], axis=-1)
stacked = stacked.astype(np.float32)
stacked = np.expand_dims(stacked, axis=0) # (1, H, W, 9)
return stacked
# ----------------------------
# Inference
# ----------------------------
def predict_segmentation(bands):
"""
Predict forest segmentation mask
Args:
bands: Dictionary with Landsat 8 bands
Returns:
Dictionary with:
- mask: (H, W) binary segmentation
- forest_percentage: % of pixels classified as forest
- forest_confidence: average confidence on forest pixels
- metadata: model and input information
"""
load_segmentation_model()
x = build_input_tensor(bands)
pred = model.predict(x, verbose=0)[0, :, :, 0]
# Generate binary mask
mask = (pred > 0.5).astype(np.uint8) * 255
# Calculate statistics
forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0
forest_percentage = float((pred > 0.5).sum() / pred.size * 100)
return {
"mask": mask.tolist(),
"forest_percentage": forest_percentage,
"forest_confidence": forest_confidence,
"mean_prediction": float(pred.mean()),
"classes": ["forest", "non-forest"],
"model_info": {
"training_data": "Landsat 8 Collection 2 Level 2",
"bands": ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2", "NDVI", "NDWI", "NBR"],
"patch_size": 256,
"value_range": "[-0.2, 0.6] for optical, [-1, 1] for indices"
}
}