agentsay commited on
Commit
777c571
Β·
verified Β·
1 Parent(s): b9abc4e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +225 -214
main.py CHANGED
@@ -1,215 +1,226 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel, Field
3
- import numpy as np
4
- import torch
5
- import pickle
6
- import json
7
- import os
8
- from typing import Optional
9
-
10
- # ── App Setup ─────────────────────────────────────────────────────────────────
11
- app = FastAPI(
12
- title="PSInSAR Deformation Forecast API",
13
- description="PINN-based ground deformation risk forecasting from PSInSAR data",
14
- version="1.0.0",
15
- )
16
-
17
- # ── Global state (loaded once at startup) ─────────────────────────────────────
18
- scaler = None
19
- cfg = None
20
- model = None
21
- df = None
22
- df_clean = None
23
-
24
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
-
26
- # ── These must match your training setup ──────────────────────────────────────
27
- FEATURE_COLS = [] # ← replace with your actual feature column names
28
- PHYSICS_COLS = [] # ← replace with your actual physics column names
29
- SEQ_LEN = 10 # ← replace with your actual sequence length
30
- HORIZON = 3 # ← replace with your actual horizon
31
- N_PASSES = 50 # MC Dropout passes
32
-
33
-
34
- # ── Request / Response Schemas ────────────────────────────────────────────────
35
- class ForecastRequest(BaseModel):
36
- lat: float = Field(..., description="Target latitude", example=22.360001)
37
- lon: float = Field(..., description="Target longitude", example=82.530869)
38
- tolerance: Optional[float] = Field(
39
- 0.001, description="Search radius in degrees to find nearest PS point"
40
- )
41
-
42
-
43
- class EpochForecast(BaseModel):
44
- day: float
45
- failure_probability: float
46
- uncertainty_std: float
47
- high_risk: bool
48
-
49
-
50
- class ForecastResponse(BaseModel):
51
- ps_id: str
52
- actual_lat: float
53
- actual_lon: float
54
- total_epochs: int
55
- forecast_count: int
56
- high_risk_count: int
57
- high_risk_pct: float
58
- mean_failure_probability: float
59
- mean_uncertainty: float
60
- first_alarm_day: Optional[float]
61
- threshold_used: float
62
- model_auc: float
63
- model_pr_auc: float
64
- forecasts: list[EpochForecast]
65
-
66
-
67
- # ── Startup: load model & data ─────────────────────────────────────────────────
68
- @app.on_event("startup")
69
- def load_assets():
70
- global scaler, cfg, model, df, df_clean
71
-
72
- # Load scaler
73
- with open("scaler.pkl", "rb") as f:
74
- scaler = pickle.load(f)
75
-
76
- # Load config
77
- with open("model_config.json", "r") as f:
78
- cfg = json.load(f)
79
-
80
- # Load model ← import your model class before this block
81
- # from your_model_module import YourPINNModel
82
- # model = YourPINNModel(**cfg["model_params"]).to(DEVICE)
83
- model.load_state_dict(torch.load("pinn_best.pt", map_location=DEVICE))
84
- model.eval()
85
-
86
- # Load dataframes ← replace with your actual data loading logic
87
- # import pandas as pd
88
- # df = pd.read_parquet("ps_data.parquet")
89
- # df_clean = pd.read_parquet("ps_data_clean.parquet")
90
-
91
- print(f"Assets loaded. Running on {DEVICE}")
92
-
93
-
94
- # ── Helper: find nearest PS point ─────────────────────────────────────────────
95
- def get_ps_by_latlon(lat: float, lon: float, tol: float = 0.001) -> str:
96
- mask = (
97
- (np.abs(df["lat"] - lat) <= tol) &
98
- (np.abs(df["lon"] - lon) <= tol)
99
- )
100
- matches = df[mask]
101
-
102
- if len(matches) == 0:
103
- # Fallback: absolute nearest point
104
- dist = np.sqrt((df["lat"] - lat) ** 2 + (df["lon"] - lon) ** 2)
105
- nearest = df.loc[dist.idxmin()]
106
- return str(nearest["ps_id"]), nearest["lat"], nearest["lon"], True
107
-
108
- matches = matches.copy()
109
- matches["_dist"] = np.sqrt(
110
- (matches["lat"] - lat) ** 2 + (matches["lon"] - lon) ** 2
111
- )
112
- row = matches.loc[matches["_dist"].idxmin()]
113
- return str(row["ps_id"]), row["lat"], row["lon"], False
114
-
115
-
116
- # ── Forecast endpoint ──────────────────────────────────────────────────────────
117
- @app.post("/forecast", response_model=ForecastResponse)
118
- def forecast(req: ForecastRequest):
119
- try:
120
- ps_id, actual_lat, actual_lon, used_fallback = get_ps_by_latlon(
121
- req.lat, req.lon, req.tolerance
122
- )
123
- except Exception as e:
124
- raise HTTPException(status_code=404, detail=f"Could not find PS point: {e}")
125
-
126
- # Load time series for this PS point
127
- ps_raw = (
128
- df[df["ps_id"] == ps_id]
129
- .sort_values("days_since_start")
130
- .reset_index(drop=True)
131
- )
132
- ps_clean = (
133
- df_clean[df_clean["ps_id"] == ps_id]
134
- .sort_values("days_since_start")
135
- .reset_index(drop=True)
136
- )
137
-
138
- if len(ps_clean) < SEQ_LEN + HORIZON + 1:
139
- raise HTTPException(
140
- status_code=422,
141
- detail=f"Insufficient data for PS point {ps_id} "
142
- f"(need >{SEQ_LEN + HORIZON} epochs, got {len(ps_clean)})",
143
- )
144
-
145
- days_all = ps_raw["days_since_start"].values
146
- disp_all = ps_raw["cumulative_disp_mm"].values
147
-
148
- feats = ps_clean[FEATURE_COLS].values.astype(np.float32)
149
- physics = ps_clean[PHYSICS_COLS].values.astype(np.float32)
150
-
151
- threshold = cfg["best_threshold"]
152
- epoch_forecasts = []
153
-
154
- for i in range(SEQ_LEN, len(ps_clean) - HORIZON):
155
- x_seq = torch.tensor(feats[i - SEQ_LEN:i]).unsqueeze(0).to(DEVICE)
156
- p_vec = torch.tensor(physics[i]).unsqueeze(0).to(DEVICE)
157
-
158
- preds = []
159
- for _ in range(N_PASSES):
160
- with torch.no_grad():
161
- preds.append(torch.sigmoid(model(x_seq, p_vec)).item())
162
-
163
- fcst_idx = i + HORIZON
164
- mean_p = float(np.mean(preds))
165
- std_p = float(np.std(preds))
166
- high_risk = mean_p >= threshold
167
-
168
- epoch_forecasts.append(
169
- EpochForecast(
170
- day=float(days_all[fcst_idx]),
171
- failure_probability=round(mean_p, 6),
172
- uncertainty_std=round(std_p, 6),
173
- high_risk=high_risk,
174
- )
175
- )
176
-
177
- # Aggregate stats
178
- forecast_days = np.array([e.day for e in epoch_forecasts])
179
- forecast_mean = np.array([e.failure_probability for e in epoch_forecasts])
180
- forecast_std = np.array([e.uncertainty_std for e in epoch_forecasts])
181
- forecast_risk = np.array([e.high_risk for e in epoch_forecasts])
182
-
183
- n_risk = int(forecast_risk.sum())
184
- first_alarm = (
185
- float(forecast_days[forecast_risk == 1][0]) if n_risk > 0 else None
186
- )
187
-
188
- return ForecastResponse(
189
- ps_id=ps_id,
190
- actual_lat=float(actual_lat),
191
- actual_lon=float(actual_lon),
192
- total_epochs=len(ps_raw),
193
- forecast_count=len(epoch_forecasts),
194
- high_risk_count=n_risk,
195
- high_risk_pct=round(n_risk / len(epoch_forecasts) * 100, 2),
196
- mean_failure_probability=round(float(forecast_mean.mean()), 6),
197
- mean_uncertainty=round(float(forecast_std.mean()), 6),
198
- first_alarm_day=first_alarm,
199
- threshold_used=threshold,
200
- model_auc=cfg["test_auc"],
201
- model_pr_auc=cfg["test_pr_auc"],
202
- forecasts=epoch_forecasts,
203
- )
204
-
205
-
206
- # ── Health check ───────────────────────────────────────────────────────────────
207
- @app.get("/health")
208
- def health():
209
- return {"status": "ok", "device": str(DEVICE)}
210
-
211
-
212
- # ── Run locally ────────────────x────────────────────────────────────────────────
213
- if __name__ == "__main__":
214
- import uvicorn
 
 
 
 
 
 
 
 
 
 
 
