from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field import numpy as np import torch import pickle import json import os from typing import Optional # ── App Setup ───────────────────────────────────────────────────────────────── app = FastAPI( title="PSInSAR Deformation Forecast API", description="PINN-based ground deformation risk forecasting from PSInSAR data", version="1.0.0", ) # ── Global state (loaded once at startup) ───────────────────────────────────── scaler = None cfg = None model = None df = None df_clean = None DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── These must match your training setup ────────────────────────────────────── FEATURE_COLS = [] # ← replace with your actual feature column names PHYSICS_COLS = [] # ← replace with your actual physics column names SEQ_LEN = 10 # ← replace with your actual sequence length HORIZON = 3 # ← replace with your actual horizon N_PASSES = 50 # MC Dropout passes # ── Request / Response Schemas ──────────────────────────────────────────────── class ForecastRequest(BaseModel): lat: float = Field(..., description="Target latitude", example=22.360001) lon: float = Field(..., description="Target longitude", example=82.530869) tolerance: Optional[float] = Field( 0.001, description="Search radius in degrees to find nearest PS point" ) class EpochForecast(BaseModel): day: float failure_probability: float uncertainty_std: float high_risk: bool class ForecastResponse(BaseModel): ps_id: str actual_lat: float actual_lon: float total_epochs: int forecast_count: int high_risk_count: int high_risk_pct: float mean_failure_probability: float mean_uncertainty: float first_alarm_day: Optional[float] threshold_used: float model_auc: float model_pr_auc: float forecasts: list[EpochForecast] # ── Startup: load model & data ───────────────────────────────────────────────── @app.on_event("startup") def load_assets(): global scaler, cfg, model, df, df_clean MODEL_PATH = os.getenv("MODEL_PATH", "artifacts/pinn_best.pt") SCALER_PATH = os.getenv("SCALER_PATH", "artifacts/scaler.pkl") CONFIG_PATH = os.getenv("CONFIG_PATH", "artifacts/model_config.json") # ── 1. Scaler ────────────────────────────────────────────────────────────── with open(SCALER_PATH, "rb") as f: scaler = pickle.load(f) # ── 2. Config ────────────────────────────────────────────────────────────── with open(CONFIG_PATH, "r") as f: cfg = json.load(f) # ── 3. Model ─────────────────────────────────────────────────────────────── # OPTION A (recommended): instantiate your model class, then load weights # # from your_model_module import YourPINNModel # model = YourPINNModel(**cfg["model_params"]).to(DEVICE) # checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True) # model.load_state_dict(checkpoint) # # OPTION B (fallback): load the entire pickled model object # Use this if pinn_best.pt was saved with torch.save(model, path) model = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) model.eval() # ── 4. Data ──────────────────────────────────────────────────────────────── # import pandas as pd # df = pd.read_parquet(os.getenv("DATA_PATH", "artifacts/ps_data.parquet")) # df_clean = pd.read_parquet(os.getenv("DATA_CLEAN_PATH", "artifacts/ps_data_clean.parquet")) print(f"Assets loaded | Device={DEVICE} | Threshold={cfg.get('best_threshold')}") # ── Helper: find nearest PS point ───────────────────────────────────────────── def get_ps_by_latlon(lat: float, lon: float, tol: float = 0.001) -> str: mask = ( (np.abs(df["lat"] - lat) <= tol) & (np.abs(df["lon"] - lon) <= tol) ) matches = df[mask] if len(matches) == 0: # Fallback: absolute nearest point dist = np.sqrt((df["lat"] - lat) ** 2 + (df["lon"] - lon) ** 2) nearest = df.loc[dist.idxmin()] return str(nearest["ps_id"]), nearest["lat"], nearest["lon"], True matches = matches.copy() matches["_dist"] = np.sqrt( (matches["lat"] - lat) ** 2 + (matches["lon"] - lon) ** 2 ) row = matches.loc[matches["_dist"].idxmin()] return str(row["ps_id"]), row["lat"], row["lon"], False # ── Forecast endpoint ────────────────────────────────────────────────────────── @app.post("/forecast", response_model=ForecastResponse) def forecast(req: ForecastRequest): try: ps_id, actual_lat, actual_lon, used_fallback = get_ps_by_latlon( req.lat, req.lon, req.tolerance ) except Exception as e: raise HTTPException(status_code=404, detail=f"Could not find PS point: {e}") # Load time series for this PS point ps_raw = ( df[df["ps_id"] == ps_id] .sort_values("days_since_start") .reset_index(drop=True) ) ps_clean = ( df_clean[df_clean["ps_id"] == ps_id] .sort_values("days_since_start") .reset_index(drop=True) ) if len(ps_clean) < SEQ_LEN + HORIZON + 1: raise HTTPException( status_code=422, detail=f"Insufficient data for PS point {ps_id} " f"(need >{SEQ_LEN + HORIZON} epochs, got {len(ps_clean)})", ) days_all = ps_raw["days_since_start"].values disp_all = ps_raw["cumulative_disp_mm"].values feats = ps_clean[FEATURE_COLS].values.astype(np.float32) physics = ps_clean[PHYSICS_COLS].values.astype(np.float32) threshold = cfg["best_threshold"] epoch_forecasts = [] for i in range(SEQ_LEN, len(ps_clean) - HORIZON): x_seq = torch.tensor(feats[i - SEQ_LEN:i]).unsqueeze(0).to(DEVICE) p_vec = torch.tensor(physics[i]).unsqueeze(0).to(DEVICE) preds = [] for _ in range(N_PASSES): with torch.no_grad(): preds.append(torch.sigmoid(model(x_seq, p_vec)).item()) fcst_idx = i + HORIZON mean_p = float(np.mean(preds)) std_p = float(np.std(preds)) high_risk = mean_p >= threshold epoch_forecasts.append( EpochForecast( day=float(days_all[fcst_idx]), failure_probability=round(mean_p, 6), uncertainty_std=round(std_p, 6), high_risk=high_risk, ) ) # Aggregate stats forecast_days = np.array([e.day for e in epoch_forecasts]) forecast_mean = np.array([e.failure_probability for e in epoch_forecasts]) forecast_std = np.array([e.uncertainty_std for e in epoch_forecasts]) forecast_risk = np.array([e.high_risk for e in epoch_forecasts]) n_risk = int(forecast_risk.sum()) first_alarm = ( float(forecast_days[forecast_risk == 1][0]) if n_risk > 0 else None ) return ForecastResponse( ps_id=ps_id, actual_lat=float(actual_lat), actual_lon=float(actual_lon), total_epochs=len(ps_raw), forecast_count=len(epoch_forecasts), high_risk_count=n_risk, high_risk_pct=round(n_risk / len(epoch_forecasts) * 100, 2), mean_failure_probability=round(float(forecast_mean.mean()), 6), mean_uncertainty=round(float(forecast_std.mean()), 6), first_alarm_day=first_alarm, threshold_used=threshold, model_auc=cfg["test_auc"], model_pr_auc=cfg["test_pr_auc"], forecasts=epoch_forecasts, ) # ── Health check ─────────────────────────────────────────────────────────────── @app.get("/health") def health(): return {"status": "ok", "device": str(DEVICE)} # ── Run locally ──────────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)