Spaces:
Sleeping
Sleeping
| # main.py | |
| from huggingface_hub import hf_hub_download | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| import tensorflow as tf | |
| import numpy as np | |
| import base64 | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from logging.handlers import RotatingFileHandler | |
| from inference.forest import predict_forest, build_input_tensor | |
| from schemas import PredictRequest, PredictResponse | |
| # ============================================================================= | |
| # LOGGING CONFIGURATION | |
| # ============================================================================= | |
| os.makedirs("logs", exist_ok=True) | |
| logger = logging.getLogger("forest_segmentation") | |
| logger.setLevel(logging.DEBUG) | |
| console_handler = logging.StreamHandler(sys.stdout) | |
| console_handler.setLevel(logging.DEBUG) | |
| console_handler.setFormatter( | |
| logging.Formatter( | |
| "%(asctime)s | %(levelname)-8s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| ) | |
| file_handler = RotatingFileHandler( | |
| "logs/server.log", maxBytes=10_000_000, backupCount=5, encoding="utf-8" | |
| ) | |
| file_handler.setFormatter(console_handler.formatter) | |
| logger.addHandler(console_handler) | |
| logger.addHandler(file_handler) | |
| logger.info("=" * 80) | |
| logger.info("FOREST SEGMENTATION SERVER STARTING") | |
| logger.info("=" * 80) | |
| # ============================================================================= | |
| # INVERSION DETECTION | |
| # ============================================================================= | |
| def detect_inversion(image_stack, confidence_map, ndvi_threshold=0.3): | |
| """ | |
| Detect if model output is inverted using NDVI correlation. | |
| image_stack: (H, W, 9) | |
| confidence_map: (H, W) | |
| """ | |
| ndvi = image_stack[:, :, 6] # NDVI channel | |
| vegetation_mask = ndvi > ndvi_threshold | |
| veg_conf = ( | |
| confidence_map[vegetation_mask].mean() | |
| if vegetation_mask.any() else 0.5 | |
| ) | |
| non_veg_conf = ( | |
| confidence_map[~vegetation_mask].mean() | |
| if (~vegetation_mask).any() else 0.5 | |
| ) | |
| is_inverted = non_veg_conf > veg_conf | |
| correlation = veg_conf - non_veg_conf | |
| return bool(is_inverted), float(correlation) | |
| # ============================================================================= | |
| # FASTAPI APP | |
| # ============================================================================= | |
| app = FastAPI( | |
| title="Forest Segmentation API", | |
| description="Landsat 8 Forest Segmentation", | |
| version="1.0.0" | |
| ) | |
| IMG_SIZE = 256 | |
| LANDSAT_BANDS = [ | |
| "Blue", "Green", "Red", | |
| "NIR", "SWIR1", "SWIR2", | |
| "NDVI", "NDWI", "NBR" | |
| ] | |
| # ============================================================================= | |
| # MIDDLEWARE | |
| # ============================================================================= | |
| async def log_requests(request: Request, call_next): | |
| start = time.time() | |
| response = await call_next(request) | |
| duration = time.time() - start | |
| logger.info( | |
| f"{request.method} {request.url.path} | " | |
| f"{response.status_code} | {duration:.3f}s" | |
| ) | |
| return response | |
| # ============================================================================= | |
| # ROOT ENDPOINT | |
| # ============================================================================= | |
| def root(): | |
| """Serve a simple HTML page with API info.""" | |
| return """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Forest Segmentation API</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 40px; background-color: #f5f5f5; } | |
| .container { background-color: white; padding: 20px; border-radius: 8px; max-width: 800px; } | |
| h1 { color: #333; } | |
| .endpoint { background-color: #f0f0f0; padding: 10px; margin: 10px 0; border-left: 4px solid #4CAF50; } | |
| code { background-color: #f9f9f9; padding: 2px 6px; border-radius: 3px; } | |
| .status { color: #4CAF50; font-weight: bold; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>🌲 Forest Segmentation API</h1> | |
| <p>Landsat 8 Forest Segmentation Model</p> | |
| <h2>API Endpoints</h2> | |
| <div class="endpoint"> | |
| <strong>Health Check</strong><br> | |
| <code>GET /health</code><br> | |
| Returns API status | |
| </div> | |
| <div class="endpoint"> | |
| <strong>Predict</strong><br> | |
| <code>POST /predict</code><br> | |
| Send Landsat bands for forest segmentation | |
| </div> | |
| <div class="endpoint"> | |
| <strong>API Docs</strong><br> | |
| <code>GET /docs</code><br> | |
| Interactive Swagger UI | |
| </div> | |
| <h2>Status</h2> | |
| <p><span class="status">✓ API is running</span></p> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| # ============================================================================= | |
| # HEALTH | |
| # ============================================================================= | |
| def health(): | |
| return { | |
| "status": "healthy", | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| # ============================================================================= | |
| # PREDICT ENDPOINT (FIXED - CONTINUOUS VALUES) | |
| # ============================================================================= | |
| def predict(payload: PredictRequest): | |
| try: | |
| logger.info("[PREDICT] Request received") | |
| if not payload.bands: | |
| raise ValueError("No bands provided") | |
| # --------------------------------------------------------------------- | |
| # Decode bands | |
| # --------------------------------------------------------------------- | |
| decoded_bands = {} | |
| for band, data in payload.bands.items(): | |
| if isinstance(data, str): | |
| raw = base64.b64decode(data) | |
| arr = np.frombuffer(raw, dtype=np.float32) | |
| side = int(np.sqrt(arr.size)) | |
| decoded_bands[band] = arr.reshape((side, side)) | |
| else: | |
| decoded_bands[band] = np.array(data, dtype=np.float32) | |
| logger.info(f"[PREDICT] Decoded {len(decoded_bands)} bands") | |
| # --------------------------------------------------------------------- | |
| # Build input tensor | |
| # --------------------------------------------------------------------- | |
| input_tensor = build_input_tensor(decoded_bands) # (1, H, W, 9) | |
| input_stack = input_tensor[0] # (H, W, 9) | |
| # --------------------------------------------------------------------- | |
| # Run model (raw confidence) | |
| # --------------------------------------------------------------------- | |
| MODEL_REPO = "prshntdxt/Forest_Segmentation_Best" | |
| MODEL_FILE = "Forest_Segmentation_Best.keras" | |
| MODEL_PATH = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILE, | |
| ) | |
| model = tf.keras.models.load_model( | |
| MODEL_PATH, | |
| compile=False | |
| ) | |
| confidence_map = model.predict( | |
| input_tensor, verbose=0 | |
| )[0, :, :, 0] | |
| # Log raw model output stats | |
| logger.info( | |
| f"[MODEL OUTPUT] Raw confidence: min={confidence_map.min():.4f}, " | |
| f"max={confidence_map.max():.4f}, mean={confidence_map.mean():.4f}" | |
| ) | |
| # --------------------------------------------------------------------- | |
| # Inversion detection & correction | |
| # --------------------------------------------------------------------- | |
| is_inverted, corr = detect_inversion( | |
| input_stack, confidence_map | |
| ) | |
| if is_inverted: | |
| logger.warning( | |
| f"[INVERSION] Detected | NDVI correlation={corr:.4f} | FIX APPLIED" | |
| ) | |
| corrected_conf = 1.0 - confidence_map | |
| else: | |
| logger.info( | |
| f"[INVERSION] Not detected | NDVI correlation={corr:.4f}" | |
| ) | |
| corrected_conf = confidence_map | |
| # --------------------------------------------------------------------- | |
| # Create masks (CONTINUOUS values for density visualization) | |
| # --------------------------------------------------------------------- | |
| # Use continuous confidence scaled to 0-255 (NOT binary!) | |
| mask_255 = (corrected_conf * 255).astype(np.uint8) | |
| inverted_mask_255 = (255 - mask_255).astype(np.uint8) | |
| # Calculate stats using threshold for percentage | |
| forest_percentage = float((corrected_conf > 0.5).sum() / corrected_conf.size * 100) | |
| forest_confidence = float(corrected_conf.mean()) | |
| # Log mask stats to verify continuous values | |
| logger.info( | |
| f"[MASK] Range: [{mask_255.min()}, {mask_255.max()}] | " | |
| f"Unique values: {len(np.unique(mask_255))}" | |
| ) | |
| logger.info( | |
| f"[PREDICT] Forest={forest_percentage:.2f}% | " | |
| f"Confidence={forest_confidence:.4f}" | |
| ) | |
| # --------------------------------------------------------------------- | |
| # Response | |
| # --------------------------------------------------------------------- | |
| return { | |
| "mask": mask_255.flatten().tolist(), | |
| "inverted_mask": inverted_mask_255.flatten().tolist(), | |
| "forest_percentage": forest_percentage, | |
| "forest_confidence": forest_confidence, | |
| "mean_prediction": forest_confidence, | |
| "classes": {"forest": 1, "non_forest": 0}, | |
| "model_info": { | |
| "name": "Forest_Segmentation_Best", | |
| "bands": LANDSAT_BANDS | |
| }, | |
| "debug": { | |
| "was_inverted": is_inverted, | |
| "inversion_correlation": corr, | |
| "mask_min": int(mask_255.min()), | |
| "mask_max": int(mask_255.max()), | |
| "unique_values": int(len(np.unique(mask_255))) | |
| } | |
| } | |
| except ValueError as e: | |
| logger.error(f"[PREDICT] Validation error: {e}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.exception("[PREDICT] Inference failed") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ============================================================================= | |
| # STARTUP / SHUTDOWN | |
| # ============================================================================= | |
| async def startup(): | |
| logger.info("=" * 80) | |
| logger.info("SERVER READY") | |
| logger.info("=" * 80) | |
| async def shutdown(): | |
| logger.info("=" * 80) | |
| logger.info("SERVER SHUTDOWN") | |
| logger.info("=" * 80) | |