Spaces:
Sleeping
Sleeping
| from tensorflow.keras.models import load_model | |
| import numpy as np | |
| import base64 | |
| import logging | |
| MODEL_PATH = "models/Forest_Segmentation_Best.keras" | |
| model = None | |
| EPS = 1e-6 | |
| # Setup logging | |
| logger = logging.getLogger("forest_segmentation.inference") | |
| def load(): | |
| global model | |
| if model is None: | |
| logger.info("[INFERENCE] Loading model from: " + MODEL_PATH) | |
| model = load_model(MODEL_PATH, compile=False) | |
| logger.info("[INFERENCE] Model loaded successfully") | |
| def decode_band_float32(b64): | |
| """Decode base64-encoded float32 band data to array""" | |
| raw = base64.b64decode(b64) | |
| arr = np.frombuffer(raw, dtype=np.float32) | |
| side = int(np.sqrt(arr.size)) | |
| return arr.reshape((side, side)) | |
| def validate_landsat_data(bands_dict): | |
| """ | |
| Validate that input data matches Landsat 8 Collection 2 Level 2 format | |
| Expected range: [-0.2, 0.6] for optical bands, [-1, 1] for indices | |
| """ | |
| for band_name, data in bands_dict.items(): | |
| if data.ndim != 2: | |
| raise ValueError(f"{band_name}: Expected 2D array, got shape {data.shape}") | |
| if data.dtype != np.float32: | |
| data = data.astype(np.float32) | |
| return bands_dict | |
| def ndvi(red, nir): | |
| """Normalized Difference Vegetation Index""" | |
| return (nir - red) / (nir + red + EPS) | |
| def ndwi(green, nir): | |
| """Normalized Difference Water Index""" | |
| return (green - nir) / (green + nir + EPS) | |
| def nbr(nir, swir2): | |
| """Normalized Burn Ratio""" | |
| return (nir - swir2) / (nir + swir2 + EPS) | |
| def analyze_input_bands(bands): | |
| """Analyze input bands and return statistics""" | |
| stats = {} | |
| logger.info("[ANALYSIS] === INPUT BAND ANALYSIS ===") | |
| for band_name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NDVI', 'NDWI', 'NBR']: | |
| if band_name in bands: | |
| data = bands[band_name] | |
| stats[band_name] = { | |
| "min": float(data.min()), | |
| "max": float(data.max()), | |
| "mean": float(data.mean()), | |
| "std": float(data.std()) | |
| } | |
| logger.info(f"[ANALYSIS] {band_name}: min={stats[band_name]['min']:.4f}, max={stats[band_name]['max']:.4f}, mean={stats[band_name]['mean']:.4f}") | |
| # Analyze vegetation coverage | |
| if 'NDVI' in bands: | |
| ndvi_data = bands['NDVI'] | |
| veg_pixels = np.sum(ndvi_data > 0.5) | |
| veg_pct = (veg_pixels / ndvi_data.size) * 100 | |
| logger.info(f"[ANALYSIS] NDVI > 0.5 (vegetation): {veg_pct:.2f}% of pixels") | |
| stats['vegetation_coverage_pct'] = veg_pct | |
| return stats | |
| def preprocess_for_model(bands, clip_optical=False, clip_indices=False): | |
| """ | |
| Preprocess bands to match model training expectations | |
| Args: | |
| bands: Dictionary of band arrays | |
| clip_optical: If True, clip optical bands to [-0.2, 0.6] | |
| clip_indices: If True, clip indices to [-1, 1] | |
| Returns: | |
| Preprocessed bands dictionary | |
| """ | |
| processed = {} | |
| if clip_optical: | |
| logger.info("[PREPROCESS] Clipping optical bands to [-0.2, 0.6]") | |
| for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']: | |
| if name in bands: | |
| processed[name] = np.clip(bands[name], -0.2, 0.6) | |
| else: | |
| processed[name] = bands.get(name) | |
| else: | |
| for name in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']: | |
| if name in bands: | |
| processed[name] = bands[name] | |
| if clip_indices: | |
| logger.info("[PREPROCESS] Clipping indices to [-1.0, 1.0]") | |
| for name in ['NDVI', 'NDWI', 'NBR']: | |
| if name in bands: | |
| processed[name] = np.clip(bands[name], -1.0, 1.0) | |
| else: | |
| processed[name] = bands.get(name) | |
| else: | |
| for name in ['NDVI', 'NDWI', 'NBR']: | |
| if name in bands: | |
| processed[name] = bands[name] | |
| return processed | |
| def build_input_tensor(bands): | |
| """ | |
| Build 9-channel input tensor from Landsat 8 bands | |
| Expected band dict keys: | |
| - Blue, Green, Red: Optical bands (indices 0-2) | |
| - NIR, SWIR1, SWIR2: Infrared bands (indices 3-5) | |
| - NDVI, NDWI, NBR: Pre-calculated or computed indices (indices 6-8) | |
| Returns: (1, 256, 256, 9) array ready for model inference | |
| """ | |
| # Extract optical bands | |
| blue = decode_band_float32(bands["Blue"]) if isinstance(bands["Blue"], str) else bands["Blue"] | |
| green = decode_band_float32(bands["Green"]) if isinstance(bands["Green"], str) else bands["Green"] | |
| red = decode_band_float32(bands["Red"]) if isinstance(bands["Red"], str) else bands["Red"] | |
| nir = decode_band_float32(bands["NIR"]) if isinstance(bands["NIR"], str) else bands["NIR"] | |
| swir1 = decode_band_float32(bands["SWIR1"]) if isinstance(bands["SWIR1"], str) else bands["SWIR1"] | |
| swir2 = decode_band_float32(bands["SWIR2"]) if isinstance(bands["SWIR2"], str) else bands["SWIR2"] | |
| # Use pre-calculated indices if provided, otherwise compute them | |
| if isinstance(bands.get("NDVI"), str) or isinstance(bands.get("NDVI"), np.ndarray): | |
| ndvi_map = decode_band_float32(bands["NDVI"]) if isinstance(bands["NDVI"], str) else bands["NDVI"] | |
| else: | |
| ndvi_map = ndvi(red, nir) | |
| if isinstance(bands.get("NDWI"), str) or isinstance(bands.get("NDWI"), np.ndarray): | |
| ndwi_map = decode_band_float32(bands["NDWI"]) if isinstance(bands["NDWI"], str) else bands["NDWI"] | |
| else: | |
| ndwi_map = ndwi(green, nir) | |
| if isinstance(bands.get("NBR"), str) or isinstance(bands.get("NBR"), np.ndarray): | |
| nbr_map = decode_band_float32(bands["NBR"]) if isinstance(bands["NBR"], str) else bands["NBR"] | |
| else: | |
| nbr_map = nbr(nir, swir2) | |
| # Stack into 9-channel tensor: (H, W, 9) | |
| stacked = np.stack([ | |
| blue, | |
| green, | |
| red, | |
| nir, | |
| swir1, | |
| swir2, | |
| ndvi_map, | |
| ndwi_map, | |
| nbr_map | |
| ], axis=-1).astype(np.float32) | |
| # Validate data range matches training expectations | |
| opt_min, opt_max = np.min(stacked[..., :6]), np.max(stacked[..., :6]) | |
| if opt_min < -0.3 or opt_max > 1.0: | |
| logger.warning(f"[BUILD] WARNING: Optical bands range [{opt_min:.4f}, {opt_max:.4f}] outside expected [-0.2, 0.6]") | |
| # Add batch dimension: (1, H, W, 9) | |
| stacked = np.expand_dims(stacked, axis=0) | |
| return stacked | |
| def predict_forest(bands, debug=False, clip_optical=False, clip_indices=False): | |
| """ | |
| Predict forest segmentation mask from Landsat 8 9-band input | |
| Args: | |
| bands: Dictionary with keys: Blue, Green, Red, NIR, SWIR1, SWIR2, NDVI, NDWI, NBR | |
| debug: If True, return detailed debug statistics | |
| clip_optical: If True, clip optical bands to [-0.2, 0.6] | |
| clip_indices: If True, clip indices to [-1, 1] | |
| Returns: | |
| Dictionary with mask, confidence scores, and optional debug data | |
| """ | |
| load() | |
| # Analyze input | |
| logger.info("[PREDICT] Starting prediction...") | |
| input_stats = analyze_input_bands(bands) | |
| # Preprocess if requested | |
| if clip_optical or clip_indices: | |
| logger.info("[PREDICT] Applying preprocessing (clip_optical={}, clip_indices={})...".format(clip_optical, clip_indices)) | |
| bands = preprocess_for_model(bands, clip_optical=clip_optical, clip_indices=clip_indices) | |
| # Build input tensor | |
| logger.info("[PREDICT] Building input tensor...") | |
| x = build_input_tensor(bands) | |
| # Run inference | |
| logger.info("[PREDICT] Running model inference...") | |
| pred = model.predict(x, verbose=0)[0, :, :, 0] # Extract (H, W) from (1, H, W, 1) | |
| # Analyze output | |
| logger.info("[PREDICT] === RAW MODEL OUTPUT ===") | |
| logger.info(f"[PREDICT] Output shape: {pred.shape}, dtype: {pred.dtype}") | |
| logger.info(f"[PREDICT] Output range: [{pred.min():.4f}, {pred.max():.4f}]") | |
| logger.info(f"[PREDICT] Output mean: {pred.mean():.4f}, std: {pred.std():.4f}") | |
| logger.info(f"[PREDICT] Pixels > 0.5: {np.sum(pred > 0.5):,} / {pred.size:,} ({100*np.sum(pred > 0.5)/pred.size:.2f}%)") | |
| logger.info(f"[PREDICT] Pixels > 0.8: {np.sum(pred > 0.8):,} / {pred.size:,}") | |
| # Generate binary mask | |
| mask = (pred > 0.5).astype(np.uint8) * 255 | |
| # Calculate statistics | |
| forest_confidence = float(np.mean(pred[pred > 0.5])) if np.any(pred > 0.5) else 0.0 | |
| forest_percentage = float((pred > 0.5).sum() / pred.size * 100) | |
| result = { | |
| "mask": mask.tolist(), | |
| "forest_confidence": forest_confidence, | |
| "forest_percentage": forest_percentage, | |
| "mean_prediction": float(pred.mean()), | |
| "classes": ["forest", "non-forest"], | |
| "model_version": "landsat8_trained" | |
| } | |
| if debug: | |
| logger.info("[PREDICT] Adding debug information...") | |
| result["debug"] = { | |
| "input_stats": input_stats, | |
| "output_distribution": { | |
| "min": float(pred.min()), | |
| "max": float(pred.max()), | |
| "mean": float(pred.mean()), | |
| "std": float(pred.std()), | |
| "percentile_10": float(np.percentile(pred, 10)), | |
| "percentile_25": float(np.percentile(pred, 25)), | |
| "percentile_50": float(np.percentile(pred, 50)), | |
| "percentile_75": float(np.percentile(pred, 75)), | |
| "percentile_90": float(np.percentile(pred, 90)), | |
| "histogram": { | |
| "0.0-0.1": int(np.sum((pred >= 0.0) & (pred < 0.1))), | |
| "0.1-0.2": int(np.sum((pred >= 0.1) & (pred < 0.2))), | |
| "0.2-0.3": int(np.sum((pred >= 0.2) & (pred < 0.3))), | |
| "0.3-0.4": int(np.sum((pred >= 0.3) & (pred < 0.4))), | |
| "0.4-0.5": int(np.sum((pred >= 0.4) & (pred < 0.5))), | |
| "0.5-0.6": int(np.sum((pred >= 0.5) & (pred < 0.6))), | |
| "0.6-0.7": int(np.sum((pred >= 0.6) & (pred < 0.7))), | |
| "0.7-0.8": int(np.sum((pred >= 0.7) & (pred < 0.8))), | |
| "0.8-0.9": int(np.sum((pred >= 0.8) & (pred < 0.9))), | |
| "0.9-1.0": int(np.sum((pred >= 0.9) & (pred <= 1.0))) | |
| } | |
| } | |
| } | |
| logger.info("[PREDICT] Forest prediction: {:.2f}%".format(forest_percentage)) | |
| return result | |