adema5051's picture
Upload 10 files
a359779 verified
raw
history blame
20.5 kB
# main.py - FastAPI application for Flood Vulnerability Assessment
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, field_validator
from typing import Optional, Dict
import pandas as pd
import io
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime
from spatial_queries import get_terrain_metrics, distance_to_water
from vulnerability import calculate_vulnerability_index
from gee_auth import initialize_gee
import os
DISABLE_HEIGHT_PREDICTOR = os.environ.get("DISABLE_HEIGHT", "false").lower() == "true"
# Global flags for model readiness
model_ready = False
gee_ready = False
# OSM rate limiting
_last_osm_request = None
_osm_lock = asyncio.Lock()
async def throttled_distance_to_water(lat, lon):
"""
Throttle OSM requests
"""
global _last_osm_request
async with _osm_lock:
if _last_osm_request:
elapsed = (datetime.now() - _last_osm_request).total_seconds()
if elapsed < 0.5: # 2 req/sec max
await asyncio.sleep(0.5 - elapsed)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, distance_to_water, lat, lon)
_last_osm_request = datetime.now()
return result
# Lifespan context manager - loads heavy models AFTER port binding
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: Port binds first, models load in background
print("πŸš€ FastAPI server starting - port binding now")
asyncio.create_task(load_heavy_models())
yield
# Shutdown
print("πŸ›‘ Shutting down")
async def load_heavy_models():
"""Load heavy models asynchronously after server starts"""
global model_ready, gee_ready
try:
# Initialize GEE immediately (no delay needed)
print("πŸ“‘ Initializing GEE...")
initialize_gee()
gee_ready = True
print("βœ… GEE initialized")
# Load SHAP explainer
try:
from explainability import VulnerabilityExplainer
global explainer
explainer = VulnerabilityExplainer()
print("βœ… SHAP model initialized")
except Exception as e:
print(f"⚠️ SHAP explainer not available: {e}")
explainer = None
# Load height predictor (334 MB model)
print("πŸ“¦ Loading height predictor...")
if DISABLE_HEIGHT_PREDICTOR:
print("⚠️ Height predictor disabled for this deployment.")
model_ready = False
else:
try:
from height_predictor.inference import get_predictor
get_predictor()
model_ready = True
print("βœ… Height predictor ready")
except Exception as e:
print(f"⚠️ Height predictor failed to load: {e}")
model_ready = False
except Exception as e:
print(f"❌ Model loading failed: {e}")
# APP INITIALIZATION
app = FastAPI(
title="Flood Vulnerability Assessment API",
version="1.0",
lifespan=lifespan
)
# Frontend templates setup
templates = Jinja2Templates(directory="templates")
# Thread pool for batch processing
executor = ThreadPoolExecutor(max_workers=10)
# Initialize explainer as None (loaded during startup)
explainer = None
# DATA MODEL
class SingleAssessment(BaseModel):
latitude: float
longitude: float
height: Optional[float] = 0.0
basement: Optional[float] = 0.0
@field_validator('latitude')
@classmethod
def check_lat(cls, v: float) -> float:
if not -90 <= v <= 90:
raise ValueError('Latitude must be between -90 and 90')
return v
@field_validator('longitude')
@classmethod
def check_lon(cls, v: float) -> float:
if not -180 <= v <= 180:
raise ValueError('Longitude must be between -180 and 180')
return v
@field_validator('basement')
@classmethod
def check_basement(cls, v: float) -> float:
if v > 0:
raise ValueError('Basement height must be 0 or negative (e.g., -1, -2, -3)')
return v
# FRONTEND ROUTE
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""Serve the main web interface"""
return templates.TemplateResponse("index.html", {"request": request})
# API ROUTES
@app.get("/api")
async def root() -> Dict:
"""API info endpoint"""
return {
"service": "Flood Vulnerability Assessment API",
"version": "1.0",
"endpoints": {
"POST /assess": "Assess single location",
"POST /assess_batch": "Assess batch from CSV file",
"GET /health": "Health check"
}
}
@app.get("/health")
async def health_check() -> Dict:
"""Health check endpoint - responds immediately even if models still loading"""
return {
"status": "healthy",
"gee_initialized": gee_ready,
"height_predictor_ready": model_ready
}
@app.post("/assess")
async def assess_single(data: SingleAssessment) -> Dict:
"""Assess flood vulnerability for a single location (non-blocking)."""
if not gee_ready:
raise HTTPException(
status_code=503,
detail="GEE still initializing, try again in 10 seconds"
)
loop = asyncio.get_event_loop()
try:
# Run terrain query in background thread
terrain = await loop.run_in_executor(
None,
get_terrain_metrics,
data.latitude,
data.longitude
)
# Throttled water distance query
water_dist = await throttled_distance_to_water(data.latitude, data.longitude)
# Calculate vulnerability after terrain + water distance retrieved
result = calculate_vulnerability_index(
lat=data.latitude,
lon=data.longitude,
height=data.height,
basement=data.basement,
terrain_metrics=terrain,
water_distance=water_dist
)
return {
"status": "success",
"input": data.dict(),
"assessment": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Assessment failed: {e}")
async def process_single_row_async(row, use_predicted_height: bool = False):
"""Process a single row from CSV with async throttling."""
try:
lat = row['latitude']
lon = row['longitude']
height = row.get('height', 0.0)
basement = row.get('basement', 0.0)
if use_predicted_height:
if not model_ready:
raise ValueError("Height predictor not ready yet")
try:
from height_predictor.inference import get_predictor
predictor = get_predictor()
pred = predictor.predict_from_coordinates(lat, lon)
if pred.get("status") == "success" and pred.get("predicted_height") is not None:
height = float(pred["predicted_height"])
except Exception as e:
raise ValueError(f"Height prediction failed for ({lat}, {lon}): {e}")
# Run terrain in thread pool
loop = asyncio.get_event_loop()
terrain = await loop.run_in_executor(None, get_terrain_metrics, lat, lon)
# Throttled water distance
water_dist = await throttled_distance_to_water(lat, lon)
result = calculate_vulnerability_index(
lat=lat,
lon=lon,
height=height,
basement=basement,
terrain_metrics=terrain,
water_distance=water_dist
)
# CSV output - essential columns
return {
'latitude': lat,
'longitude': lon,
'height': height,
'basement': basement,
'vulnerability_index': result['vulnerability_index'],
'ci_lower_95': result['confidence_interval']['lower_bound_95'],
'ci_upper_95': result['confidence_interval']['upper_bound_95'],
'risk_level': result['risk_level'],
'confidence': result['uncertainty_analysis']['confidence'],
'confidence_interpretation': result['uncertainty_analysis']['interpretation'],
'elevation_m': result['elevation_m'],
'tpi_m': result['relative_elevation_m'],
'slope_degrees': result['slope_degrees'],
'distance_to_water_m': result['distance_to_water_m'],
'quality_flags': ','.join(result['uncertainty_analysis']['data_quality_flags']) if result['uncertainty_analysis']['data_quality_flags'] else ''
}
except Exception as e:
return {
'latitude': row.get('latitude'),
'longitude': row.get('longitude'),
'height': row.get('height', 0.0),
'basement': row.get('basement', 0.0),
'error': str(e),
'vulnerability_index': None,
'ci_lower_95': None,
'ci_upper_95': None,
'risk_level': None,
'confidence': None,
'confidence_interpretation': None,
'elevation_m': None,
'tpi_m': None,
'slope_degrees': None,
'distance_to_water_m': None,
'quality_flags': ''
}
@app.post("/assess_batch")
async def assess_batch(file: UploadFile = File(...), use_predicted_height: bool = False) -> StreamingResponse:
"""Assess flood vulnerability for multiple locations from a CSV file."""
if not gee_ready:
raise HTTPException(
status_code=503,
detail="GEE still initializing, try again in 10 seconds"
)
if use_predicted_height and not model_ready:
raise HTTPException(
status_code=503,
detail="Height predictor still loading, try again in 30 seconds"
)
try:
contents = await file.read()
df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
if 'latitude' not in df.columns or 'longitude' not in df.columns:
raise HTTPException(
status_code=400,
detail="CSV must contain 'latitude' and 'longitude' columns"
)
import numpy as np
df = df[(np.abs(df['latitude']) <= 90) & (np.abs(df['longitude']) <= 180)]
if len(df) == 0:
raise HTTPException(status_code=400, detail="No valid coordinates in CSV (lat -90..90, lon -180..180)")
# Set defaults for optional columns
if 'height' not in df.columns:
df['height'] = 0.0
if 'basement' not in df.columns:
df['basement'] = 0.0
# Process rows with async throttling
results = []
for _, row in df.iterrows():
result = await process_single_row_async(row, use_predicted_height)
results.append(result)
results_df = pd.DataFrame(results)
output = io.StringIO()
results_df.to_csv(output, index=False)
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode('utf-8')),
media_type="text/csv",
headers={
"Content-Disposition": (
"attachment; filename=vulnerability_results.csv; "
"filename*=UTF-8''vulnerability_results.csv"
)
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch processing failed: {str(e)}")
@app.post("/assess_batch_multihazard")
async def assess_batch_multihazard(file: UploadFile = File(...)) -> StreamingResponse:
if not gee_ready:
raise HTTPException(
status_code=503,
detail="GEE still initializing, try again in 10 seconds"
)
try:
contents = await file.read()
df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
if 'latitude' not in df.columns or 'longitude' not in df.columns:
raise HTTPException(
status_code=400,
detail="CSV must contain 'latitude' and 'longitude' columns"
)
results = []
for _, row in df.iterrows():
result = await process_single_row_multihazard_async(row)
results.append(result)
results_df = pd.DataFrame(results)
output = io.StringIO()
results_df.to_csv(output, index=False)
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode('utf-8')),
media_type="text/csv",
headers={
"Content-Disposition": (
"attachment; filename=multihazard_results.csv; "
"filename*=UTF-8''multihazard_results.csv"
)
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch multihazard failed: {str(e)}")
@app.post("/explain")
async def explain_assessment(data: SingleAssessment) -> Dict:
"""Assess vulnerability with SHAP explanation"""
if not gee_ready:
raise HTTPException(
status_code=503,
detail="GEE still initializing, try again in 10 seconds"
)
loop = asyncio.get_event_loop()
try:
# Run terrain in background thread
terrain = await loop.run_in_executor(
None,
get_terrain_metrics,
data.latitude,
data.longitude
)
# Throttled water distance
water_dist = await throttled_distance_to_water(data.latitude, data.longitude)
result = calculate_vulnerability_index(
lat=data.latitude,
lon=data.longitude,
height=data.height,
basement=data.basement,
terrain_metrics=terrain,
water_distance=water_dist
)
# Generate explanation if explainer available
explanation = None
if explainer:
try:
explanation = explainer.explain(result['components'])
except Exception as e:
print(f"SHAP explanation failed: {e}")
return {
"status": "success",
"input": data.dict(),
"assessment": result,
"explanation": explanation
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Assessment failed: {e}")
async def process_single_row_multihazard_async(row):
"""Process a single row with multi-hazard assessment."""
try:
from vulnerability import calculate_multi_hazard_vulnerability
lat = row['latitude']
lon = row['longitude']
height = row.get('height', 0.0)
basement = row.get('basement', 0.0)
loop = asyncio.get_event_loop()
terrain = await loop.run_in_executor(None, get_terrain_metrics, lat, lon)
water_dist = await throttled_distance_to_water(lat, lon)
result = calculate_multi_hazard_vulnerability(
lat=lat,
lon=lon,
height=height,
basement=basement,
terrain_metrics=terrain,
water_distance=water_dist
)
return {
'latitude': lat,
'longitude': lon,
'height': height,
'basement': basement,
'vulnerability_index': result['vulnerability_index'],
'ci_lower_95': result['confidence_interval']['lower_bound_95'],
'ci_upper_95': result['confidence_interval']['upper_bound_95'],
'risk_level': result['risk_level'],
'confidence': result['uncertainty_analysis']['confidence'],
'confidence_interpretation': result['uncertainty_analysis']['interpretation'],
'elevation_m': result['elevation_m'],
'tpi_m': result['relative_elevation_m'],
'slope_degrees': result['slope_degrees'],
'distance_to_water_m': result['distance_to_water_m'],
'dominant_hazard': result['dominant_hazard'],
'fluvial_risk': result['hazard_breakdown']['fluvial_riverine'],
'coastal_risk': result['hazard_breakdown']['coastal_surge'],
'pluvial_risk': result['hazard_breakdown']['pluvial_drainage'],
'combined_risk': result['hazard_breakdown']['combined_index'],
'quality_flags': ','.join(result['uncertainty_analysis']['data_quality_flags'])
if result['uncertainty_analysis']['data_quality_flags'] else ''
}
except Exception as e:
return {
'latitude': row.get('latitude'),
'longitude': row.get('longitude'),
'error': str(e),
'vulnerability_index': None
}
@app.post("/assess_multihazard")
async def assess_multihazard(data: SingleAssessment) -> Dict:
"""Multi-hazard assessment (fluvial + coastal + pluvial)"""
if not gee_ready:
raise HTTPException(
status_code=503,
detail="GEE still initializing, try again in 10 seconds"
)
loop = asyncio.get_event_loop()
try:
from vulnerability import calculate_multi_hazard_vulnerability
# Run terrain in background thread
terrain = await loop.run_in_executor(
None,
get_terrain_metrics,
data.latitude,
data.longitude
)
# Throttled water distance
water_dist = await throttled_distance_to_water(data.latitude, data.longitude)
result = calculate_multi_hazard_vulnerability(
lat=data.latitude,
lon=data.longitude,
height=data.height,
basement=data.basement,
terrain_metrics=terrain,
water_distance=water_dist
)
return {
"status": "success",
"input": data.dict(),
"assessment": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Assessment failed: {e}")
class HeightRequest(BaseModel):
latitude: float
longitude: float
@field_validator("latitude")
@classmethod
def check_lat(cls, v: float) -> float:
if not -90 <= v <= 90:
raise ValueError("Latitude must be between -90 and 90")
return v
@field_validator("longitude")
@classmethod
def check_lon(cls, v: float) -> float:
if not -180 <= v <= 180:
raise ValueError("Longitude must be between -180 and 180")
return v
@app.post("/predict_height")
async def predict_height(data: HeightRequest) -> Dict:
if DISABLE_HEIGHT_PREDICTOR:
raise HTTPException(status_code=503,
detail="Height predictor disabled on this deployment.")
if not model_ready:
raise HTTPException(
status_code=503,
detail="Height predictor still loading, try again later."
)
try:
from height_predictor.inference import get_predictor
predictor = get_predictor()
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
predictor.predict_from_coordinates,
data.latitude,
data.longitude,
)
return result
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Height prediction failed: {str(e)}",
)
# For local development
if __name__ == "__main__":
import uvicorn
import os
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)