from tensorflow.keras.models import load_model import numpy as np import base64 import logging MODEL_PATH = "models/Forest_Segmentation_Best.keras" model = None EPS = 1e-6 # Setup logging logger = logging.getLogger("forest_segmentation.inference") def load(): global model if model is None: logger.info("[INFERENCE] Loading model from: " + MODEL_PATH) model = load_model(MODEL_PATH, compile=False) logger.info("[INFERENCE] Model loaded successfully") def decode_band_float32(b64): """Decode base64-encoded float32 band data to array""" raw = base64.b64decode(b64) arr = np.frombuffer(raw, dtype=np.float32) side = int(np.sqrt(arr.size)) return arr.reshape((side, side)) def validate_landsat_data(bands_dict): """ Validate that input data matches Landsat 8 Collection 2 Level 2 format Expected range: [-0.2, 0.6] for optical bands, [-1, 1] for indices """ for band_name, data in bands_dict.items(): if data.ndim != 2: raise ValueError(f"{band_name}: Expected 2D array, got shape {data.shape}") if data.dtype != np.float32: data = data.astype(np.float32) return bands_dict def ndvi(red, nir): """Normalized Difference Vegetation Index""" return (nir - red) / (nir + red + EPS) def ndwi(green, nir): """Normalized Difference Water Index""" return (green - nir) / (green + nir + EPS) def nbr(nir, swir2): """Normalized Burn Ratio""" return (nir - swir2) / (nir + swir2 + EPS) def analyze_input_bands(bands): """Analyze input bands and return statistics""" stats = {} logger.info("[ANALYSIS] === INPUT BAND ANALYSIS ===") for band_name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NDVI', 'NDWI', 'NBR']: if band_name in bands: data = bands[band_name] stats[band_name] = { "min": float(data.min()), "max": float(data.max()), "mean": float(data.mean()), "std": float(data.std()) } logger.info(f"[ANALYSIS] {band_name}: min={stats[band_name]['min']:.4f}, max={stats[band_name]['max']:.4f}, mean={stats[band_name]['mean']:.4f}") # Analyze vegetation coverage if 'NDVI' in bands: ndvi_data = bands['NDVI'] veg_pixels = np.sum(ndvi_data > 0.5) veg_pct = (veg_pixels / ndvi_data.size) * 100 logger.info(f"[ANALYSIS] NDVI > 0.5 (vegetation): {veg_pct:.2f}% of pixels") stats['vegetation_coverage_pct'] = veg_pct return stats def preprocess_for_model(bands, clip_optical=False, clip_indices=False): """ Preprocess bands to match model training expectations Args: bands: Dictionary of band arrays clip_optical: If True, clip optical bands to [-0.2, 0.6] clip_indices: If True, clip indices to [-1, 1] Returns: Preprocessed bands dictionary """ processed = {} if clip_optical: logger.info("[PREPROCESS] Clipping optical bands to [-0.2, 0.6]") for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']: if name in bands: processed[name] = np.clip(bands[name], -0.2, 0.6) else: processed[name] = bands.get(name) else: for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']: if name in bands: processed[name] = bands[name] if clip_indices: logger.info("[PREPROCESS] Clipping indices to [-1.0, 1.0]") for name in ['NDVI', 'NDWI', 'NBR']: if name in bands: processed[name] = np.clip(bands[name], -1.0, 1.0) else: processed[name] = bands.get(name) else: for name in ['NDVI', 'NDWI', 'NBR']: if name in bands: processed[name] = bands[name] return processed def build_input_tensor(bands): """ Build 9-channel input tensor from Landsat 8 bands Expected band dict keys: - Blue, Green, Red: Optical bands (indices 0-2) - NIR, SWIR1, SWIR2: Infrared bands (indices 3-5) - NDVI, NDWI, NBR: Pre-calculated or computed indices (indices 6-8) Returns: (1, 256, 256, 9) array ready for model inference """ # Extract optical bands blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"] green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"] red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"] nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"] swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"] swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"] # Use pre-calculated indices if provided, otherwise compute them if isinstance(bands.get("NDVI"), str) or isinstance(bands.get("NDVI"), np.ndarray): ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"] else: ndvi_map = ndvi(red, nir) if isinstance(bands.get("NDWI"), str) or isinstance(bands.get("NDWI"), np.ndarray): ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"] else: ndwi_map = ndwi(green, nir) if isinstance(bands.get("NBR"), str) or isinstance(bands.get("NBR"), np.ndarray): nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"] else: nbr_map = nbr(nir, swir2) # Stack into 9-channel tensor: (H, W, 9) stacked = np.stack([ blue, green, red, nir, swir1, swir2, ndvi_map, ndwi_map, nbr_map ], axis=-1).astype(np.float32) # Validate data range matches training expectations opt_min, opt_max = np.min(stacked[..., :6]), np.max(stacked[..., :6]) if opt_min < -0.3 or opt_max > 1.0: logger.warning(f"[BUILD] WARNING: Optical bands range [{opt_min:.4f}, {opt_max:.4f}] outside expected [-0.2, 0.6]") # Add batch dimension: (1, H, W, 9) stacked = np.expand_dims(stacked, axis=0) return stacked def predict_forest(bands, debug=False, clip_optical=False, clip_indices=False): """ Predict forest segmentation mask from Landsat 8 9-band input Args: bands: Dictionary with keys: Blue, Green, Red, NIR, SWIR1, SWIR2, NDVI, NDWI, NBR debug: If True, return detailed debug statistics clip_optical: If True, clip optical bands to [-0.2, 0.6] clip_indices: If True, clip indices to [-1, 1] Returns: Dictionary with mask, confidence scores, and optional debug data """ load() # Analyze input logger.info("[PREDICT] Starting prediction...") input_stats = analyze_input_bands(bands) # Preprocess if requested if clip_optical or clip_indices: logger.info("[PREDICT] Applying preprocessing (clip_optical={}, clip_indices={})...".format(clip_optical, clip_indices)) bands = preprocess_for_model(bands, clip_optical=clip_optical, clip_indices=clip_indices) # Build input tensor logger.info("[PREDICT] Building input tensor...") x = build_input_tensor(bands) # Run inference logger.info("[PREDICT] Running model inference...") pred = model.predict(x, verbose=0)[0, :, :, 0] # Extract (H, W) from (1, H, W, 1) # Analyze output logger.info("[PREDICT] === RAW MODEL OUTPUT ===") logger.info(f"[PREDICT] Output shape: {pred.shape}, dtype: {pred.dtype}") logger.info(f"[PREDICT] Output range: [{pred.min():.4f}, {pred.max():.4f}]") logger.info(f"[PREDICT] Output mean: {pred.mean():.4f}, std: {pred.std():.4f}") logger.info(f"[PREDICT] Pixels > 0.5: {np.sum(pred > 0.5):,} / {pred.size:,} ({100*np.sum(pred > 0.5)/pred.size:.2f}%)") logger.info(f"[PREDICT] Pixels > 0.8: {np.sum(pred > 0.8):,} / {pred.size:,}") # Generate binary mask mask = (pred > 0.5).astype(np.uint8) * 255 # Calculate statistics forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0 forest_percentage = float((pred > 0.5).sum() / pred.size * 100) result = { "mask": mask.tolist(), "forest_confidence": forest_confidence, "forest_percentage": forest_percentage, "mean_prediction": float(pred.mean()), "classes": ["forest", "non-forest"], "model_version": "landsat8_trained" } if debug: logger.info("[PREDICT] Adding debug information...") result["debug"] = { "input_stats": input_stats, "output_distribution": { "min": float(pred.min()), "max": float(pred.max()), "mean": float(pred.mean()), "std": float(pred.std()), "percentile_10": float(np.percentile(pred, 10)), "percentile_25": float(np.percentile(pred, 25)), "percentile_50": float(np.percentile(pred, 50)), "percentile_75": float(np.percentile(pred, 75)), "percentile_90": float(np.percentile(pred, 90)), "histogram": { "0.0-0.1": int(np.sum((pred >= 0.0) & (pred < 0.1))), "0.1-0.2": int(np.sum((pred >= 0.1) & (pred < 0.2))), "0.2-0.3": int(np.sum((pred >= 0.2) & (pred < 0.3))), "0.3-0.4": int(np.sum((pred >= 0.3) & (pred < 0.4))), "0.4-0.5": int(np.sum((pred >= 0.4) & (pred < 0.5))), "0.5-0.6": int(np.sum((pred >= 0.5) & (pred < 0.6))), "0.6-0.7": int(np.sum((pred >= 0.6) & (pred < 0.7))), "0.7-0.8": int(np.sum((pred >= 0.7) & (pred < 0.8))), "0.8-0.9": int(np.sum((pred >= 0.8) & (pred < 0.9))), "0.9-1.0": int(np.sum((pred >= 0.9) & (pred <= 1.0))) } } } logger.info("[PREDICT] Forest prediction: {:.2f}%".format(forest_percentage)) return result