215
  uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel, Field
3
+ import numpy as np
4
+ import torch
5
+ import pickle
6
+ import json
7
+ import os
8
+ from typing import Optional
9
+
10
+ # ── App Setup ─────────────────────────────────────────────────────────────────
11
+ app = FastAPI(
12
+ title="PSInSAR Deformation Forecast API",
13
+ description="PINN-based ground deformation risk forecasting from PSInSAR data",
14
+ version="1.0.0",
15
+ )
16
+
17
+ # ── Global state (loaded once at startup) ─────────────────────────────────────
18
+ scaler = None
19
+ cfg = None
20
+ model = None
21
+ df = None
22
+ df_clean = None
23
+
24
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # ── These must match your training setup ──────────────────────────────────────
27
+ FEATURE_COLS = [] # ← replace with your actual feature column names
28
+ PHYSICS_COLS = [] # ← replace with your actual physics column names
29
+ SEQ_LEN = 10 # ← replace with your actual sequence length
30
+ HORIZON = 3 # ← replace with your actual horizon
31
+ N_PASSES = 50 # MC Dropout passes
32
+
33
+
34
+ # ── Request / Response Schemas ────────────────────────────────────────────────
35
+ class ForecastRequest(BaseModel):
36
+ lat: float = Field(..., description="Target latitude", example=22.360001)
37
+ lon: float = Field(..., description="Target longitude", example=82.530869)
38
+ tolerance: Optional[float] = Field(
39
+ 0.001, description="Search radius in degrees to find nearest PS point"
40
+ )
41
+
42
+
43
+ class EpochForecast(BaseModel):
44
+ day: float
45
+ failure_probability: float
46
+ uncertainty_std: float
47
+ high_risk: bool
48
+
49
+
50
+ class ForecastResponse(BaseModel):
51
+ ps_id: str
52
+ actual_lat: float
53
+ actual_lon: float
54
+ total_epochs: int
55
+ forecast_count: int
56
+ high_risk_count: int
57
+ high_risk_pct: float
58
+ mean_failure_probability: float
59
+ mean_uncertainty: float
60
+ first_alarm_day: Optional[float]
61
+ threshold_used: float
62
+ model_auc: float
63
+ model_pr_auc: float
64
+ forecasts: list[EpochForecast]
65
+
66
+
67
+ # ── Startup: load model & data ─────────────────────────────────────────────────
68
+ @app.on_event("startup")
69
+ def load_assets():
70
+ global scaler, cfg, model, df, df_clean
71
+
72
+ MODEL_PATH = os.getenv("MODEL_PATH", "artifacts/pinn_best.pt")
73
+ SCALER_PATH = os.getenv("SCALER_PATH", "artifacts/scaler.pkl")
74
+ CONFIG_PATH = os.getenv("CONFIG_PATH", "artifacts/model_config.json")
75
+
76
+ # ── 1. Scaler ──────────────────────────────────────────────────────────────
77
+ with open(SCALER_PATH, "rb") as f:
78
+ scaler = pickle.load(f)
79
+
80
+ # ── 2. Config ──────────────────────────────────────────────────────────────
81
+ with open(CONFIG_PATH, "r") as f:
82
+ cfg = json.load(f)
83
+
84
+ # ── 3. Model ───────────────────────────────────────────────────────────────
85
+ # OPTION A (recommended): instantiate your model class, then load weights
86
+ #
87
+ # from your_model_module import YourPINNModel
88
+ # model = YourPINNModel(**cfg["model_params"]).to(DEVICE)
89
+ # checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
90
+ # model.load_state_dict(checkpoint)
91
+ #
92
+ # OPTION B (fallback): load the entire pickled model object
93
+ # Use this if pinn_best.pt was saved with torch.save(model, path)
94
+ model = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
95
+ model.eval()
96
+
97
+ # ── 4. Data ────────────────────────────────────────────────────────────────
98
+ # import pandas as pd
99
+ # df = pd.read_parquet(os.getenv("DATA_PATH", "artifacts/ps_data.parquet"))
100
+ # df_clean = pd.read_parquet(os.getenv("DATA_CLEAN_PATH", "artifacts/ps_data_clean.parquet"))
101
+
102
+ print(f"Assets loaded | Device={DEVICE} | Threshold={cfg.get('best_threshold')}")
103
+
104
+
105
+ # ── Helper: find nearest PS point ─────────────────────────────────────────────
106
+ def get_ps_by_latlon(lat: float, lon: float, tol: float = 0.001) -> str:
107
+ mask = (
108
+ (np.abs(df["lat"] - lat) <= tol) &
109
+ (np.abs(df["lon"] - lon) <= tol)
110
+ )
111
+ matches = df[mask]
112
+
113
+ if len(matches) == 0:
114
+ # Fallback: absolute nearest point
115
+ dist = np.sqrt((df["lat"] - lat) ** 2 + (df["lon"] - lon) ** 2)
116
+ nearest = df.loc[dist.idxmin()]
117
+ return str(nearest["ps_id"]), nearest["lat"], nearest["lon"], True
118
+
119
+ matches = matches.copy()
120
+ matches["_dist"] = np.sqrt(
121
+ (matches["lat"] - lat) ** 2 + (matches["lon"] - lon) ** 2
122
+ )
123
+ row = matches.loc[matches["_dist"].idxmin()]
124
+ return str(row["ps_id"]), row["lat"], row["lon"], False
125
+
126
+
127
+ # ── Forecast endpoint ──────────────────────────────────────────────────────────
128
+ @app.post("/forecast", response_model=ForecastResponse)
129
+ def forecast(req: ForecastRequest):
130
+ try:
131
+ ps_id, actual_lat, actual_lon, used_fallback = get_ps_by_latlon(
132
+ req.lat, req.lon, req.tolerance
133
+ )
134
+ except Exception as e:
135
+ raise HTTPException(status_code=404, detail=f"Could not find PS point: {e}")
136
+
137
+ # Load time series for this PS point
138
+ ps_raw = (
139
+ df[df["ps_id"] == ps_id]
140
+ .sort_values("days_since_start")
141
+ .reset_index(drop=True)
142
+ )
143
+ ps_clean = (
144
+ df_clean[df_clean["ps_id"] == ps_id]
145
+ .sort_values("days_since_start")
146
+ .reset_index(drop=True)
147
+ )
148
+
149
+ if len(ps_clean) < SEQ_LEN + HORIZON + 1:
150
+ raise HTTPException(
151
+ status_code=422,
152
+ detail=f"Insufficient data for PS point {ps_id} "
153
+ f"(need >{SEQ_LEN + HORIZON} epochs, got {len(ps_clean)})",
154
+ )
155
+
156
+ days_all = ps_raw["days_since_start"].values
157
+ disp_all = ps_raw["cumulative_disp_mm"].values
158
+
159
+ feats = ps_clean[FEATURE_COLS].values.astype(np.float32)
160
+ physics = ps_clean[PHYSICS_COLS].values.astype(np.float32)
161
+
162
+ threshold = cfg["best_threshold"]
163
+ epoch_forecasts = []
164
+
165
+ for i in range(SEQ_LEN, len(ps_clean) - HORIZON):
166
+ x_seq = torch.tensor(feats[i - SEQ_LEN:i]).unsqueeze(0).to(DEVICE)
167
+ p_vec = torch.tensor(physics[i]).unsqueeze(0).to(DEVICE)
168
+
169
+ preds = []
170
+ for _ in range(N_PASSES):
171
+ with torch.no_grad():
172
+ preds.append(torch.sigmoid(model(x_seq, p_vec)).item())
173
+
174
+ fcst_idx = i + HORIZON
175
+ mean_p = float(np.mean(preds))
176
+ std_p = float(np.std(preds))
177
+ high_risk = mean_p >= threshold
178
+
179
+ epoch_forecasts.append(
180
+ EpochForecast(
181
+ day=float(days_all[fcst_idx]),
182
+ failure_probability=round(mean_p, 6),
183
+ uncertainty_std=round(std_p, 6),
184
+ high_risk=high_risk,
185
+ )
186
+ )
187
+
188
+ # Aggregate stats
189
+ forecast_days = np.array([e.day for e in epoch_forecasts])
190
+ forecast_mean = np.array([e.failure_probability for e in epoch_forecasts])
191
+ forecast_std = np.array([e.uncertainty_std for e in epoch_forecasts])
192
+ forecast_risk = np.array([e.high_risk for e in epoch_forecasts])
193
+
194
+ n_risk = int(forecast_risk.sum())
195
+ first_alarm = (
196
+ float(forecast_days[forecast_risk == 1][0]) if n_risk > 0 else None
197
+ )
198
+
199
+ return ForecastResponse(
200
+ ps_id=ps_id,
201
+ actual_lat=float(actual_lat),
202
+ actual_lon=float(actual_lon),
203
+ total_epochs=len(ps_raw),
204
+ forecast_count=len(epoch_forecasts),
205
+ high_risk_count=n_risk,
206
+ high_risk_pct=round(n_risk / len(epoch_forecasts) * 100, 2),
207
+ mean_failure_probability=round(float(forecast_mean.mean()), 6),
208
+ mean_uncertainty=round(float(forecast_std.mean()), 6),
209
+ first_alarm_day=first_alarm,
210
+ threshold_used=threshold,
211
+ model_auc=cfg["test_auc"],
212
+ model_pr_auc=cfg["test_pr_auc"],
213
+ forecasts=epoch_forecasts,
214
+ )
215
+
216
+
217
+ # ── Health check ───────────────────────────────────────────────────────────────
218
+ @app.get("/health")
219
+ def health():
220
+ return {"status": "ok", "device": str(DEVICE)}
221
+
222
+
223
+ # ── Run locally ────────────────────────────────────────────────────────────────
224
+ if __name__ == "__main__":
225
+ import uvicorn
226
  uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=False)