Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException, Request, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html | |
| from pydantic import BaseModel | |
| import pandas as pd | |
| import joblib | |
| import requests | |
| import gc | |
| import os | |
| import logging | |
| from math import sin, cos, radians, pi | |
| from contextlib import asynccontextmanager | |
| # ------------------------- | |
| # Logger | |
| # ------------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| # ------------------------- | |
| # Global models | |
| # ------------------------- | |
| _occurrence_model = None | |
| _occurrence_scaler = None | |
| _severity_model = None | |
| _severity_scaler = None | |
| # ------------------------- | |
| # Feature setup | |
| # ------------------------- | |
| API_BASE = "https://power.larc.nasa.gov/api/temporal/daily/point" | |
| PARAMS = "PRECTOT,T2M,T2M_MAX,T2M_MIN,ALLSKY_SFC_SW_DWN,RH2M,WS2M" | |
| FEATURE_ORDER = [ | |
| "RH2M", "T2M_MAX", "T2M_MIN", "WS2M", "T2M", | |
| "ALLSKY_SFC_SW_DWN", "PRECTOTCORR", | |
| "lat_sin", "lat_cos", "lon_sin", "lon_cos", | |
| "month_sin", "month_cos" | |
| ] | |
| # ------------------------- | |
| # Utility functions | |
| # ------------------------- | |
| def cleanup_memory(): | |
| gc.collect() | |
| def safe_model_load(filename: str): | |
| try: | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| path = os.path.join(script_dir, filename) | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"{filename} not found") | |
| return joblib.load(path) | |
| except Exception as e: | |
| logging.error(f"Failed to load {filename}: {e}") | |
| raise HTTPException(status_code=500, detail=f"Model loading failed: {filename}") | |
| def get_occurrence_model_and_scaler(): | |
| global _occurrence_model, _occurrence_scaler | |
| if _occurrence_model is None or _occurrence_scaler is None: | |
| logging.info("Loading occurrence model/scaler...") | |
| _occurrence_model = safe_model_load("drought_occurrence_model.joblib") | |
| _occurrence_scaler = safe_model_load("drought_occurrence_model_scaler.joblib") | |
| cleanup_memory() | |
| return _occurrence_model, _occurrence_scaler | |
| def get_severity_model_and_scaler(): | |
| global _severity_model, _severity_scaler | |
| if _severity_model is None or _severity_scaler is None: | |
| logging.info("Loading severity model/scaler...") | |
| _severity_model = safe_model_load("drought_severity_model.joblib") | |
| _severity_scaler = safe_model_load("drought_severity_model_scaler.joblib") | |
| cleanup_memory() | |
| return _severity_model, _severity_scaler | |
| # ------------------------- | |
| # Lifespan | |
| # ------------------------- | |
| async def lifespan(app: FastAPI): | |
| logging.info("π Drought API starting (models load on first request)") | |
| cleanup_memory() | |
| yield | |
| logging.info("π Shutting down API") | |
| global _occurrence_model, _occurrence_scaler, _severity_model, _severity_scaler | |
| _occurrence_model = _occurrence_scaler = _severity_model = _severity_scaler = None | |
| cleanup_memory() | |
| # ------------------------- | |
| # FastAPI instance | |
| # ------------------------- | |
| app = FastAPI( | |
| title="π Drought Prediction API", | |
| version="2.4", | |
| description="Memory-optimized drought prediction API", | |
| lifespan=lifespan | |
| ) | |
| # ------------------------- | |
| # CORS middleware for website | |
| # ------------------------- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # replace with website URL in production | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| # ------------------------- | |
| # Request model | |
| # ------------------------- | |
| class PredictionRequest(BaseModel): | |
| lat: float | |
| lon: float | |
| time: str # YYYY-MM-DD | |
| # ------------------------- | |
| # NASA feature fetcher | |
| # ------------------------- | |
| def fetch_features(lat, lon, time_str: str) -> dict: | |
| end = pd.to_datetime(time_str) | |
| start = end - pd.Timedelta(days=90) | |
| params = { | |
| "latitude": lat, | |
| "longitude": lon, | |
| "start": start.strftime("%Y%m%d"), | |
| "end": end.strftime("%Y%m%d"), | |
| "parameters": PARAMS, | |
| "format": "JSON", | |
| "community": "AG" | |
| } | |
| try: | |
| response = requests.get(API_BASE, params=params, timeout=30) | |
| response.raise_for_status() | |
| data = response.json().get("properties", {}).get("parameter", {}) | |
| features = {} | |
| for p, vals in data.items(): | |
| values = [v for v in vals.values() if v is not None] | |
| if values: | |
| features["PRECTOTCORR" if p=="PRECTOT" else p] = sum(values)/len(values) if p!="PRECTOT" else sum(values) | |
| features.update({ | |
| "lat_sin": sin(radians(lat)), | |
| "lat_cos": cos(radians(lat)), | |
| "lon_sin": sin(radians(lon)), | |
| "lon_cos": cos(radians(lon)), | |
| "month_sin": sin(2*pi*end.month/12), | |
| "month_cos": cos(2*pi*end.month/12) | |
| }) | |
| missing = [f for f in FEATURE_ORDER if f not in features] | |
| if missing: | |
| raise HTTPException(status_code=500, detail=f"Missing features: {missing}") | |
| cleanup_memory() | |
| return features | |
| except Exception as e: | |
| logging.error(f"NASA fetch error: {e}") | |
| raise HTTPException(status_code=502, detail="NASA API request failed") | |
| # ------------------------- | |
| # Prediction endpoint | |
| # ------------------------- | |
| async def predict(req: PredictionRequest): | |
| try: | |
| features = fetch_features(req.lat, req.lon, req.time) | |
| X = pd.DataFrame([[features[f] for f in FEATURE_ORDER]], columns=FEATURE_ORDER) | |
| occ_model, occ_scaler = get_occurrence_model_and_scaler() | |
| sev_model, sev_scaler = get_severity_model_and_scaler() | |
| X_occ = occ_scaler.transform(X) | |
| X_sev = sev_scaler.transform(X) | |
| occurrence_pred = int(occ_model.predict(X_occ)[0]) | |
| occurrence_proba = occ_model.predict_proba(X_occ)[0].tolist() | |
| severity_pred = int(sev_model.predict(X_sev)[0]) | |
| severity_proba = sev_model.predict_proba(X_sev)[0].tolist() | |
| result = { | |
| "input": {"lat": req.lat, "lon": req.lon, "time": req.time}, | |
| "occurrence": {"prediction": occurrence_pred, "probabilities": occurrence_proba}, | |
| "severity": {"prediction": severity_pred, "probabilities": severity_proba}, | |
| "features_used": {k: round(v,4) for k,v in zip(FEATURE_ORDER, X.iloc[0].tolist())} | |
| } | |
| cleanup_memory() | |
| return result | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| logging.error(f"Prediction error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ------------------------- | |
| # Health check | |
| # ------------------------- | |
| async def health_check(request: Request): | |
| if request.method == "HEAD": | |
| return Response(status_code=200) | |
| return {"status": "healthy", "api_version": "2.4"} | |
| # ------------------------- | |
| # Debug endpoint | |
| # ------------------------- | |
| async def debug_info(): | |
| return { | |
| "models_loaded": { | |
| "occurrence_model": _occurrence_model is not None, | |
| "occurrence_scaler": _occurrence_scaler is not None, | |
| "severity_model": _severity_model is not None, | |
| "severity_scaler": _severity_scaler is not None | |
| }, | |
| "feature_order": FEATURE_ORDER | |
| } | |
| # ------------------------- | |
| # Test endpoint | |
| # ------------------------- | |
| async def test_prediction(): | |
| try: | |
| test_req = PredictionRequest(lat=40.7128, lon=-74.0060, time="2024-08-15") | |
| result = await predict(test_req) | |
| return {"test_status": "success", "result": result} | |
| except Exception as e: | |
| return {"test_status": "failed", "error": str(e)} | |
| # ------------------------- | |
| # Root endpoint | |
| # ------------------------- | |
| async def root(): | |
| return { | |
| "message": "π Drought Prediction API", | |
| "version": "2.4", | |
| "endpoints": { | |
| "predict": "/predict", | |
| "health": "/health", | |
| "debug": "/debug", | |
| "test": "/test", | |
| "docs": "/docs", | |
| "redoc": "/redoc" | |
| } | |
| } | |
| # ------------------------- | |
| # Swagger UI and Redoc | |
| # ------------------------- | |
| async def custom_swagger_ui(): | |
| return get_swagger_ui_html(openapi_url="/openapi.json", title="API Docs") | |
| async def custom_redoc(): | |
| return get_redoc_html(openapi_url="/openapi.json", title="ReDoc") | |