Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| import numpy as np | |
| import torch | |
| import pickle | |
| import json | |
| import os | |
| from typing import Optional | |
| # ββ App Setup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="PSInSAR Deformation Forecast API", | |
| description="PINN-based ground deformation risk forecasting from PSInSAR data", | |
| version="1.0.0", | |
| ) | |
| # ββ Global state (loaded once at startup) βββββββββββββββββββββββββββββββββββββ | |
| scaler = None | |
| cfg = None | |
| model = None | |
| df = None | |
| df_clean = None | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ββ These must match your training setup ββββββββββββββββββββββββββββββββββββββ | |
| FEATURE_COLS = [] # β replace with your actual feature column names | |
| PHYSICS_COLS = [] # β replace with your actual physics column names | |
| SEQ_LEN = 10 # β replace with your actual sequence length | |
| HORIZON = 3 # β replace with your actual horizon | |
| N_PASSES = 50 # MC Dropout passes | |
| # ββ Request / Response Schemas ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ForecastRequest(BaseModel): | |
| lat: float = Field(..., description="Target latitude", example=22.360001) | |
| lon: float = Field(..., description="Target longitude", example=82.530869) | |
| tolerance: Optional[float] = Field( | |
| 0.001, description="Search radius in degrees to find nearest PS point" | |
| ) | |
| class EpochForecast(BaseModel): | |
| day: float | |
| failure_probability: float | |
| uncertainty_std: float | |
| high_risk: bool | |
| class ForecastResponse(BaseModel): | |
| ps_id: str | |
| actual_lat: float | |
| actual_lon: float | |
| total_epochs: int | |
| forecast_count: int | |
| high_risk_count: int | |
| high_risk_pct: float | |
| mean_failure_probability: float | |
| mean_uncertainty: float | |
| first_alarm_day: Optional[float] | |
| threshold_used: float | |
| model_auc: float | |
| model_pr_auc: float | |
| forecasts: list[EpochForecast] | |
| # ββ Startup: load model & data βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_assets(): | |
| global scaler, cfg, model, df, df_clean | |
| MODEL_PATH = os.getenv("MODEL_PATH", "artifacts/pinn_best.pt") | |
| SCALER_PATH = os.getenv("SCALER_PATH", "artifacts/scaler.pkl") | |
| CONFIG_PATH = os.getenv("CONFIG_PATH", "artifacts/model_config.json") | |
| # ββ 1. Scaler ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with open(SCALER_PATH, "rb") as f: | |
| scaler = pickle.load(f) | |
| # ββ 2. Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with open(CONFIG_PATH, "r") as f: | |
| cfg = json.load(f) | |
| # ββ 3. Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # OPTION A (recommended): instantiate your model class, then load weights | |
| # | |
| # from your_model_module import YourPINNModel | |
| # model = YourPINNModel(**cfg["model_params"]).to(DEVICE) | |
| # checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True) | |
| # model.load_state_dict(checkpoint) | |
| # | |
| # OPTION B (fallback): load the entire pickled model object | |
| # Use this if pinn_best.pt was saved with torch.save(model, path) | |
| model = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) | |
| model.eval() | |
| # ββ 4. Data ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # import pandas as pd | |
| # df = pd.read_parquet(os.getenv("DATA_PATH", "artifacts/ps_data.parquet")) | |
| # df_clean = pd.read_parquet(os.getenv("DATA_CLEAN_PATH", "artifacts/ps_data_clean.parquet")) | |
| print(f"Assets loaded | Device={DEVICE} | Threshold={cfg.get('best_threshold')}") | |
| # ββ Helper: find nearest PS point βββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_ps_by_latlon(lat: float, lon: float, tol: float = 0.001) -> str: | |
| mask = ( | |
| (np.abs(df["lat"] - lat) <= tol) & | |
| (np.abs(df["lon"] - lon) <= tol) | |
| ) | |
| matches = df[mask] | |
| if len(matches) == 0: | |
| # Fallback: absolute nearest point | |
| dist = np.sqrt((df["lat"] - lat) ** 2 + (df["lon"] - lon) ** 2) | |
| nearest = df.loc[dist.idxmin()] | |
| return str(nearest["ps_id"]), nearest["lat"], nearest["lon"], True | |
| matches = matches.copy() | |
| matches["_dist"] = np.sqrt( | |
| (matches["lat"] - lat) ** 2 + (matches["lon"] - lon) ** 2 | |
| ) | |
| row = matches.loc[matches["_dist"].idxmin()] | |
| return str(row["ps_id"]), row["lat"], row["lon"], False | |
| # ββ Forecast endpoint ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def forecast(req: ForecastRequest): | |
| try: | |
| ps_id, actual_lat, actual_lon, used_fallback = get_ps_by_latlon( | |
| req.lat, req.lon, req.tolerance | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=404, detail=f"Could not find PS point: {e}") | |
| # Load time series for this PS point | |
| ps_raw = ( | |
| df[df["ps_id"] == ps_id] | |
| .sort_values("days_since_start") | |
| .reset_index(drop=True) | |
| ) | |
| ps_clean = ( | |
| df_clean[df_clean["ps_id"] == ps_id] | |
| .sort_values("days_since_start") | |
| .reset_index(drop=True) | |
| ) | |
| if len(ps_clean) < SEQ_LEN + HORIZON + 1: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Insufficient data for PS point {ps_id} " | |
| f"(need >{SEQ_LEN + HORIZON} epochs, got {len(ps_clean)})", | |
| ) | |
| days_all = ps_raw["days_since_start"].values | |
| disp_all = ps_raw["cumulative_disp_mm"].values | |
| feats = ps_clean[FEATURE_COLS].values.astype(np.float32) | |
| physics = ps_clean[PHYSICS_COLS].values.astype(np.float32) | |
| threshold = cfg["best_threshold"] | |
| epoch_forecasts = [] | |
| for i in range(SEQ_LEN, len(ps_clean) - HORIZON): | |
| x_seq = torch.tensor(feats[i - SEQ_LEN:i]).unsqueeze(0).to(DEVICE) | |
| p_vec = torch.tensor(physics[i]).unsqueeze(0).to(DEVICE) | |
| preds = [] | |
| for _ in range(N_PASSES): | |
| with torch.no_grad(): | |
| preds.append(torch.sigmoid(model(x_seq, p_vec)).item()) | |
| fcst_idx = i + HORIZON | |
| mean_p = float(np.mean(preds)) | |
| std_p = float(np.std(preds)) | |
| high_risk = mean_p >= threshold | |
| epoch_forecasts.append( | |
| EpochForecast( | |
| day=float(days_all[fcst_idx]), | |
| failure_probability=round(mean_p, 6), | |
| uncertainty_std=round(std_p, 6), | |
| high_risk=high_risk, | |
| ) | |
| ) | |
| # Aggregate stats | |
| forecast_days = np.array([e.day for e in epoch_forecasts]) | |
| forecast_mean = np.array([e.failure_probability for e in epoch_forecasts]) | |
| forecast_std = np.array([e.uncertainty_std for e in epoch_forecasts]) | |
| forecast_risk = np.array([e.high_risk for e in epoch_forecasts]) | |
| n_risk = int(forecast_risk.sum()) | |
| first_alarm = ( | |
| float(forecast_days[forecast_risk == 1][0]) if n_risk > 0 else None | |
| ) | |
| return ForecastResponse( | |
| ps_id=ps_id, | |
| actual_lat=float(actual_lat), | |
| actual_lon=float(actual_lon), | |
| total_epochs=len(ps_raw), | |
| forecast_count=len(epoch_forecasts), | |
| high_risk_count=n_risk, | |
| high_risk_pct=round(n_risk / len(epoch_forecasts) * 100, 2), | |
| mean_failure_probability=round(float(forecast_mean.mean()), 6), | |
| mean_uncertainty=round(float(forecast_std.mean()), 6), | |
| first_alarm_day=first_alarm, | |
| threshold_used=threshold, | |
| model_auc=cfg["test_auc"], | |
| model_pr_auc=cfg["test_pr_auc"], | |
| forecasts=epoch_forecasts, | |
| ) | |
| # ββ Health check βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return {"status": "ok", "device": str(DEVICE)} | |
| # ββ Run locally ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False) |