Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| import base64 | |
| import io | |
| MODEL_PATH = "models/Forest_Segmentation_Best.keras" | |
| EPS = 1e-6 | |
| model = None | |
| # ---------------------------- | |
| # Load model once | |
| # ---------------------------- | |
| def load_segmentation_model(): | |
| global model | |
| if model is None: | |
| model = load_model(MODEL_PATH, compile=False) | |
| # ---------------------------- | |
| # Decode Landsat band from base64 | |
| # ---------------------------- | |
| def decode_band_float32(b64): | |
| """Decode base64-encoded float32 data to 2D array""" | |
| raw = base64.b64decode(b64) | |
| arr = np.frombuffer(raw, dtype=np.float32) | |
| side = int(np.sqrt(arr.size)) # assumes square tile | |
| arr = arr.reshape((side, side)) | |
| return arr | |
| # ---------------------------- | |
| # Spectral Indices (matching training pipeline) | |
| # ---------------------------- | |
| def ndvi(red, nir): | |
| """Normalized Difference Vegetation Index""" | |
| return (nir - red) / (nir + red + EPS) | |
| def ndwi(green, nir): | |
| """Normalized Difference Water Index""" | |
| return (green - nir) / (green + nir + EPS) | |
| def nbr(nir, swir2): | |
| """Normalized Burn Ratio""" | |
| return (nir - swir2) / (nir + swir2 + EPS) | |
| # ---------------------------- | |
| # Build 9-channel tensor from Landsat 8 | |
| # ---------------------------- | |
| def build_input_tensor(bands): | |
| """ | |
| Build 9-channel input tensor from Landsat 8 Collection 2 Level 2 data | |
| Args: | |
| bands: Dictionary with keys: | |
| - Blue, Green, Red: Optical bands (0-2) | |
| - NIR, SWIR1, SWIR2: Infrared bands (3-5) | |
| - NDVI, NDWI, NBR: Indices (6-8) | |
| Values can be: | |
| - Base64-encoded float32 strings (from API) | |
| - Numpy arrays (from direct processing) | |
| Returns: | |
| (1, H, W, 9) array ready for model inference | |
| Expected value range: | |
| - Optical bands: [-0.2, 0.6] | |
| - Indices: [-1, 1] | |
| """ | |
| # Extract and decode optical bands | |
| blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"] | |
| green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"] | |
| red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"] | |
| nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"] | |
| swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"] | |
| swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"] | |
| # Use pre-calculated indices if provided, otherwise compute them | |
| if "NDVI" in bands and bands["NDVI"] is not None: | |
| ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"] | |
| else: | |
| ndvi_map = ndvi(red, nir) | |
| if "NDWI" in bands and bands["NDWI"] is not None: | |
| ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"] | |
| else: | |
| ndwi_map = ndwi(green, nir) | |
| if "NBR" in bands and bands["NBR"] is not None: | |
| nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"] | |
| else: | |
| nbr_map = nbr(nir, swir2) | |
| # Stack into (H, W, 9) - matches training data format exactly | |
| stacked = np.stack([ | |
| blue, | |
| green, | |
| red, | |
| nir, | |
| swir1, | |
| swir2, | |
| ndvi_map, | |
| ndwi_map, | |
| nbr_map | |
| ], axis=-1) | |
| stacked = stacked.astype(np.float32) | |
| stacked = np.expand_dims(stacked, axis=0) # (1, H, W, 9) | |
| return stacked | |
| # ---------------------------- | |
| # Inference | |
| # ---------------------------- | |
| def predict_segmentation(bands): | |
| """ | |
| Predict forest segmentation mask | |
| Args: | |
| bands: Dictionary with Landsat 8 bands | |
| Returns: | |
| Dictionary with: | |
| - mask: (H, W) binary segmentation | |
| - forest_percentage: % of pixels classified as forest | |
| - forest_confidence: average confidence on forest pixels | |
| - metadata: model and input information | |
| """ | |
| load_segmentation_model() | |
| x = build_input_tensor(bands) | |
| pred = model.predict(x, verbose=0)[0, :, :, 0] | |
| # Generate binary mask | |
| mask = (pred > 0.5).astype(np.uint8) * 255 | |
| # Calculate statistics | |
| forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0 | |
| forest_percentage = float((pred > 0.5).sum() / pred.size * 100) | |
| return { | |
| "mask": mask.tolist(), | |
| "forest_percentage": forest_percentage, | |
| "forest_confidence": forest_confidence, | |
| "mean_prediction": float(pred.mean()), | |
| "classes": ["forest", "non-forest"], | |
| "model_info": { | |
| "training_data": "Landsat 8 Collection 2 Level 2", | |
| "bands": ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2", "NDVI", "NDWI", "NBR"], | |
| "patch_size": 256, | |
| "value_range": "[-0.2, 0.6] for optical, [-1, 1] for indices" | |
| } | |
| } | |