agentsay commited on
Commit
b9abc4e
Β·
verified Β·
1 Parent(s): 1450eec

Upload 6 files

Browse files
Files changed (6) hide show
  1. Dockerfile +36 -0
  2. main.py +215 -0
  3. model_config.json +36 -0
  4. pinn_best.pt +3 -0
  5. requirements.txt +8 -0
  6. scaler.pkl +3 -0
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─────────────────────────────────────────────
2
+ # Base image (lightweight Python)
3
+ # ─────────────────────────────────────────────
4
+ FROM python:3.10-slim
5
+
6
+ # Prevent python from buffering stdout/stderr
7
+ ENV PYTHONUNBUFFERED=1
8
+ ENV PYTHONDONTWRITEBYTECODE=1
9
+
10
+ # Set working directory
11
+ WORKDIR /app
12
+
13
+ # Install system dependencies (needed for torch, pandas, etc.)
14
+ RUN apt-get update && apt-get install -y \
15
+ build-essential \
16
+ curl \
17
+ git \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ # Copy requirements first (better Docker caching)
21
+ COPY requirements.txt .
22
+
23
+ # Upgrade pip
24
+ RUN pip install --upgrade pip
25
+
26
+ # Install Python dependencies
27
+ RUN pip install --no-cache-dir -r requirements.txt
28
+
29
+ # Copy entire app
30
+ COPY . .
31
+
32
+ # Expose port (HF uses 7860 internally)
33
+ EXPOSE 7860
34
+
35
+ # Start FastAPI using uvicorn
36
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel, Field
3
+ import numpy as np
4
+ import torch
5
+ import pickle
6
+ import json
7
+ import os
8
+ from typing import Optional
9
+
10
+ # ── App Setup ─────────────────────────────────────────────────────────────────
11
+ app = FastAPI(
12
+ title="PSInSAR Deformation Forecast API",
13
+ description="PINN-based ground deformation risk forecasting from PSInSAR data",
14
+ version="1.0.0",
15
+ )
16
+
17
+ # ── Global state (loaded once at startup) ─────────────────────────────────────
18
+ scaler = None
19
+ cfg = None
20
+ model = None
21
+ df = None
22
+ df_clean = None
23
+
24
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # ── These must match your training setup ──────────────────────────────────────
27
+ FEATURE_COLS = [] # ← replace with your actual feature column names
28
+ PHYSICS_COLS = [] # ← replace with your actual physics column names
29
+ SEQ_LEN = 10 # ← replace with your actual sequence length
30
+ HORIZON = 3 # ← replace with your actual horizon
31
+ N_PASSES = 50 # MC Dropout passes
32
+
33
+
34
+ # ── Request / Response Schemas ────────────────────────────────────────────────
35
+ class ForecastRequest(BaseModel):
36
+ lat: float = Field(..., description="Target latitude", example=22.360001)
37
+ lon: float = Field(..., description="Target longitude", example=82.530869)
38
+ tolerance: Optional[float] = Field(
39
+ 0.001, description="Search radius in degrees to find nearest PS point"
40
+ )
41
+
42
+
43
+ class EpochForecast(BaseModel):
44
+ day: float
45
+ failure_probability: float
46
+ uncertainty_std: float
47
+ high_risk: bool
48
+
49
+
50
+ class ForecastResponse(BaseModel):
51
+ ps_id: str
52
+ actual_lat: float
53
+ actual_lon: float
54
+ total_epochs: int
55
+ forecast_count: int
56
+ high_risk_count: int
57
+ high_risk_pct: float
58
+ mean_failure_probability: float
59
+ mean_uncertainty: float
60
+ first_alarm_day: Optional[float]
61
+ threshold_used: float
62
+ model_auc: float
63
+ model_pr_auc: float
64
+ forecasts: list[EpochForecast]
65
+
66
+
67
+ # ── Startup: load model & data ─────────────────────────────────────────────────
68
+ @app.on_event("startup")
69
+ def load_assets():
70
+ global scaler, cfg, model, df, df_clean
71
+
72
+ # Load scaler
73
+ with open("scaler.pkl", "rb") as f:
74
+ scaler = pickle.load(f)
75
+
76
+ # Load config
77
+ with open("model_config.json", "r") as f:
78
+ cfg = json.load(f)
79
+
80
+ # Load model ← import your model class before this block
81
+ # from your_model_module import YourPINNModel
82
+ # model = YourPINNModel(**cfg["model_params"]).to(DEVICE)
83
+ model.load_state_dict(torch.load("pinn_best.pt", map_location=DEVICE))
84
+ model.eval()
85
+
86
+ # Load dataframes ← replace with your actual data loading logic
87
+ # import pandas as pd
88
+ # df = pd.read_parquet("ps_data.parquet")
89
+ # df_clean = pd.read_parquet("ps_data_clean.parquet")
90
+
91
+ print(f"Assets loaded. Running on {DEVICE}")
92
+
93
+
94
+ # ── Helper: find nearest PS point ─────────────────────────────────────────────
95
+ def get_ps_by_latlon(lat: float, lon: float, tol: float = 0.001) -> str:
96
+ mask = (
97
+ (np.abs(df["lat"] - lat) <= tol) &
98
+ (np.abs(df["lon"] - lon) <= tol)
99
+ )
100
+ matches = df[mask]
101
+
102
+ if len(matches) == 0:
103
+ # Fallback: absolute nearest point
104
+ dist = np.sqrt((df["lat"] - lat) ** 2 + (df["lon"] - lon) ** 2)
105
+ nearest = df.loc[dist.idxmin()]
106
+ return str(nearest["ps_id"]), nearest["lat"], nearest["lon"], True
107
+
108
+ matches = matches.copy()
109
+ matches["_dist"] = np.sqrt(
110
+ (matches["lat"] - lat) ** 2 + (matches["lon"] - lon) ** 2
111
+ )
112
+ row = matches.loc[matches["_dist"].idxmin()]
113
+ return str(row["ps_id"]), row["lat"], row["lon"], False
114
+
115
+
116
+ # ── Forecast endpoint ──────────────────────────────────────────────────────────
117
+ @app.post("/forecast", response_model=ForecastResponse)
118
+ def forecast(req: ForecastRequest):
119
+ try:
120
+ ps_id, actual_lat, actual_lon, used_fallback = get_ps_by_latlon(
121
+ req.lat, req.lon, req.tolerance
122
+ )
123
+ except Exception as e:
124
+ raise HTTPException(status_code=404, detail=f"Could not find PS point: {e}")
125
+
126
+ # Load time series for this PS point
127
+ ps_raw = (
128
+ df[df["ps_id"] == ps_id]
129
+ .sort_values("days_since_start")
130
+ .reset_index(drop=True)
131
+ )
132
+ ps_clean = (
133
+ df_clean[df_clean["ps_id"] == ps_id]
134
+ .sort_values("days_since_start")
135
+ .reset_index(drop=True)
136
+ )
137
+
138
+ if len(ps_clean) < SEQ_LEN + HORIZON + 1:
139
+ raise HTTPException(
140
+ status_code=422,
141
+ detail=f"Insufficient data for PS point {ps_id} "
142
+ f"(need >{SEQ_LEN + HORIZON} epochs, got {len(ps_clean)})",
143
+ )
144
+
145
+ days_all = ps_raw["days_since_start"].values
146
+ disp_all = ps_raw["cumulative_disp_mm"].values
147
+
148
+ feats = ps_clean[FEATURE_COLS].values.astype(np.float32)
149
+ physics = ps_clean[PHYSICS_COLS].values.astype(np.float32)
150
+
151
+ threshold = cfg["best_threshold"]
152
+ epoch_forecasts = []
153
+
154
+ for i in range(SEQ_LEN, len(ps_clean) - HORIZON):
155
+ x_seq = torch.tensor(feats[i - SEQ_LEN:i]).unsqueeze(0).to(DEVICE)
156
+ p_vec = torch.tensor(physics[i]).unsqueeze(0).to(DEVICE)
157
+
158
+ preds = []
159
+ for _ in range(N_PASSES):
160
+ with torch.no_grad():
161
+ preds.append(torch.sigmoid(model(x_seq, p_vec)).item())
162
+
163
+ fcst_idx = i + HORIZON
164
+ mean_p = float(np.mean(preds))
165
+ std_p = float(np.std(preds))
166
+ high_risk = mean_p >= threshold
167
+
168
+ epoch_forecasts.append(
169
+ EpochForecast(
170
+ day=float(days_all[fcst_idx]),
171
+ failure_probability=round(mean_p, 6),
172
+ uncertainty_std=round(std_p, 6),
173
+ high_risk=high_risk,
174
+ )
175
+ )
176
+
177
+ # Aggregate stats
178
+ forecast_days = np.array([e.day for e in epoch_forecasts])
179
+ forecast_mean = np.array([e.failure_probability for e in epoch_forecasts])
180
+ forecast_std = np.array([e.uncertainty_std for e in epoch_forecasts])
181
+ forecast_risk = np.array([e.high_risk for e in epoch_forecasts])
182
+
183
+ n_risk = int(forecast_risk.sum())
184
+ first_alarm = (
185
+ float(forecast_days[forecast_risk == 1][0]) if n_risk > 0 else None
186
+ )
187
+
188
+ return ForecastResponse(
189
+ ps_id=ps_id,
190
+ actual_lat=float(actual_lat),
191
+ actual_lon=float(actual_lon),
192
+ total_epochs=len(ps_raw),
193
+ forecast_count=len(epoch_forecasts),
194
+ high_risk_count=n_risk,
195
+ high_risk_pct=round(n_risk / len(epoch_forecasts) * 100, 2),
196
+ mean_failure_probability=round(float(forecast_mean.mean()), 6),
197
+ mean_uncertainty=round(float(forecast_std.mean()), 6),
198
+ first_alarm_day=first_alarm,
199
+ threshold_used=threshold,
200
+ model_auc=cfg["test_auc"],
201
+ model_pr_auc=cfg["test_pr_auc"],
202
+ forecasts=epoch_forecasts,
203
+ )
204
+
205
+
206
+ # ── Health check ───────────────────────────────────────────────────────────────
207
+ @app.get("/health")
208
+ def health():
209
+ return {"status": "ok", "device": str(DEVICE)}
210
+
211
+
212
+ # ── Run locally ────────────────x────────────────────────────────────────────────
213
+ if __name__ == "__main__":
214
+ import uvicorn
215
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)
model_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_features": 14,
3
+ "n_physics": 3,
4
+ "hidden": 32,
5
+ "n_heads": 2,
6
+ "dropout": 0.4,
7
+ "seq_len": 8,
8
+ "horizon": 2,
9
+ "feature_cols": [
10
+ "lat",
11
+ "lon",
12
+ "cumulative_disp_mm",
13
+ "incremental_disp_mm",
14
+ "velocity_last_3ep",
15
+ "velocity_last_6ep",
16
+ "disp_rolling_mean_3",
17
+ "disp_rolling_mean_6",
18
+ "disp_rolling_std_6",
19
+ "coherence",
20
+ "mean_velocity_mm_yr",
21
+ "dem_height_m",
22
+ "seasonal_sin",
23
+ "seasonal_cos"
24
+ ],
25
+ "physics_cols": [
26
+ "acceleration",
27
+ "velocity_last_3ep",
28
+ "dem_height_m"
29
+ ],
30
+ "best_threshold": 0.5551024079322815,
31
+ "best_val_auc": 0.8073059156530157,
32
+ "test_auc": 0.8027481911831947,
33
+ "test_pr_auc": 0.2522120332201386,
34
+ "pos_weight": 11.660771704180064,
35
+ "stopped_epoch": 36
36
+ }
pinn_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0949c95525952af798cdf05deba454fd526f378a3203c3b57db862d98e198297
3
+ size 208926
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic
4
+ numpy
5
+ pandas
6
+ torch
7
+ scikit-learn
8
+ pyarrow
scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f62d69e1f6f60f2ce4a68992841051d7c0d21623220e5c62d313ca132fa6d901
3
+ size 1112