prshntdxt's picture
Deploy Forest Segmentation API with LFS
c82cafe
from tensorflow.keras.models import load_model
import numpy as np
import base64
import logging
MODEL_PATH = "models/Forest_Segmentation_Best.keras"
model = None
EPS = 1e-6
# Setup logging
logger = logging.getLogger("forest_segmentation.inference")
def load():
global model
if model is None:
logger.info("[INFERENCE] Loading model from: " + MODEL_PATH)
model = load_model(MODEL_PATH, compile=False)
logger.info("[INFERENCE] Model loaded successfully")
def decode_band_float32(b64):
"""Decode base64-encoded float32 band data to array"""
raw = base64.b64decode(b64)
arr = np.frombuffer(raw, dtype=np.float32)
side = int(np.sqrt(arr.size))
return arr.reshape((side, side))
def validate_landsat_data(bands_dict):
"""
Validate that input data matches Landsat 8 Collection 2 Level 2 format
Expected range: [-0.2, 0.6] for optical bands, [-1, 1] for indices
"""
for band_name, data in bands_dict.items():
if data.ndim != 2:
raise ValueError(f"{band_name}: Expected 2D array, got shape {data.shape}")
if data.dtype != np.float32:
data = data.astype(np.float32)
return bands_dict
def ndvi(red, nir):
"""Normalized Difference Vegetation Index"""
return (nir - red) / (nir + red + EPS)
def ndwi(green, nir):
"""Normalized Difference Water Index"""
return (green - nir) / (green + nir + EPS)
def nbr(nir, swir2):
"""Normalized Burn Ratio"""
return (nir - swir2) / (nir + swir2 + EPS)
def analyze_input_bands(bands):
"""Analyze input bands and return statistics"""
stats = {}
logger.info("[ANALYSIS] === INPUT BAND ANALYSIS ===")
for band_name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NDVI', 'NDWI', 'NBR']:
if band_name in bands:
data = bands[band_name]
stats[band_name] = {
"min": float(data.min()),
"max": float(data.max()),
"mean": float(data.mean()),
"std": float(data.std())
}
logger.info(f"[ANALYSIS] {band_name}: min={stats[band_name]['min']:.4f}, max={stats[band_name]['max']:.4f}, mean={stats[band_name]['mean']:.4f}")
# Analyze vegetation coverage
if 'NDVI' in bands:
ndvi_data = bands['NDVI']
veg_pixels = np.sum(ndvi_data > 0.5)
veg_pct = (veg_pixels / ndvi_data.size) * 100
logger.info(f"[ANALYSIS] NDVI > 0.5 (vegetation): {veg_pct:.2f}% of pixels")
stats['vegetation_coverage_pct'] = veg_pct
return stats
def preprocess_for_model(bands, clip_optical=False, clip_indices=False):
"""
Preprocess bands to match model training expectations
Args:
bands: Dictionary of band arrays
clip_optical: If True, clip optical bands to [-0.2, 0.6]
clip_indices: If True, clip indices to [-1, 1]
Returns:
Preprocessed bands dictionary
"""
processed = {}
if clip_optical:
logger.info("[PREPROCESS] Clipping optical bands to [-0.2, 0.6]")
for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']:
if name in bands:
processed[name] = np.clip(bands[name], -0.2, 0.6)
else:
processed[name] = bands.get(name)
else:
for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']:
if name in bands:
processed[name] = bands[name]
if clip_indices:
logger.info("[PREPROCESS] Clipping indices to [-1.0, 1.0]")
for name in ['NDVI', 'NDWI', 'NBR']:
if name in bands:
processed[name] = np.clip(bands[name], -1.0, 1.0)
else:
processed[name] = bands.get(name)
else:
for name in ['NDVI', 'NDWI', 'NBR']:
if name in bands:
processed[name] = bands[name]
return processed
def build_input_tensor(bands):
"""
Build 9-channel input tensor from Landsat 8 bands
Expected band dict keys:
- Blue, Green, Red: Optical bands (indices 0-2)
- NIR, SWIR1, SWIR2: Infrared bands (indices 3-5)
- NDVI, NDWI, NBR: Pre-calculated or computed indices (indices 6-8)
Returns: (1, 256, 256, 9) array ready for model inference
"""
# Extract optical bands
blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"]
green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"]
red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"]
nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"]
swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"]
swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"]
# Use pre-calculated indices if provided, otherwise compute them
if isinstance(bands.get("NDVI"), str) or isinstance(bands.get("NDVI"), np.ndarray):
ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"]
else:
ndvi_map = ndvi(red, nir)
if isinstance(bands.get("NDWI"), str) or isinstance(bands.get("NDWI"), np.ndarray):
ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"]
else:
ndwi_map = ndwi(green, nir)
if isinstance(bands.get("NBR"), str) or isinstance(bands.get("NBR"), np.ndarray):
nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"]
else:
nbr_map = nbr(nir, swir2)
# Stack into 9-channel tensor: (H, W, 9)
stacked = np.stack([
blue,
green,
red,
nir,
swir1,
swir2,
ndvi_map,
ndwi_map,
nbr_map
], axis=-1).astype(np.float32)
# Validate data range matches training expectations
opt_min, opt_max = np.min(stacked[..., :6]), np.max(stacked[..., :6])
if opt_min < -0.3 or opt_max > 1.0:
logger.warning(f"[BUILD] WARNING: Optical bands range [{opt_min:.4f}, {opt_max:.4f}] outside expected [-0.2, 0.6]")
# Add batch dimension: (1, H, W, 9)
stacked = np.expand_dims(stacked, axis=0)
return stacked
def predict_forest(bands, debug=False, clip_optical=False, clip_indices=False):
"""
Predict forest segmentation mask from Landsat 8 9-band input
Args:
bands: Dictionary with keys: Blue, Green, Red, NIR, SWIR1, SWIR2, NDVI, NDWI, NBR
debug: If True, return detailed debug statistics
clip_optical: If True, clip optical bands to [-0.2, 0.6]
clip_indices: If True, clip indices to [-1, 1]
Returns:
Dictionary with mask, confidence scores, and optional debug data
"""
load()
# Analyze input
logger.info("[PREDICT] Starting prediction...")
input_stats = analyze_input_bands(bands)
# Preprocess if requested
if clip_optical or clip_indices:
logger.info("[PREDICT] Applying preprocessing (clip_optical={}, clip_indices={})...".format(clip_optical, clip_indices))
bands = preprocess_for_model(bands, clip_optical=clip_optical, clip_indices=clip_indices)
# Build input tensor
logger.info("[PREDICT] Building input tensor...")
x = build_input_tensor(bands)
# Run inference
logger.info("[PREDICT] Running model inference...")
pred = model.predict(x, verbose=0)[0, :, :, 0] # Extract (H, W) from (1, H, W, 1)
# Analyze output
logger.info("[PREDICT] === RAW MODEL OUTPUT ===")
logger.info(f"[PREDICT] Output shape: {pred.shape}, dtype: {pred.dtype}")
logger.info(f"[PREDICT] Output range: [{pred.min():.4f}, {pred.max():.4f}]")
logger.info(f"[PREDICT] Output mean: {pred.mean():.4f}, std: {pred.std():.4f}")
logger.info(f"[PREDICT] Pixels > 0.5: {np.sum(pred > 0.5):,} / {pred.size:,} ({100*np.sum(pred > 0.5)/pred.size:.2f}%)")
logger.info(f"[PREDICT] Pixels > 0.8: {np.sum(pred > 0.8):,} / {pred.size:,}")
# Generate binary mask
mask = (pred > 0.5).astype(np.uint8) * 255
# Calculate statistics
forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0
forest_percentage = float((pred > 0.5).sum() / pred.size * 100)
result = {
"mask": mask.tolist(),
"forest_confidence": forest_confidence,
"forest_percentage": forest_percentage,
"mean_prediction": float(pred.mean()),
"classes": ["forest", "non-forest"],
"model_version": "landsat8_trained"
}
if debug:
logger.info("[PREDICT] Adding debug information...")
result["debug"] = {
"input_stats": input_stats,
"output_distribution": {
"min": float(pred.min()),
"max": float(pred.max()),
"mean": float(pred.mean()),
"std": float(pred.std()),
"percentile_10": float(np.percentile(pred, 10)),
"percentile_25": float(np.percentile(pred, 25)),
"percentile_50": float(np.percentile(pred, 50)),
"percentile_75": float(np.percentile(pred, 75)),
"percentile_90": float(np.percentile(pred, 90)),
"histogram": {
"0.0-0.1": int(np.sum((pred >= 0.0) & (pred < 0.1))),
"0.1-0.2": int(np.sum((pred >= 0.1) & (pred < 0.2))),
"0.2-0.3": int(np.sum((pred >= 0.2) & (pred < 0.3))),
"0.3-0.4": int(np.sum((pred >= 0.3) & (pred < 0.4))),
"0.4-0.5": int(np.sum((pred >= 0.4) & (pred < 0.5))),
"0.5-0.6": int(np.sum((pred >= 0.5) & (pred < 0.6))),
"0.6-0.7": int(np.sum((pred >= 0.6) & (pred < 0.7))),
"0.7-0.8": int(np.sum((pred >= 0.7) & (pred < 0.8))),
"0.8-0.9": int(np.sum((pred >= 0.8) & (pred < 0.9))),
"0.9-1.0": int(np.sum((pred >= 0.9) & (pred <= 1.0)))
}
}
}
logger.info("[PREDICT] Forest prediction: {:.2f}%".format(forest_percentage))
return result