kozy9 commited on
Commit
8941ef3
Β·
verified Β·
1 Parent(s): 628607c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +935 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UK Groundwater Level Prediction Dashboard
3
+ ==========================================
4
+ Benchmarking SARIMAX, LSTM, and TCN for Monthly Groundwater Level Prediction.
5
+
6
+ Gradio app comparing three time-series forecasting models on a long-term UK
7
+ borehole dataset (1944-2023). Presents pre-computed evaluation results and
8
+ allows interactive scenario-based predictions.
9
+
10
+ Author: Ahmed | Module: IJC319 Responsible Data Science | University of Sheffield
11
+ """
12
+
13
+ import gradio as gr
14
+ import pandas as pd
15
+ import numpy as np
16
+ import plotly.graph_objects as go
17
+ from plotly.subplots import make_subplots
18
+ import joblib
19
+ from huggingface_hub import hf_hub_download
20
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
21
+ import warnings
22
+
23
+ warnings.filterwarnings("ignore")
24
+
25
+
26
+ # ======================================================================
27
+ # CONFIGURATION - UPDATE THESE TO MATCH YOUR DATA
28
+ # ======================================================================
29
+ # Check your CSV column names and update if they differ.
30
+ # Check FEATURE_COLS order matches the order your scalers were
31
+ # fitted on (open your notebook and verify).
32
+ # ======================================================================
33
+
34
+ DATE_COL = "date"
35
+ TARGET_COL = "water_level"
36
+ FEATURE_COLS = ["water_level", "temperature", "precipitation", "wind_speed"]
37
+ EXOG_COLS = ["temperature", "precipitation", "wind_speed"]
38
+
39
+ LOOKBACK = 24 # Sliding window length for LSTM/TCN
40
+
41
+ # HuggingFace repository IDs
42
+ LSTM_REPO = "kozy9/GWLSTM"
43
+ TCN_REPO = "kozy9/GWTCN"
44
+ SARIMAX_REPO = "kozy9/GWSarimax"
45
+
46
+ # Local CSV paths (place alongside app.py in your HF Space)
47
+ TRAIN_CSV = "uk_train.csv"
48
+ VALIDATE_CSV = "uk_validate.csv"
49
+ TEST_CSV = "uk_test.csv"
50
+
51
+ # Consistent colour palette across all tabs
52
+ COLOURS = {
53
+ "actual": "#1a2744",
54
+ "LSTM": "#2ecc71",
55
+ "TCN": "#e67e22",
56
+ "SARIMAX": "#3498db",
57
+ "Persistence": "#95a5a6",
58
+ "Seasonal": "#bdc3c7",
59
+ }
60
+
61
+
62
+ # ======================================================================
63
+ # DATA LOADING
64
+ # ======================================================================
65
+
66
+ print("=" * 60)
67
+ print("Loading data files...")
68
+ print("=" * 60)
69
+
70
+ try:
71
+ df_train = pd.read_csv(TRAIN_CSV, parse_dates=[DATE_COL])
72
+ df_val = pd.read_csv(VALIDATE_CSV, parse_dates=[DATE_COL])
73
+ df_test = pd.read_csv(TEST_CSV, parse_dates=[DATE_COL])
74
+ print(f" Train: {len(df_train)} rows")
75
+ print(f" Validate: {len(df_val)} rows")
76
+ print(f" Test: {len(df_test)} rows")
77
+ except FileNotFoundError as e:
78
+ raise FileNotFoundError(
79
+ f"Could not find data file: {e}\n"
80
+ "Make sure uk_train.csv, uk_validate.csv, and uk_test.csv "
81
+ "are in the same directory as app.py."
82
+ )
83
+
84
+ # Combine chronologically
85
+ df_all = (
86
+ pd.concat([df_train, df_val, df_test], ignore_index=True)
87
+ .sort_values(DATE_COL)
88
+ .reset_index(drop=True)
89
+ )
90
+ test_start_idx = len(df_train) + len(df_val)
91
+ test_dates = df_all[DATE_COL].iloc[test_start_idx:].values
92
+ test_actual = df_all[TARGET_COL].iloc[test_start_idx:].values
93
+
94
+ print(f" Total records: {len(df_all)}")
95
+ print(f" Features: {FEATURE_COLS}")
96
+ print(f" Test set starts at index: {test_start_idx}")
97
+
98
+
99
+ # ======================================================================
100
+ # MODEL LOADING (with error handling)
101
+ # ======================================================================
102
+
103
+ print("\n" + "=" * 60)
104
+ print("Downloading models from HuggingFace...")
105
+ print("=" * 60)
106
+
107
+ # -- LSTM --
108
+ lstm_model = None
109
+ lstm_scaler_X = None
110
+ lstm_scaler_y = None
111
+ try:
112
+ print(" Loading LSTM from", LSTM_REPO, "...")
113
+ from tensorflow.keras.models import load_model
114
+
115
+ lstm_model = load_model(hf_hub_download(LSTM_REPO, "lstm_model.keras"))
116
+ lstm_scaler_X = joblib.load(hf_hub_download(LSTM_REPO, "scaler_X.pkl"))
117
+ lstm_scaler_y = joblib.load(hf_hub_download(LSTM_REPO, "scaler_y.pkl"))
118
+ print(" LSTM loaded successfully.")
119
+ except Exception as e:
120
+ print(f" WARNING - LSTM failed to load: {e}")
121
+
122
+ # -- TCN --
123
+ tcn_model = None
124
+ tcn_scaler_X = None
125
+ tcn_scaler_y = None
126
+ try:
127
+ print(" Loading TCN from", TCN_REPO, "...")
128
+ from tensorflow.keras.models import load_model as load_keras_model
129
+
130
+ try:
131
+ from tcn import TCN as TCNLayer
132
+
133
+ tcn_model = load_keras_model(
134
+ hf_hub_download(TCN_REPO, "tcn_model.keras"),
135
+ custom_objects={"TCN": TCNLayer},
136
+ )
137
+ except ImportError:
138
+ tcn_model = load_keras_model(hf_hub_download(TCN_REPO, "tcn_model.keras"))
139
+ tcn_scaler_X = joblib.load(hf_hub_download(TCN_REPO, "scaler_features.pkl"))
140
+ tcn_scaler_y = joblib.load(hf_hub_download(TCN_REPO, "scaler_target.pkl"))
141
+ print(" TCN loaded successfully.")
142
+ except Exception as e:
143
+ print(f" WARNING - TCN failed to load: {e}")
144
+
145
+ # -- SARIMAX --
146
+ sarimax_model = None
147
+ try:
148
+ print(" Loading SARIMAX from", SARIMAX_REPO, "...")
149
+ sarimax_model = joblib.load(
150
+ hf_hub_download(SARIMAX_REPO, "sarimax_model.pkl")
151
+ )
152
+ # Verify it is a SARIMAXResultsWrapper, not a Keras model
153
+ model_type = type(sarimax_model).__name__
154
+ if "SARIMAX" not in model_type and "Results" not in model_type:
155
+ print(f" WARNING: Expected SARIMAXResultsWrapper but got {model_type}.")
156
+ print(" This may cause forecast errors. Re-run your SARIMAX notebook and")
157
+ print(" ensure the correct object is saved to the .pkl file.")
158
+ print(" SARIMAX loaded successfully.")
159
+ except Exception as e:
160
+ print(f" WARNING - SARIMAX failed to load: {e}")
161
+
162
+ loaded_models = {
163
+ "LSTM": lstm_model is not None,
164
+ "TCN": tcn_model is not None,
165
+ "SARIMAX": sarimax_model is not None,
166
+ }
167
+ print(f"\n Model status: {loaded_models}")
168
+
169
+
170
+ # ======================================================================
171
+ # GENERATE TEST SET PREDICTIONS
172
+ # ======================================================================
173
+
174
+ print("\n" + "=" * 60)
175
+ print("Generating test set predictions...")
176
+ print("=" * 60)
177
+
178
+
179
+ def predict_dl_test(model, scaler_X, scaler_y, data, feature_cols, test_start, lookback):
180
+ """Run sliding-window single-step-ahead inference over the test set."""
181
+ predictions = []
182
+ features = data[feature_cols].values
183
+ for i in range(test_start, len(data)):
184
+ if i - lookback < 0:
185
+ predictions.append(np.nan)
186
+ continue
187
+ window = features[i - lookback : i]
188
+ window_scaled = scaler_X.transform(window)
189
+ X_input = window_scaled.reshape(1, lookback, len(feature_cols))
190
+ y_scaled = model.predict(X_input, verbose=0)
191
+ pred = scaler_y.inverse_transform(y_scaled)[0][0]
192
+ predictions.append(pred)
193
+ return np.array(predictions)
194
+
195
+
196
+ # LSTM predictions
197
+ lstm_preds = np.full(len(df_test), np.nan)
198
+ if lstm_model is not None:
199
+ print(" Running LSTM inference on test set...")
200
+ lstm_preds = predict_dl_test(
201
+ lstm_model, lstm_scaler_X, lstm_scaler_y,
202
+ df_all, FEATURE_COLS, test_start_idx, LOOKBACK,
203
+ )
204
+ print(" LSTM predictions complete.")
205
+
206
+ # TCN predictions
207
+ tcn_preds = np.full(len(df_test), np.nan)
208
+ if tcn_model is not None:
209
+ print(" Running TCN inference on test set...")
210
+ tcn_preds = predict_dl_test(
211
+ tcn_model, tcn_scaler_X, tcn_scaler_y,
212
+ df_all, FEATURE_COLS, test_start_idx, LOOKBACK,
213
+ )
214
+ print(" TCN predictions complete.")
215
+
216
+ # SARIMAX predictions
217
+ sarimax_preds = np.full(len(df_test), np.nan)
218
+ sarimax_lower = np.full(len(df_test), np.nan)
219
+ sarimax_upper = np.full(len(df_test), np.nan)
220
+ if sarimax_model is not None:
221
+ print(" Running SARIMAX forecast on test set...")
222
+ try:
223
+ exog_test = df_all[EXOG_COLS].iloc[test_start_idx:]
224
+ sarimax_fc = sarimax_model.get_forecast(steps=len(df_test), exog=exog_test)
225
+ sarimax_preds = sarimax_fc.predicted_mean.values
226
+ sarimax_ci = sarimax_fc.conf_int()
227
+ sarimax_lower = sarimax_ci.iloc[:, 0].values
228
+ sarimax_upper = sarimax_ci.iloc[:, 1].values
229
+ print(" SARIMAX forecast complete.")
230
+ except Exception as e:
231
+ print(f" WARNING - SARIMAX forecast error: {e}")
232
+
233
+ # Naive baselines
234
+ print(" Computing naive baselines...")
235
+ persistence_preds = df_all[TARGET_COL].iloc[test_start_idx - 1 : -1].values
236
+ seasonal_preds = df_all[TARGET_COL].iloc[test_start_idx - 12 : len(df_all) - 12].values
237
+
238
+ # Assemble results DataFrame
239
+ results_df = pd.DataFrame({
240
+ "date": test_dates,
241
+ "actual": test_actual,
242
+ "LSTM": lstm_preds,
243
+ "TCN": tcn_preds,
244
+ "SARIMAX": sarimax_preds,
245
+ "SARIMAX_lower": sarimax_lower,
246
+ "SARIMAX_upper": sarimax_upper,
247
+ "Persistence": persistence_preds,
248
+ "Seasonal": seasonal_preds,
249
+ })
250
+
251
+ print("All predictions generated.\n")
252
+
253
+
254
+ # ======================================================================
255
+ # METRICS
256
+ # ======================================================================
257
+
258
+ def compute_metrics(actual, predicted, name):
259
+ """Compute RMSE, MAE, MAPE, R-squared, NSE - handling NaN values."""
260
+ mask = ~np.isnan(predicted) & ~np.isnan(actual)
261
+ a, p = actual[mask], predicted[mask]
262
+ if len(a) == 0:
263
+ return {"Model": name, "RMSE (m)": "N/A", "MAE (m)": "N/A",
264
+ "MAPE (%)": "N/A", "RΒ²": "N/A", "NSE": "N/A"}
265
+ rmse = np.sqrt(mean_squared_error(a, p))
266
+ mae = mean_absolute_error(a, p)
267
+ mape = np.mean(np.abs((a - p) / a)) * 100 if np.all(a != 0) else np.nan
268
+ r2 = r2_score(a, p)
269
+ nse = 1 - np.sum((a - p) ** 2) / np.sum((a - np.mean(a)) ** 2)
270
+ return {
271
+ "Model": name,
272
+ "RMSE (m)": round(rmse, 3),
273
+ "MAE (m)": round(mae, 3),
274
+ "MAPE (%)": round(mape, 2),
275
+ "RΒ²": round(r2, 4),
276
+ "NSE": round(nse, 4),
277
+ }
278
+
279
+
280
+ metrics_list = [
281
+ compute_metrics(test_actual, sarimax_preds, "SARIMAX"),
282
+ compute_metrics(test_actual, lstm_preds, "LSTM"),
283
+ compute_metrics(test_actual, tcn_preds, "TCN"),
284
+ compute_metrics(test_actual, persistence_preds, "Persistence Baseline"),
285
+ compute_metrics(test_actual, seasonal_preds, "Seasonal Naive Baseline"),
286
+ ]
287
+ metrics_df = pd.DataFrame(metrics_list)
288
+
289
+
290
+ # ======================================================================
291
+ # PREPROCESSING FOR SCENARIO PREDICTION
292
+ # ======================================================================
293
+
294
+ def preprocess_dl(last_24_rows, next_month_meteo, scaler_X, lookback=LOOKBACK):
295
+ """
296
+ Construct a scaled sliding window for LSTM/TCN inference.
297
+
298
+ Parameters
299
+ ----------
300
+ last_24_rows : pd.DataFrame
301
+ Most recent 24 months of observed data with columns matching FEATURE_COLS.
302
+ next_month_meteo : dict
303
+ User-specified values: {temperature, precipitation, wind_speed}.
304
+ scaler_X : MinMaxScaler
305
+ Fitted on training data only.
306
+
307
+ Returns
308
+ -------
309
+ np.ndarray of shape (1, 24, n_features)
310
+ """
311
+ # Use last known water_level as placeholder for target in the appended row
312
+ last_wl = last_24_rows[TARGET_COL].iloc[-1]
313
+ new_row = pd.DataFrame([{
314
+ TARGET_COL: last_wl,
315
+ "temperature": next_month_meteo["temperature"],
316
+ "precipitation": next_month_meteo["precipitation"],
317
+ "wind_speed": next_month_meteo["wind_speed"],
318
+ }])
319
+
320
+ # Append and take the last 24 rows as the input window
321
+ combined = pd.concat(
322
+ [last_24_rows[FEATURE_COLS], new_row[FEATURE_COLS]], ignore_index=True
323
+ )
324
+ window = combined.iloc[-lookback:].values
325
+ window_scaled = scaler_X.transform(window)
326
+ return window_scaled.reshape(1, lookback, len(FEATURE_COLS))
327
+
328
+
329
+ # Prepare the last 24 observed months for the scenario tab
330
+ last_24_df = df_all[FEATURE_COLS + [DATE_COL]].iloc[-LOOKBACK:].copy()
331
+ last_24_display = last_24_df.copy()
332
+ last_24_display[DATE_COL] = last_24_display[DATE_COL].dt.strftime("%Y-%m")
333
+ last_24_display = last_24_display.rename(columns={
334
+ DATE_COL: "Month",
335
+ TARGET_COL: "Water Level (m)",
336
+ "temperature": "Temp (C)",
337
+ "precipitation": "Precip (mm)",
338
+ "wind_speed": "Wind (m/s)",
339
+ })
340
+
341
+ # Slider ranges from training data
342
+ temp_min = float(df_train["temperature"].min())
343
+ temp_max = float(df_train["temperature"].max())
344
+ precip_min = float(df_train["precipitation"].min())
345
+ precip_max = float(df_train["precipitation"].max())
346
+ wind_min = float(df_train["wind_speed"].min())
347
+ wind_max = float(df_train["wind_speed"].max())
348
+ temp_mean = round(float(df_train["temperature"].mean()), 1)
349
+ precip_mean = round(float(df_train["precipitation"].mean()), 1)
350
+ wind_mean = round(float(df_train["wind_speed"].mean()), 1)
351
+
352
+
353
+ # ======================================================================
354
+ # TAB 1: FORECAST COMPARISON (PRE-COMPUTED)
355
+ # ======================================================================
356
+
357
+ def build_forecast_comparison(show_lstm, show_tcn, show_sarimax, show_ci):
358
+ """Overlay plot of test set predictions vs actual with toggleable traces."""
359
+ fig = go.Figure()
360
+
361
+ # Actual
362
+ fig.add_trace(go.Scatter(
363
+ x=results_df["date"], y=results_df["actual"],
364
+ name="Actual (Ground Truth)", mode="lines",
365
+ line=dict(color=COLOURS["actual"], width=2.5),
366
+ ))
367
+
368
+ if show_sarimax:
369
+ fig.add_trace(go.Scatter(
370
+ x=results_df["date"], y=results_df["SARIMAX"],
371
+ name="SARIMAX", mode="lines",
372
+ line=dict(color=COLOURS["SARIMAX"], width=1.8),
373
+ ))
374
+ if show_ci:
375
+ fig.add_trace(go.Scatter(
376
+ x=list(results_df["date"]) + list(results_df["date"][::-1]),
377
+ y=list(results_df["SARIMAX_upper"]) + list(results_df["SARIMAX_lower"][::-1]),
378
+ fill="toself", fillcolor="rgba(52, 152, 219, 0.1)",
379
+ line=dict(color="rgba(0,0,0,0)"),
380
+ name="SARIMAX 95% CI", showlegend=True,
381
+ ))
382
+
383
+ if show_lstm:
384
+ fig.add_trace(go.Scatter(
385
+ x=results_df["date"], y=results_df["LSTM"],
386
+ name="LSTM", mode="lines",
387
+ line=dict(color=COLOURS["LSTM"], width=1.8),
388
+ ))
389
+
390
+ if show_tcn:
391
+ fig.add_trace(go.Scatter(
392
+ x=results_df["date"], y=results_df["TCN"],
393
+ name="TCN", mode="lines",
394
+ line=dict(color=COLOURS["TCN"], width=1.8),
395
+ ))
396
+
397
+ fig.update_layout(
398
+ title="Test Set: Model Predictions vs Actual Groundwater Level",
399
+ xaxis_title="Date",
400
+ yaxis_title="Groundwater Level (m)",
401
+ height=520,
402
+ template="plotly_white",
403
+ font=dict(family="IBM Plex Sans, system-ui, sans-serif"),
404
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
405
+ margin=dict(t=60, b=40),
406
+ xaxis=dict(rangeslider=dict(visible=True, thickness=0.05)),
407
+ )
408
+ return fig
409
+
410
+
411
+ # ======================================================================
412
+ # TAB 2: SCENARIO PREDICTION
413
+ # ======================================================================
414
+
415
+ def predict_scenario(temperature, precipitation, wind_speed):
416
+ """Run all three models with user-specified next-month meteorological values."""
417
+ meteo = {
418
+ "temperature": temperature,
419
+ "precipitation": precipitation,
420
+ "wind_speed": wind_speed,
421
+ }
422
+
423
+ results = {}
424
+
425
+ # -- LSTM --
426
+ if lstm_model is not None:
427
+ try:
428
+ X_in = preprocess_dl(last_24_df, meteo, lstm_scaler_X)
429
+ y_sc = lstm_model.predict(X_in, verbose=0)
430
+ results["LSTM"] = float(lstm_scaler_y.inverse_transform(y_sc)[0][0])
431
+ except Exception as e:
432
+ results["LSTM"] = f"Error: {e}"
433
+ else:
434
+ results["LSTM"] = "Model not loaded"
435
+
436
+ # -- TCN --
437
+ if tcn_model is not None:
438
+ try:
439
+ X_in = preprocess_dl(last_24_df, meteo, tcn_scaler_X)
440
+ y_sc = tcn_model.predict(X_in, verbose=0)
441
+ results["TCN"] = float(tcn_scaler_y.inverse_transform(y_sc)[0][0])
442
+ except Exception as e:
443
+ results["TCN"] = f"Error: {e}"
444
+ else:
445
+ results["TCN"] = "Model not loaded"
446
+
447
+ # -- SARIMAX --
448
+ if sarimax_model is not None:
449
+ try:
450
+ exog_row = pd.DataFrame([{
451
+ "temperature": temperature,
452
+ "precipitation": precipitation,
453
+ "wind_speed": wind_speed,
454
+ }])
455
+ fc = sarimax_model.get_forecast(steps=1, exog=exog_row)
456
+ results["SARIMAX"] = float(fc.predicted_mean.iloc[0])
457
+ except Exception as e:
458
+ results["SARIMAX"] = f"Error: {e}"
459
+ else:
460
+ results["SARIMAX"] = "Model not loaded"
461
+
462
+ # -- Build output text --
463
+ lines = ["## Predicted Groundwater Level (Next Month)\n"]
464
+ for model_name in ["LSTM", "TCN", "SARIMAX"]:
465
+ val = results[model_name]
466
+ if isinstance(val, float):
467
+ lines.append(f"- **{model_name}:** {val:.2f} m")
468
+ else:
469
+ lines.append(f"- **{model_name}:** {val}")
470
+
471
+ # SARIMAX sensitivity check
472
+ sarimax_note = ""
473
+ if isinstance(results.get("SARIMAX"), float):
474
+ try:
475
+ exog_alt = pd.DataFrame([{
476
+ "temperature": temp_mean,
477
+ "precipitation": precip_mean,
478
+ "wind_speed": wind_mean,
479
+ }])
480
+ fc_alt = sarimax_model.get_forecast(steps=1, exog=exog_alt)
481
+ alt_pred = float(fc_alt.predicted_mean.iloc[0])
482
+ diff = abs(results["SARIMAX"] - alt_pred)
483
+ if diff < 0.5:
484
+ sarimax_note = (
485
+ "\n\n> **Note:** SARIMAX predictions are largely unaffected by "
486
+ "meteorological inputs (prediction changed by only "
487
+ f"{diff:.2f} m compared to mean conditions). This is consistent "
488
+ "with this study's finding that the model relies on autoregressive "
489
+ "structure rather than exogenous features."
490
+ )
491
+ except Exception:
492
+ pass
493
+
494
+ lines.append(sarimax_note)
495
+
496
+ # -- Build bar chart --
497
+ fig = go.Figure()
498
+ model_names = []
499
+ pred_values = []
500
+ bar_colours = []
501
+ for m in ["LSTM", "TCN", "SARIMAX"]:
502
+ if isinstance(results[m], float):
503
+ model_names.append(m)
504
+ pred_values.append(results[m])
505
+ bar_colours.append(COLOURS[m])
506
+
507
+ if pred_values:
508
+ fig.add_trace(go.Bar(
509
+ x=model_names, y=pred_values,
510
+ marker_color=bar_colours,
511
+ text=[f"{v:.2f} m" for v in pred_values],
512
+ textposition="outside",
513
+ width=0.5,
514
+ ))
515
+ fig.update_layout(
516
+ title="Scenario Prediction: All Models",
517
+ yaxis_title="Groundwater Level (m)",
518
+ height=400, template="plotly_white",
519
+ font=dict(family="IBM Plex Sans, system-ui, sans-serif"),
520
+ margin=dict(t=60, b=30),
521
+ )
522
+
523
+ return "\n".join(lines), fig
524
+
525
+
526
+ # ======================================================================
527
+ # TAB 3: PERFORMANCE METRICS
528
+ # ======================================================================
529
+
530
+ def build_metrics_bar():
531
+ """Grouped bar chart for key metrics across all models."""
532
+ fig = make_subplots(
533
+ rows=1, cols=2,
534
+ subplot_titles=(
535
+ "Error Metrics (Lower is Better)",
536
+ "Goodness-of-Fit (Higher is Better)",
537
+ ),
538
+ )
539
+
540
+ models = metrics_df["Model"].tolist()
541
+ rmse_vals = pd.to_numeric(metrics_df["RMSE (m)"], errors="coerce")
542
+ mae_vals = pd.to_numeric(metrics_df["MAE (m)"], errors="coerce")
543
+ r2_vals = pd.to_numeric(metrics_df["RΒ²"], errors="coerce")
544
+ nse_vals = pd.to_numeric(metrics_df["NSE"], errors="coerce")
545
+
546
+ colours = [COLOURS.get(m.split(" ")[0], "#888") for m in models]
547
+
548
+ fig.add_trace(go.Bar(
549
+ name="RMSE (m)", x=models, y=rmse_vals,
550
+ marker_color=colours, opacity=0.9,
551
+ ), row=1, col=1)
552
+ fig.add_trace(go.Bar(
553
+ name="MAE (m)", x=models, y=mae_vals,
554
+ marker_color=colours, opacity=0.55,
555
+ ), row=1, col=1)
556
+
557
+ fig.add_trace(go.Bar(
558
+ name="RΒ²", x=models, y=r2_vals,
559
+ marker_color=colours, opacity=0.9,
560
+ ), row=1, col=2)
561
+ fig.add_trace(go.Bar(
562
+ name="NSE", x=models, y=nse_vals,
563
+ marker_color=colours, opacity=0.55,
564
+ ), row=1, col=2)
565
+
566
+ fig.update_layout(
567
+ height=430, template="plotly_white",
568
+ font=dict(family="IBM Plex Sans, system-ui, sans-serif"),
569
+ showlegend=True,
570
+ legend=dict(orientation="h", yanchor="bottom", y=1.08, xanchor="center", x=0.5),
571
+ margin=dict(t=70, b=30),
572
+ )
573
+ return fig
574
+
575
+
576
+ # ======================================================================
577
+ # TAB 4: FEATURE IMPORTANCE
578
+ # ======================================================================
579
+ # UPDATE: Replace these placeholder values with your actual results
580
+ # from your notebooks.
581
+
582
+ lstm_importance = {
583
+ "water_level": 0.85, # UPDATE with your actual value
584
+ "temperature": 0.12, # UPDATE with your actual value
585
+ "wind_speed": 0.08, # UPDATE with your actual value
586
+ "precipitation": 0.03, # UPDATE with your actual value
587
+ }
588
+
589
+ sarimax_importance = {
590
+ "temperature": -0.02, # UPDATE with your actual value
591
+ "precipitation": -0.01, # UPDATE with your actual value
592
+ "wind_speed": 0.005, # UPDATE with your actual value
593
+ }
594
+
595
+
596
+ def build_feature_importance():
597
+ """Side-by-side horizontal bar charts for LSTM and SARIMAX."""
598
+ fig = make_subplots(
599
+ rows=1, cols=2,
600
+ subplot_titles=(
601
+ "LSTM - Permutation Feature Importance",
602
+ "SARIMAX - Permutation Feature Importance",
603
+ ),
604
+ horizontal_spacing=0.2,
605
+ )
606
+
607
+ # LSTM
608
+ lstm_sorted = sorted(lstm_importance.items(), key=lambda x: x[1])
609
+ lstm_features = [p[0] for p in lstm_sorted]
610
+ lstm_values = [p[1] for p in lstm_sorted]
611
+
612
+ fig.add_trace(go.Bar(
613
+ y=lstm_features, x=lstm_values,
614
+ orientation="h",
615
+ marker_color=[COLOURS["LSTM"] if v > 0 else "#e74c3c" for v in lstm_values],
616
+ text=[f"{v:.3f}" for v in lstm_values],
617
+ textposition="outside",
618
+ name="LSTM", showlegend=False,
619
+ ), row=1, col=1)
620
+
621
+ # SARIMAX
622
+ sar_sorted = sorted(sarimax_importance.items(), key=lambda x: x[1])
623
+ sar_features = [p[0] for p in sar_sorted]
624
+ sar_values = [p[1] for p in sar_sorted]
625
+
626
+ fig.add_trace(go.Bar(
627
+ y=sar_features, x=sar_values,
628
+ orientation="h",
629
+ marker_color=[COLOURS["SARIMAX"] if v > 0 else "#e74c3c" for v in sar_values],
630
+ text=[f"{v:.3f}" for v in sar_values],
631
+ textposition="outside",
632
+ name="SARIMAX", showlegend=False,
633
+ ), row=1, col=2)
634
+
635
+ fig.add_vline(x=0, line_dash="dot", line_color="#ccc", row=1, col=2)
636
+
637
+ fig.update_layout(
638
+ height=380, template="plotly_white",
639
+ font=dict(family="IBM Plex Sans, system-ui, sans-serif"),
640
+ margin=dict(t=60, b=30, l=130),
641
+ )
642
+ return fig
643
+
644
+
645
+ # ======================================================================
646
+ # TAB 5: MODEL ARCHITECTURES
647
+ # ======================================================================
648
+ # UPDATE: Replace all (UPDATE) placeholders with your actual
649
+ # hyperparameters from your notebooks.
650
+
651
+ ARCHITECTURE_MD = """
652
+ ## SARIMAX
653
+
654
+ | Parameter | Value |
655
+ |-----------|-------|
656
+ | Order (p, d, q) | *(UPDATE)* |
657
+ | Seasonal Order (P, D, Q, s) | *(UPDATE, e.g. (P, D, Q, 12))* |
658
+ | Optimisation | Optuna (TPE sampler, 80 trials, seed=42) |
659
+ | Exogenous Variables | temperature, precipitation, wind_speed |
660
+ | Key Finding | Performance driven by autoregressive structure; meteorological inputs statistically insignificant |
661
+
662
+ [View on HuggingFace](https://huggingface.co/Kozy9/GWSarimax)
663
+
664
+ ---
665
+
666
+ ## LSTM
667
+
668
+ | Parameter | Value |
669
+ |-----------|-------|
670
+ | Architecture | *(UPDATE: e.g. 2 LSTM layers)* |
671
+ | Units per Layer | *(UPDATE)* |
672
+ | Dropout | *(UPDATE)* |
673
+ | Optimiser | *(UPDATE: e.g. Adam)* |
674
+ | Lookback Window | 24 months |
675
+ | Optimisation | Keras Tuner (BayesianOptimization) |
676
+ | Input Shape | (24, 4) - 24 timesteps x 4 features |
677
+
678
+ [View on HuggingFace](https://huggingface.co/Kozy9/GWLSTM)
679
+
680
+ ---
681
+
682
+ ## TCN
683
+
684
+ | Parameter | Value |
685
+ |-----------|-------|
686
+ | Receptive Field | *(UPDATE)* |
687
+ | Filters | *(UPDATE)* |
688
+ | Kernel Size | *(UPDATE)* |
689
+ | Dilations | *(UPDATE: e.g. [1, 2, 4, 8])* |
690
+ | Dropout | *(UPDATE)* |
691
+ | Lookback Window | 24 months |
692
+ | Optimisation | Keras Tuner (BayesianOptimization, 20 trials) |
693
+ | Input Shape | (24, 4) - 24 timesteps x 4 features |
694
+ | Baseline RMSE (before tuning) | 5.91 m (R-squared/NSE = -0.82) |
695
+ | Tuned RMSE | 3.58 m (R-squared/NSE = 0.33) |
696
+ | Underperformance Factors | Small dataset (~766 training sequences), constrained search space, MSE loss under-predicting peaks |
697
+
698
+ [View on HuggingFace](https://huggingface.co/Kozy9/GWTCN)
699
+
700
+ ---
701
+
702
+ ## Preprocessing (Shared Across Models)
703
+
704
+ | Component | Detail |
705
+ |-----------|--------|
706
+ | Scaling | MinMaxScaler (separate scalers for features and target) |
707
+ | Fitting | Scalers fitted on training data only (no data leakage) |
708
+ | Lookback Window | 24 monthly timesteps for LSTM and TCN |
709
+ | Target Variable | water_level (metres) |
710
+ """
711
+
712
+
713
+ # ======================================================================
714
+ # GRADIO APP
715
+ # ======================================================================
716
+
717
+ with gr.Blocks(
718
+ title="UK Groundwater Level Prediction",
719
+ theme=gr.themes.Soft(
720
+ primary_hue="teal",
721
+ secondary_hue="blue",
722
+ font=["IBM Plex Sans", "system-ui", "sans-serif"],
723
+ ),
724
+ css="""
725
+ .main-header { text-align: center; margin-bottom: 0.3rem; }
726
+ .sub-header { text-align: center; color: #666; font-size: 0.95rem; margin-bottom: 1rem; }
727
+ .caveat-box { background: #f0f7ff; border-left: 4px solid #3498db;
728
+ padding: 12px 16px; border-radius: 6px; margin: 10px 0;
729
+ font-size: 0.88rem; color: #2c3e50; }
730
+ .warn-box { background: #fef9e7; border-left: 4px solid #f39c12;
731
+ padding: 12px 16px; border-radius: 6px; margin: 10px 0;
732
+ font-size: 0.88rem; color: #7d6608; }
733
+ """,
734
+ ) as app:
735
+
736
+ gr.Markdown(
737
+ "# Benchmarking SARIMAX, LSTM, and TCN for Monthly Groundwater Level Prediction",
738
+ elem_classes="main-header",
739
+ )
740
+ gr.Markdown(
741
+ "Comparing statistical and deep learning forecasting models on 79 years of UK "
742
+ "borehole observations (1944-2023). Module IJC319 | University of Sheffield.",
743
+ elem_classes="sub-header",
744
+ )
745
+
746
+ # ──────────────────────────────────────────────
747
+ # TAB 1 - FORECAST COMPARISON
748
+ # ──────────────────────────────────────────────
749
+ with gr.Tab("Forecast Comparison"):
750
+ gr.Markdown("### Test Set: Predicted vs Actual Groundwater Level")
751
+ gr.Markdown(
752
+ "Toggle individual model traces with the checkboxes below. "
753
+ "Use the date-range slider beneath the chart to zoom into specific periods."
754
+ )
755
+
756
+ with gr.Row():
757
+ fc_lstm = gr.Checkbox(value=True, label="LSTM")
758
+ fc_tcn = gr.Checkbox(value=True, label="TCN")
759
+ fc_sarimax = gr.Checkbox(value=True, label="SARIMAX")
760
+ fc_ci = gr.Checkbox(value=True, label="SARIMAX 95% CI")
761
+
762
+ fc_plot = gr.Plot(
763
+ value=build_forecast_comparison(True, True, True, True),
764
+ )
765
+
766
+ for chk in [fc_lstm, fc_tcn, fc_sarimax, fc_ci]:
767
+ chk.change(
768
+ fn=build_forecast_comparison,
769
+ inputs=[fc_lstm, fc_tcn, fc_sarimax, fc_ci],
770
+ outputs=fc_plot,
771
+ )
772
+
773
+ # ──────────────────────────────────────────────
774
+ # TAB 2 - SCENARIO PREDICTION
775
+ # ──────────────────────────────────────────────
776
+ with gr.Tab("Scenario Prediction"):
777
+ gr.Markdown("### Interactive Next-Month Prediction")
778
+ gr.Markdown(
779
+ "Adjust the meteorological sliders to define a scenario for the next month. "
780
+ "All three models will generate a prediction based on the last 24 months "
781
+ "of observed data shown below."
782
+ )
783
+
784
+ with gr.Accordion("Important Methodological Caveats", open=False):
785
+ gr.Markdown(
786
+ '<div class="caveat-box">'
787
+ "<strong>Different forecasting procedures:</strong> LSTM and TCN produce "
788
+ "single-step-ahead predictions using the last 24 months as a sliding window input. "
789
+ "SARIMAX forecasts using its fitted autoregressive parameters and internal state. "
790
+ "These are not identical forecasting procedures. See the Performance Metrics tab "
791
+ "for further details on this methodological asymmetry."
792
+ "</div>"
793
+ )
794
+ gr.Markdown(
795
+ '<div class="warn-box">'
796
+ "Predictions are based on models trained on a <strong>single UK observation "
797
+ "borehole</strong> dataset (1944-2023) and should <strong>not</strong> be used for "
798
+ "operational groundwater management decisions."
799
+ "</div>"
800
+ )
801
+
802
+ with gr.Row():
803
+ with gr.Column(scale=1):
804
+ gr.Markdown("#### Historical Context (Last 24 Observed Months)")
805
+ gr.DataFrame(
806
+ value=last_24_display,
807
+ label="Lookback Window",
808
+ interactive=False,
809
+ )
810
+
811
+ with gr.Column(scale=1):
812
+ gr.Markdown("#### Next Month's Meteorological Scenario")
813
+ sl_temp = gr.Slider(
814
+ minimum=temp_min, maximum=temp_max, value=temp_mean,
815
+ step=0.5, label="Temperature (C)",
816
+ )
817
+ sl_precip = gr.Slider(
818
+ minimum=precip_min, maximum=precip_max, value=precip_mean,
819
+ step=1.0, label="Precipitation (mm)",
820
+ )
821
+ sl_wind = gr.Slider(
822
+ minimum=wind_min, maximum=wind_max, value=wind_mean,
823
+ step=0.1, label="Wind Speed (m/s)",
824
+ )
825
+ btn_predict = gr.Button(
826
+ "Predict Next Month", variant="primary", size="lg",
827
+ )
828
+
829
+ pred_output = gr.Markdown()
830
+ pred_chart = gr.Plot()
831
+
832
+ btn_predict.click(
833
+ fn=predict_scenario,
834
+ inputs=[sl_temp, sl_precip, sl_wind],
835
+ outputs=[pred_output, pred_chart],
836
+ )
837
+
838
+ # ──────────────────────────────────────────────
839
+ # TAB 3 - PERFORMANCE METRICS
840
+ # ──────────────────────────────────────────────
841
+ with gr.Tab("Performance Metrics"):
842
+ gr.Markdown("### Evaluation Metrics on Held-Out Test Set")
843
+ gr.Markdown(
844
+ "All models evaluated on the same test period. Persistence (previous month's value) "
845
+ "and seasonal naive (same month, previous year) baselines provide benchmarking context."
846
+ )
847
+
848
+ gr.DataFrame(value=metrics_df, label="Performance Metrics", interactive=False)
849
+
850
+ gr.Markdown(
851
+ '<div class="caveat-box">'
852
+ "<strong>Methodological note:</strong> SARIMAX was evaluated using "
853
+ "<em>multi-step-ahead forecasting</em>; LSTM and TCN used "
854
+ "<em>single-step-ahead (rolling one-step) evaluation</em>. Direct metric "
855
+ "comparison should be interpreted with caution due to this methodological "
856
+ "difference. Multi-step forecasting accumulates error over the forecast horizon, "
857
+ "which may disadvantage SARIMAX relative to the deep learning models."
858
+ "</div>"
859
+ )
860
+
861
+ gr.Markdown("### Visual Comparison")
862
+ gr.Plot(value=build_metrics_bar())
863
+
864
+ # ──────────────────────────────────────────────
865
+ # TAB 4 - FEATURE IMPORTANCE
866
+ # ──────────────────────────────────────────────
867
+ with gr.Tab("Feature Importance"):
868
+ gr.Markdown("### Permutation Feature Importance Analysis")
869
+ gr.Markdown(
870
+ "Permutation feature importance measures how much each input variable "
871
+ "contributes to model accuracy. A feature is shuffled, and the resulting "
872
+ "increase in prediction error indicates its importance."
873
+ )
874
+
875
+ gr.Plot(value=build_feature_importance())
876
+
877
+ with gr.Row():
878
+ with gr.Column():
879
+ gr.Markdown(
880
+ "#### LSTM Interpretation\n\n"
881
+ "**Water level history** is the dominant input feature, confirming that "
882
+ "the LSTM relies heavily on autoregressive patterns in the target series. "
883
+ "Among meteorological variables, **temperature** is the most influential, "
884
+ "followed by wind speed and precipitation."
885
+ )
886
+ with gr.Column():
887
+ gr.Markdown(
888
+ "#### SARIMAX Interpretation\n\n"
889
+ "**Negative importance values** indicate that the exogenous meteorological "
890
+ "features did not contribute meaningfully to prediction accuracy. In some "
891
+ "cases, removing these features actually *improved* predictions. This is "
892
+ "consistent with the finding that SARIMAX performance is driven by its "
893
+ "**autoregressive and seasonal components**, not by external weather inputs."
894
+ )
895
+
896
+ gr.Markdown(
897
+ '<div class="warn-box">'
898
+ "<strong>Note:</strong> Feature importance analysis was not performed for "
899
+ "the TCN model in this study due to the model's weaker overall performance "
900
+ "and the focus on comparing the two stronger-performing approaches."
901
+ "</div>"
902
+ )
903
+
904
+ # ──────────────────────────────────────────────
905
+ # TAB 5 - MODEL ARCHITECTURES
906
+ # ──────────────────────────────────────────────
907
+ with gr.Tab("Model Architectures"):
908
+ gr.Markdown("### Model Specifications and Hyperparameters")
909
+ gr.Markdown(
910
+ "Full details of each model's architecture, optimisation approach, and "
911
+ "training configuration. Links to HuggingFace repositories are provided "
912
+ "for full reproducibility."
913
+ )
914
+ gr.Markdown(ARCHITECTURE_MD)
915
+
916
+ # ─────────────────────────���────────────────────
917
+ # FOOTER
918
+ # ──────────────────────────────────────────────
919
+ gr.Markdown(
920
+ "---\n"
921
+ "**IJC319 Responsible Data Science** | University of Sheffield | "
922
+ "[LSTM Repo](https://huggingface.co/Kozy9/GWLSTM) | "
923
+ "[TCN Repo](https://huggingface.co/Kozy9/GWTCN) | "
924
+ "[SARIMAX Repo](https://huggingface.co/Kozy9/GWSarimax)\n\n"
925
+ "*This tool is a research demonstrator trained on a single UK observation borehole. "
926
+ "Predictions are site-specific and must not be used for operational water management decisions.*"
927
+ )
928
+
929
+
930
+ # ======================================================================
931
+ # LAUNCH
932
+ # ======================================================================
933
+
934
+ if __name__ == "__main__":
935
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ tensorflow
3
+ keras-tcn
4
+ joblib
5
+ pandas
6
+ numpy
7
+ plotly
8
+ huggingface_hub
9
+ scikit-learn
10
+ statsmodels