from fastapi import FastAPI, HTTPException, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html from pydantic import BaseModel import pandas as pd import joblib import requests import gc import os import logging from math import sin, cos, radians, pi from contextlib import asynccontextmanager # ------------------------- # Logger # ------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # ------------------------- # Global models # ------------------------- _occurrence_model = None _occurrence_scaler = None _severity_model = None _severity_scaler = None # ------------------------- # Feature setup # ------------------------- API_BASE = "https://power.larc.nasa.gov/api/temporal/daily/point" PARAMS = "PRECTOT,T2M,T2M_MAX,T2M_MIN,ALLSKY_SFC_SW_DWN,RH2M,WS2M" FEATURE_ORDER = [ "RH2M", "T2M_MAX", "T2M_MIN", "WS2M", "T2M", "ALLSKY_SFC_SW_DWN", "PRECTOTCORR", "lat_sin", "lat_cos", "lon_sin", "lon_cos", "month_sin", "month_cos" ] # ------------------------- # Utility functions # ------------------------- def cleanup_memory(): gc.collect() def safe_model_load(filename: str): try: script_dir = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(script_dir, filename) if not os.path.exists(path): raise FileNotFoundError(f"{filename} not found") return joblib.load(path) except Exception as e: logging.error(f"Failed to load {filename}: {e}") raise HTTPException(status_code=500, detail=f"Model loading failed: {filename}") def get_occurrence_model_and_scaler(): global _occurrence_model, _occurrence_scaler if _occurrence_model is None or _occurrence_scaler is None: logging.info("Loading occurrence model/scaler...") _occurrence_model = safe_model_load("drought_occurrence_model.joblib") _occurrence_scaler = safe_model_load("drought_occurrence_model_scaler.joblib") cleanup_memory() return _occurrence_model, _occurrence_scaler def get_severity_model_and_scaler(): global _severity_model, _severity_scaler if _severity_model is None or _severity_scaler is None: logging.info("Loading severity model/scaler...") _severity_model = safe_model_load("drought_severity_model.joblib") _severity_scaler = safe_model_load("drought_severity_model_scaler.joblib") cleanup_memory() return _severity_model, _severity_scaler # ------------------------- # Lifespan # ------------------------- @asynccontextmanager async def lifespan(app: FastAPI): logging.info("🚀 Drought API starting (models load on first request)") cleanup_memory() yield logging.info("🛑 Shutting down API") global _occurrence_model, _occurrence_scaler, _severity_model, _severity_scaler _occurrence_model = _occurrence_scaler = _severity_model = _severity_scaler = None cleanup_memory() # ------------------------- # FastAPI instance # ------------------------- app = FastAPI( title="🌍 Drought Prediction API", version="2.4", description="Memory-optimized drought prediction API", lifespan=lifespan ) # ------------------------- # CORS middleware for website # ------------------------- app.add_middleware( CORSMiddleware, allow_origins=["*"], # replace with website URL in production allow_methods=["*"], allow_headers=["*"] ) # ------------------------- # Request model # ------------------------- class PredictionRequest(BaseModel): lat: float lon: float time: str # YYYY-MM-DD # ------------------------- # NASA feature fetcher # ------------------------- def fetch_features(lat, lon, time_str: str) -> dict: end = pd.to_datetime(time_str) start = end - pd.Timedelta(days=90) params = { "latitude": lat, "longitude": lon, "start": start.strftime("%Y%m%d"), "end": end.strftime("%Y%m%d"), "parameters": PARAMS, "format": "JSON", "community": "AG" } try: response = requests.get(API_BASE, params=params, timeout=30) response.raise_for_status() data = response.json().get("properties", {}).get("parameter", {}) features = {} for p, vals in data.items(): values = [v for v in vals.values() if v is not None] if values: features["PRECTOTCORR" if p=="PRECTOT" else p] = sum(values)/len(values) if p!="PRECTOT" else sum(values) features.update({ "lat_sin": sin(radians(lat)), "lat_cos": cos(radians(lat)), "lon_sin": sin(radians(lon)), "lon_cos": cos(radians(lon)), "month_sin": sin(2*pi*end.month/12), "month_cos": cos(2*pi*end.month/12) }) missing = [f for f in FEATURE_ORDER if f not in features] if missing: raise HTTPException(status_code=500, detail=f"Missing features: {missing}") cleanup_memory() return features except Exception as e: logging.error(f"NASA fetch error: {e}") raise HTTPException(status_code=502, detail="NASA API request failed") # ------------------------- # Prediction endpoint # ------------------------- @app.post("/predict") async def predict(req: PredictionRequest): try: features = fetch_features(req.lat, req.lon, req.time) X = pd.DataFrame([[features[f] for f in FEATURE_ORDER]], columns=FEATURE_ORDER) occ_model, occ_scaler = get_occurrence_model_and_scaler() sev_model, sev_scaler = get_severity_model_and_scaler() X_occ = occ_scaler.transform(X) X_sev = sev_scaler.transform(X) occurrence_pred = int(occ_model.predict(X_occ)[0]) occurrence_proba = occ_model.predict_proba(X_occ)[0].tolist() severity_pred = int(sev_model.predict(X_sev)[0]) severity_proba = sev_model.predict_proba(X_sev)[0].tolist() result = { "input": {"lat": req.lat, "lon": req.lon, "time": req.time}, "occurrence": {"prediction": occurrence_pred, "probabilities": occurrence_proba}, "severity": {"prediction": severity_pred, "probabilities": severity_proba}, "features_used": {k: round(v,4) for k,v in zip(FEATURE_ORDER, X.iloc[0].tolist())} } cleanup_memory() return result except HTTPException as e: raise e except Exception as e: logging.error(f"Prediction error: {e}") raise HTTPException(status_code=500, detail=str(e)) # ------------------------- # Health check # ------------------------- @app.api_route("/health", methods=["GET", "HEAD"]) async def health_check(request: Request): if request.method == "HEAD": return Response(status_code=200) return {"status": "healthy", "api_version": "2.4"} # ------------------------- # Debug endpoint # ------------------------- @app.get("/debug") async def debug_info(): return { "models_loaded": { "occurrence_model": _occurrence_model is not None, "occurrence_scaler": _occurrence_scaler is not None, "severity_model": _severity_model is not None, "severity_scaler": _severity_scaler is not None }, "feature_order": FEATURE_ORDER } # ------------------------- # Test endpoint # ------------------------- @app.get("/test") async def test_prediction(): try: test_req = PredictionRequest(lat=40.7128, lon=-74.0060, time="2024-08-15") result = await predict(test_req) return {"test_status": "success", "result": result} except Exception as e: return {"test_status": "failed", "error": str(e)} # ------------------------- # Root endpoint # ------------------------- @app.get("/") async def root(): return { "message": "🌍 Drought Prediction API", "version": "2.4", "endpoints": { "predict": "/predict", "health": "/health", "debug": "/debug", "test": "/test", "docs": "/docs", "redoc": "/redoc" } } # ------------------------- # Swagger UI and Redoc # ------------------------- @app.get("/docs", include_in_schema=False) async def custom_swagger_ui(): return get_swagger_ui_html(openapi_url="/openapi.json", title="API Docs") @app.get("/redoc", include_in_schema=False) async def custom_redoc(): return get_redoc_html(openapi_url="/openapi.json", title="ReDoc")