Spaces:
Sleeping
Sleeping
File size: 5,050 Bytes
c82cafe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import base64
import io
MODEL_PATH = "models/Forest_Segmentation_Best.keras"
EPS = 1e-6
model = None
# ----------------------------
# Load model once
# ----------------------------
def load_segmentation_model():
global model
if model is None:
model = load_model(MODEL_PATH, compile=False)
# ----------------------------
# Decode Landsat band from base64
# ----------------------------
def decode_band_float32(b64):
"""Decode base64-encoded float32 data to 2D array"""
raw = base64.b64decode(b64)
arr = np.frombuffer(raw, dtype=np.float32)
side = int(np.sqrt(arr.size)) # assumes square tile
arr = arr.reshape((side, side))
return arr
# ----------------------------
# Spectral Indices (matching training pipeline)
# ----------------------------
def ndvi(red, nir):
"""Normalized Difference Vegetation Index"""
return (nir - red) / (nir + red + EPS)
def ndwi(green, nir):
"""Normalized Difference Water Index"""
return (green - nir) / (green + nir + EPS)
def nbr(nir, swir2):
"""Normalized Burn Ratio"""
return (nir - swir2) / (nir + swir2 + EPS)
# ----------------------------
# Build 9-channel tensor from Landsat 8
# ----------------------------
def build_input_tensor(bands):
"""
Build 9-channel input tensor from Landsat 8 Collection 2 Level 2 data
Args:
bands: Dictionary with keys:
- Blue, Green, Red: Optical bands (0-2)
- NIR, SWIR1, SWIR2: Infrared bands (3-5)
- NDVI, NDWI, NBR: Indices (6-8)
Values can be:
- Base64-encoded float32 strings (from API)
- Numpy arrays (from direct processing)
Returns:
(1, H, W, 9) array ready for model inference
Expected value range:
- Optical bands: [-0.2, 0.6]
- Indices: [-1, 1]
"""
# Extract and decode optical bands
blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"]
green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"]
red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"]
nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"]
swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"]
swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"]
# Use pre-calculated indices if provided, otherwise compute them
if "NDVI" in bands and bands["NDVI"] is not None:
ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"]
else:
ndvi_map = ndvi(red, nir)
if "NDWI" in bands and bands["NDWI"] is not None:
ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"]
else:
ndwi_map = ndwi(green, nir)
if "NBR" in bands and bands["NBR"] is not None:
nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"]
else:
nbr_map = nbr(nir, swir2)
# Stack into (H, W, 9) - matches training data format exactly
stacked = np.stack([
blue,
green,
red,
nir,
swir1,
swir2,
ndvi_map,
ndwi_map,
nbr_map
], axis=-1)
stacked = stacked.astype(np.float32)
stacked = np.expand_dims(stacked, axis=0) # (1, H, W, 9)
return stacked
# ----------------------------
# Inference
# ----------------------------
def predict_segmentation(bands):
"""
Predict forest segmentation mask
Args:
bands: Dictionary with Landsat 8 bands
Returns:
Dictionary with:
- mask: (H, W) binary segmentation
- forest_percentage: % of pixels classified as forest
- forest_confidence: average confidence on forest pixels
- metadata: model and input information
"""
load_segmentation_model()
x = build_input_tensor(bands)
pred = model.predict(x, verbose=0)[0, :, :, 0]
# Generate binary mask
mask = (pred > 0.5).astype(np.uint8) * 255
# Calculate statistics
forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0
forest_percentage = float((pred > 0.5).sum() / pred.size * 100)
return {
"mask": mask.tolist(),
"forest_percentage": forest_percentage,
"forest_confidence": forest_confidence,
"mean_prediction": float(pred.mean()),
"classes": ["forest", "non-forest"],
"model_info": {
"training_data": "Landsat 8 Collection 2 Level 2",
"bands": ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2", "NDVI", "NDWI", "NBR"],
"patch_size": 256,
"value_range": "[-0.2, 0.6] for optical, [-1, 1] for indices"
}
}
|