Spaces:
Running
Running
| # main.py - FastAPI application for Flood Vulnerability Assessment | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from pydantic import BaseModel, field_validator | |
| from typing import Optional, Dict | |
| import pandas as pd | |
| import io | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from spatial_queries import get_terrain_metrics, distance_to_water | |
| from vulnerability import calculate_vulnerability_index | |
| from gee_auth import initialize_gee | |
| import os | |
| DISABLE_HEIGHT_PREDICTOR = os.environ.get("DISABLE_HEIGHT", "false").lower() == "true" | |
| # Global flags for model readiness | |
| model_ready = False | |
| gee_ready = False | |
| # OSM rate limiting | |
| _last_osm_request = None | |
| _osm_lock = asyncio.Lock() | |
| async def throttled_distance_to_water(lat, lon): | |
| """ | |
| Throttle OSM requests | |
| """ | |
| global _last_osm_request | |
| async with _osm_lock: | |
| if _last_osm_request: | |
| elapsed = (datetime.now() - _last_osm_request).total_seconds() | |
| if elapsed < 0.5: # 2 req/sec max | |
| await asyncio.sleep(0.5 - elapsed) | |
| loop = asyncio.get_event_loop() | |
| result = await loop.run_in_executor(None, distance_to_water, lat, lon) | |
| _last_osm_request = datetime.now() | |
| return result | |
| # Lifespan context manager - loads heavy models AFTER port binding | |
| async def lifespan(app: FastAPI): | |
| # Startup: Port binds first, models load in background | |
| print("π FastAPI server starting - port binding now") | |
| asyncio.create_task(load_heavy_models()) | |
| yield | |
| # Shutdown | |
| print("π Shutting down") | |
| async def load_heavy_models(): | |
| """Load heavy models asynchronously after server starts""" | |
| global model_ready, gee_ready | |
| try: | |
| # Initialize GEE immediately (no delay needed) | |
| print("π‘ Initializing GEE...") | |
| initialize_gee() | |
| gee_ready = True | |
| print("β GEE initialized") | |
| # Load SHAP explainer | |
| try: | |
| from explainability import VulnerabilityExplainer | |
| global explainer | |
| explainer = VulnerabilityExplainer() | |
| print("β SHAP model initialized") | |
| except Exception as e: | |
| print(f"β οΈ SHAP explainer not available: {e}") | |
| explainer = None | |
| # Load height predictor (334 MB model) | |
| print("π¦ Loading height predictor...") | |
| if DISABLE_HEIGHT_PREDICTOR: | |
| print("β οΈ Height predictor disabled for this deployment.") | |
| model_ready = False | |
| else: | |
| try: | |
| from height_predictor.inference import get_predictor | |
| get_predictor() | |
| model_ready = True | |
| print("β Height predictor ready") | |
| except Exception as e: | |
| print(f"β οΈ Height predictor failed to load: {e}") | |
| model_ready = False | |
| except Exception as e: | |
| print(f"β Model loading failed: {e}") | |
| # APP INITIALIZATION | |
| app = FastAPI( | |
| title="Flood Vulnerability Assessment API", | |
| version="1.0", | |
| lifespan=lifespan | |
| ) | |
| # Frontend templates setup | |
| templates = Jinja2Templates(directory="templates") | |
| # Thread pool for batch processing | |
| executor = ThreadPoolExecutor(max_workers=10) | |
| # Initialize explainer as None (loaded during startup) | |
| explainer = None | |
| # DATA MODEL | |
| class SingleAssessment(BaseModel): | |
| latitude: float | |
| longitude: float | |
| height: Optional[float] = 0.0 | |
| basement: Optional[float] = 0.0 | |
| def check_lat(cls, v: float) -> float: | |
| if not -90 <= v <= 90: | |
| raise ValueError('Latitude must be between -90 and 90') | |
| return v | |
| def check_lon(cls, v: float) -> float: | |
| if not -180 <= v <= 180: | |
| raise ValueError('Longitude must be between -180 and 180') | |
| return v | |
| def check_basement(cls, v: float) -> float: | |
| if v > 0: | |
| raise ValueError('Basement height must be 0 or negative (e.g., -1, -2, -3)') | |
| return v | |
| # FRONTEND ROUTE | |
| async def home(request: Request): | |
| """Serve the main web interface""" | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # API ROUTES | |
| async def root() -> Dict: | |
| """API info endpoint""" | |
| return { | |
| "service": "Flood Vulnerability Assessment API", | |
| "version": "1.0", | |
| "endpoints": { | |
| "POST /assess": "Assess single location", | |
| "POST /assess_batch": "Assess batch from CSV file", | |
| "GET /health": "Health check" | |
| } | |
| } | |
| async def health_check() -> Dict: | |
| """Health check endpoint - responds immediately even if models still loading""" | |
| return { | |
| "status": "healthy", | |
| "gee_initialized": gee_ready, | |
| "height_predictor_ready": model_ready | |
| } | |
| async def assess_single(data: SingleAssessment) -> Dict: | |
| """Assess flood vulnerability for a single location (non-blocking).""" | |
| if not gee_ready: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="GEE still initializing, try again in 10 seconds" | |
| ) | |
| loop = asyncio.get_event_loop() | |
| try: | |
| # Run terrain query in background thread | |
| terrain = await loop.run_in_executor( | |
| None, | |
| get_terrain_metrics, | |
| data.latitude, | |
| data.longitude | |
| ) | |
| # Throttled water distance query | |
| water_dist = await throttled_distance_to_water(data.latitude, data.longitude) | |
| # Calculate vulnerability after terrain + water distance retrieved | |
| result = calculate_vulnerability_index( | |
| lat=data.latitude, | |
| lon=data.longitude, | |
| height=data.height, | |
| basement=data.basement, | |
| terrain_metrics=terrain, | |
| water_distance=water_dist | |
| ) | |
| return { | |
| "status": "success", | |
| "input": data.dict(), | |
| "assessment": result | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Assessment failed: {e}") | |
| async def process_single_row_async(row, use_predicted_height: bool = False): | |
| """Process a single row from CSV with async throttling.""" | |
| try: | |
| lat = row['latitude'] | |
| lon = row['longitude'] | |
| height = row.get('height', 0.0) | |
| basement = row.get('basement', 0.0) | |
| if use_predicted_height: | |
| if not model_ready: | |
| raise ValueError("Height predictor not ready yet") | |
| try: | |
| from height_predictor.inference import get_predictor | |
| predictor = get_predictor() | |
| pred = predictor.predict_from_coordinates(lat, lon) | |
| if pred.get("status") == "success" and pred.get("predicted_height") is not None: | |
| height = float(pred["predicted_height"]) | |
| except Exception as e: | |
| raise ValueError(f"Height prediction failed for ({lat}, {lon}): {e}") | |
| # Run terrain in thread pool | |
| loop = asyncio.get_event_loop() | |
| terrain = await loop.run_in_executor(None, get_terrain_metrics, lat, lon) | |
| # Throttled water distance | |
| water_dist = await throttled_distance_to_water(lat, lon) | |
| result = calculate_vulnerability_index( | |
| lat=lat, | |
| lon=lon, | |
| height=height, | |
| basement=basement, | |
| terrain_metrics=terrain, | |
| water_distance=water_dist | |
| ) | |
| # CSV output - essential columns | |
| return { | |
| 'latitude': lat, | |
| 'longitude': lon, | |
| 'height': height, | |
| 'basement': basement, | |
| 'vulnerability_index': result['vulnerability_index'], | |
| 'ci_lower_95': result['confidence_interval']['lower_bound_95'], | |
| 'ci_upper_95': result['confidence_interval']['upper_bound_95'], | |
| 'risk_level': result['risk_level'], | |
| 'confidence': result['uncertainty_analysis']['confidence'], | |
| 'confidence_interpretation': result['uncertainty_analysis']['interpretation'], | |
| 'elevation_m': result['elevation_m'], | |
| 'tpi_m': result['relative_elevation_m'], | |
| 'slope_degrees': result['slope_degrees'], | |
| 'distance_to_water_m': result['distance_to_water_m'], | |
| 'quality_flags': ','.join(result['uncertainty_analysis']['data_quality_flags']) if result['uncertainty_analysis']['data_quality_flags'] else '' | |
| } | |
| except Exception as e: | |
| return { | |
| 'latitude': row.get('latitude'), | |
| 'longitude': row.get('longitude'), | |
| 'height': row.get('height', 0.0), | |
| 'basement': row.get('basement', 0.0), | |
| 'error': str(e), | |
| 'vulnerability_index': None, | |
| 'ci_lower_95': None, | |
| 'ci_upper_95': None, | |
| 'risk_level': None, | |
| 'confidence': None, | |
| 'confidence_interpretation': None, | |
| 'elevation_m': None, | |
| 'tpi_m': None, | |
| 'slope_degrees': None, | |
| 'distance_to_water_m': None, | |
| 'quality_flags': '' | |
| } | |
| async def assess_batch(file: UploadFile = File(...), use_predicted_height: bool = False) -> StreamingResponse: | |
| """Assess flood vulnerability for multiple locations from a CSV file.""" | |
| if not gee_ready: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="GEE still initializing, try again in 10 seconds" | |
| ) | |
| if use_predicted_height and not model_ready: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Height predictor still loading, try again in 30 seconds" | |
| ) | |
| try: | |
| contents = await file.read() | |
| df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) | |
| if 'latitude' not in df.columns or 'longitude' not in df.columns: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="CSV must contain 'latitude' and 'longitude' columns" | |
| ) | |
| import numpy as np | |
| df = df[(np.abs(df['latitude']) <= 90) & (np.abs(df['longitude']) <= 180)] | |
| if len(df) == 0: | |
| raise HTTPException(status_code=400, detail="No valid coordinates in CSV (lat -90..90, lon -180..180)") | |
| # Set defaults for optional columns | |
| if 'height' not in df.columns: | |
| df['height'] = 0.0 | |
| if 'basement' not in df.columns: | |
| df['basement'] = 0.0 | |
| # Process rows with async throttling | |
| results = [] | |
| for _, row in df.iterrows(): | |
| result = await process_single_row_async(row, use_predicted_height) | |
| results.append(result) | |
| results_df = pd.DataFrame(results) | |
| output = io.StringIO() | |
| results_df.to_csv(output, index=False) | |
| output.seek(0) | |
| return StreamingResponse( | |
| io.BytesIO(output.getvalue().encode('utf-8')), | |
| media_type="text/csv", | |
| headers={ | |
| "Content-Disposition": ( | |
| "attachment; filename=vulnerability_results.csv; " | |
| "filename*=UTF-8''vulnerability_results.csv" | |
| ) | |
| } | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Batch processing failed: {str(e)}") | |
| async def assess_batch_multihazard(file: UploadFile = File(...)) -> StreamingResponse: | |
| if not gee_ready: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="GEE still initializing, try again in 10 seconds" | |
| ) | |
| try: | |
| contents = await file.read() | |
| df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) | |
| if 'latitude' not in df.columns or 'longitude' not in df.columns: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="CSV must contain 'latitude' and 'longitude' columns" | |
| ) | |
| results = [] | |
| for _, row in df.iterrows(): | |
| result = await process_single_row_multihazard_async(row) | |
| results.append(result) | |
| results_df = pd.DataFrame(results) | |
| output = io.StringIO() | |
| results_df.to_csv(output, index=False) | |
| output.seek(0) | |
| return StreamingResponse( | |
| io.BytesIO(output.getvalue().encode('utf-8')), | |
| media_type="text/csv", | |
| headers={ | |
| "Content-Disposition": ( | |
| "attachment; filename=multihazard_results.csv; " | |
| "filename*=UTF-8''multihazard_results.csv" | |
| ) | |
| } | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Batch multihazard failed: {str(e)}") | |
| async def explain_assessment(data: SingleAssessment) -> Dict: | |
| """Assess vulnerability with SHAP explanation""" | |
| if not gee_ready: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="GEE still initializing, try again in 10 seconds" | |
| ) | |
| loop = asyncio.get_event_loop() | |
| try: | |
| # Run terrain in background thread | |
| terrain = await loop.run_in_executor( | |
| None, | |
| get_terrain_metrics, | |
| data.latitude, | |
| data.longitude | |
| ) | |
| # Throttled water distance | |
| water_dist = await throttled_distance_to_water(data.latitude, data.longitude) | |
| result = calculate_vulnerability_index( | |
| lat=data.latitude, | |
| lon=data.longitude, | |
| height=data.height, | |
| basement=data.basement, | |
| terrain_metrics=terrain, | |
| water_distance=water_dist | |
| ) | |
| # Generate explanation if explainer available | |
| explanation = None | |
| if explainer: | |
| try: | |
| explanation = explainer.explain(result['components']) | |
| except Exception as e: | |
| print(f"SHAP explanation failed: {e}") | |
| return { | |
| "status": "success", | |
| "input": data.dict(), | |
| "assessment": result, | |
| "explanation": explanation | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Assessment failed: {e}") | |
| async def process_single_row_multihazard_async(row): | |
| """Process a single row with multi-hazard assessment.""" | |
| try: | |
| from vulnerability import calculate_multi_hazard_vulnerability | |
| lat = row['latitude'] | |
| lon = row['longitude'] | |
| height = row.get('height', 0.0) | |
| basement = row.get('basement', 0.0) | |
| loop = asyncio.get_event_loop() | |
| terrain = await loop.run_in_executor(None, get_terrain_metrics, lat, lon) | |
| water_dist = await throttled_distance_to_water(lat, lon) | |
| result = calculate_multi_hazard_vulnerability( | |
| lat=lat, | |
| lon=lon, | |
| height=height, | |
| basement=basement, | |
| terrain_metrics=terrain, | |
| water_distance=water_dist | |
| ) | |
| return { | |
| 'latitude': lat, | |
| 'longitude': lon, | |
| 'height': height, | |
| 'basement': basement, | |
| 'vulnerability_index': result['vulnerability_index'], | |
| 'ci_lower_95': result['confidence_interval']['lower_bound_95'], | |
| 'ci_upper_95': result['confidence_interval']['upper_bound_95'], | |
| 'risk_level': result['risk_level'], | |
| 'confidence': result['uncertainty_analysis']['confidence'], | |
| 'confidence_interpretation': result['uncertainty_analysis']['interpretation'], | |
| 'elevation_m': result['elevation_m'], | |
| 'tpi_m': result['relative_elevation_m'], | |
| 'slope_degrees': result['slope_degrees'], | |
| 'distance_to_water_m': result['distance_to_water_m'], | |
| 'dominant_hazard': result['dominant_hazard'], | |
| 'fluvial_risk': result['hazard_breakdown']['fluvial_riverine'], | |
| 'coastal_risk': result['hazard_breakdown']['coastal_surge'], | |
| 'pluvial_risk': result['hazard_breakdown']['pluvial_drainage'], | |
| 'combined_risk': result['hazard_breakdown']['combined_index'], | |
| 'quality_flags': ','.join(result['uncertainty_analysis']['data_quality_flags']) | |
| if result['uncertainty_analysis']['data_quality_flags'] else '' | |
| } | |
| except Exception as e: | |
| return { | |
| 'latitude': row.get('latitude'), | |
| 'longitude': row.get('longitude'), | |
| 'error': str(e), | |
| 'vulnerability_index': None | |
| } | |
| async def assess_multihazard(data: SingleAssessment) -> Dict: | |
| """Multi-hazard assessment (fluvial + coastal + pluvial)""" | |
| if not gee_ready: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="GEE still initializing, try again in 10 seconds" | |
| ) | |
| loop = asyncio.get_event_loop() | |
| try: | |
| from vulnerability import calculate_multi_hazard_vulnerability | |
| # Run terrain in background thread | |
| terrain = await loop.run_in_executor( | |
| None, | |
| get_terrain_metrics, | |
| data.latitude, | |
| data.longitude | |
| ) | |
| # Throttled water distance | |
| water_dist = await throttled_distance_to_water(data.latitude, data.longitude) | |
| result = calculate_multi_hazard_vulnerability( | |
| lat=data.latitude, | |
| lon=data.longitude, | |
| height=data.height, | |
| basement=data.basement, | |
| terrain_metrics=terrain, | |
| water_distance=water_dist | |
| ) | |
| return { | |
| "status": "success", | |
| "input": data.dict(), | |
| "assessment": result | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Assessment failed: {e}") | |
| class HeightRequest(BaseModel): | |
| latitude: float | |
| longitude: float | |
| def check_lat(cls, v: float) -> float: | |
| if not -90 <= v <= 90: | |
| raise ValueError("Latitude must be between -90 and 90") | |
| return v | |
| def check_lon(cls, v: float) -> float: | |
| if not -180 <= v <= 180: | |
| raise ValueError("Longitude must be between -180 and 180") | |
| return v | |
| async def predict_height(data: HeightRequest) -> Dict: | |
| if DISABLE_HEIGHT_PREDICTOR: | |
| raise HTTPException(status_code=503, | |
| detail="Height predictor disabled on this deployment.") | |
| if not model_ready: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Height predictor still loading, try again later." | |
| ) | |
| try: | |
| from height_predictor.inference import get_predictor | |
| predictor = get_predictor() | |
| loop = asyncio.get_event_loop() | |
| result = await loop.run_in_executor( | |
| None, | |
| predictor.predict_from_coordinates, | |
| data.latitude, | |
| data.longitude, | |
| ) | |
| return result | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Height prediction failed: {str(e)}", | |
| ) | |
| # For local development | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |