Vikctor's picture
Update app.py
a772142 verified
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from pydantic import BaseModel
import pandas as pd
import joblib
import requests
import gc
import os
import logging
from math import sin, cos, radians, pi
from contextlib import asynccontextmanager
# -------------------------
# Logger
# -------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
# -------------------------
# Global models
# -------------------------
_occurrence_model = None
_occurrence_scaler = None
_severity_model = None
_severity_scaler = None
# -------------------------
# Feature setup
# -------------------------
API_BASE = "https://power.larc.nasa.gov/api/temporal/daily/point"
PARAMS = "PRECTOT,T2M,T2M_MAX,T2M_MIN,ALLSKY_SFC_SW_DWN,RH2M,WS2M"
FEATURE_ORDER = [
"RH2M", "T2M_MAX", "T2M_MIN", "WS2M", "T2M",
"ALLSKY_SFC_SW_DWN", "PRECTOTCORR",
"lat_sin", "lat_cos", "lon_sin", "lon_cos",
"month_sin", "month_cos"
]
# -------------------------
# Utility functions
# -------------------------
def cleanup_memory():
gc.collect()
def safe_model_load(filename: str):
try:
script_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(script_dir, filename)
if not os.path.exists(path):
raise FileNotFoundError(f"{filename} not found")
return joblib.load(path)
except Exception as e:
logging.error(f"Failed to load {filename}: {e}")
raise HTTPException(status_code=500, detail=f"Model loading failed: {filename}")
def get_occurrence_model_and_scaler():
global _occurrence_model, _occurrence_scaler
if _occurrence_model is None or _occurrence_scaler is None:
logging.info("Loading occurrence model/scaler...")
_occurrence_model = safe_model_load("drought_occurrence_model.joblib")
_occurrence_scaler = safe_model_load("drought_occurrence_model_scaler.joblib")
cleanup_memory()
return _occurrence_model, _occurrence_scaler
def get_severity_model_and_scaler():
global _severity_model, _severity_scaler
if _severity_model is None or _severity_scaler is None:
logging.info("Loading severity model/scaler...")
_severity_model = safe_model_load("drought_severity_model.joblib")
_severity_scaler = safe_model_load("drought_severity_model_scaler.joblib")
cleanup_memory()
return _severity_model, _severity_scaler
# -------------------------
# Lifespan
# -------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
logging.info("πŸš€ Drought API starting (models load on first request)")
cleanup_memory()
yield
logging.info("πŸ›‘ Shutting down API")
global _occurrence_model, _occurrence_scaler, _severity_model, _severity_scaler
_occurrence_model = _occurrence_scaler = _severity_model = _severity_scaler = None
cleanup_memory()
# -------------------------
# FastAPI instance
# -------------------------
app = FastAPI(
title="🌍 Drought Prediction API",
version="2.4",
description="Memory-optimized drought prediction API",
lifespan=lifespan
)
# -------------------------
# CORS middleware for website
# -------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # replace with website URL in production
allow_methods=["*"],
allow_headers=["*"]
)
# -------------------------
# Request model
# -------------------------
class PredictionRequest(BaseModel):
lat: float
lon: float
time: str # YYYY-MM-DD
# -------------------------
# NASA feature fetcher
# -------------------------
def fetch_features(lat, lon, time_str: str) -> dict:
end = pd.to_datetime(time_str)
start = end - pd.Timedelta(days=90)
params = {
"latitude": lat,
"longitude": lon,
"start": start.strftime("%Y%m%d"),
"end": end.strftime("%Y%m%d"),
"parameters": PARAMS,
"format": "JSON",
"community": "AG"
}
try:
response = requests.get(API_BASE, params=params, timeout=30)
response.raise_for_status()
data = response.json().get("properties", {}).get("parameter", {})
features = {}
for p, vals in data.items():
values = [v for v in vals.values() if v is not None]
if values:
features["PRECTOTCORR" if p=="PRECTOT" else p] = sum(values)/len(values) if p!="PRECTOT" else sum(values)
features.update({
"lat_sin": sin(radians(lat)),
"lat_cos": cos(radians(lat)),
"lon_sin": sin(radians(lon)),
"lon_cos": cos(radians(lon)),
"month_sin": sin(2*pi*end.month/12),
"month_cos": cos(2*pi*end.month/12)
})
missing = [f for f in FEATURE_ORDER if f not in features]
if missing:
raise HTTPException(status_code=500, detail=f"Missing features: {missing}")
cleanup_memory()
return features
except Exception as e:
logging.error(f"NASA fetch error: {e}")
raise HTTPException(status_code=502, detail="NASA API request failed")
# -------------------------
# Prediction endpoint
# -------------------------
@app.post("/predict")
async def predict(req: PredictionRequest):
try:
features = fetch_features(req.lat, req.lon, req.time)
X = pd.DataFrame([[features[f] for f in FEATURE_ORDER]], columns=FEATURE_ORDER)
occ_model, occ_scaler = get_occurrence_model_and_scaler()
sev_model, sev_scaler = get_severity_model_and_scaler()
X_occ = occ_scaler.transform(X)
X_sev = sev_scaler.transform(X)
occurrence_pred = int(occ_model.predict(X_occ)[0])
occurrence_proba = occ_model.predict_proba(X_occ)[0].tolist()
severity_pred = int(sev_model.predict(X_sev)[0])
severity_proba = sev_model.predict_proba(X_sev)[0].tolist()
result = {
"input": {"lat": req.lat, "lon": req.lon, "time": req.time},
"occurrence": {"prediction": occurrence_pred, "probabilities": occurrence_proba},
"severity": {"prediction": severity_pred, "probabilities": severity_proba},
"features_used": {k: round(v,4) for k,v in zip(FEATURE_ORDER, X.iloc[0].tolist())}
}
cleanup_memory()
return result
except HTTPException as e:
raise e
except Exception as e:
logging.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# -------------------------
# Health check
# -------------------------
@app.api_route("/health", methods=["GET", "HEAD"])
async def health_check(request: Request):
if request.method == "HEAD":
return Response(status_code=200)
return {"status": "healthy", "api_version": "2.4"}
# -------------------------
# Debug endpoint
# -------------------------
@app.get("/debug")
async def debug_info():
return {
"models_loaded": {
"occurrence_model": _occurrence_model is not None,
"occurrence_scaler": _occurrence_scaler is not None,
"severity_model": _severity_model is not None,
"severity_scaler": _severity_scaler is not None
},
"feature_order": FEATURE_ORDER
}
# -------------------------
# Test endpoint
# -------------------------
@app.get("/test")
async def test_prediction():
try:
test_req = PredictionRequest(lat=40.7128, lon=-74.0060, time="2024-08-15")
result = await predict(test_req)
return {"test_status": "success", "result": result}
except Exception as e:
return {"test_status": "failed", "error": str(e)}
# -------------------------
# Root endpoint
# -------------------------
@app.get("/")
async def root():
return {
"message": "🌍 Drought Prediction API",
"version": "2.4",
"endpoints": {
"predict": "/predict",
"health": "/health",
"debug": "/debug",
"test": "/test",
"docs": "/docs",
"redoc": "/redoc"
}
}
# -------------------------
# Swagger UI and Redoc
# -------------------------
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui():
return get_swagger_ui_html(openapi_url="/openapi.json", title="API Docs")
@app.get("/redoc", include_in_schema=False)
async def custom_redoc():
return get_redoc_html(openapi_url="/openapi.json", title="ReDoc")