UCS2014 commited on
Commit
843472e
·
verified ·
1 Parent(s): ac61d22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +395 -493
app.py CHANGED
@@ -1,83 +1,75 @@
1
- # app.py
2
- import io, os, json, base64
3
  from pathlib import Path
4
-
5
- import numpy as np
6
- import pandas as pd
7
  import streamlit as st
 
 
8
  import joblib
9
 
10
- # =========================
11
- # Constants / defaults
12
- # =========================
 
 
 
 
 
 
13
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
14
  TARGET = "UCS"
15
-
16
  MODELS_DIR = Path("models")
17
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
18
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
19
 
20
- COLORS = {
21
- "pred": "#1f77b4", # blue
22
- "actual": "#f2c94c", # yellow
23
- "ref": "#444444", # 1:1 line
24
- }
25
-
26
- # =========================
27
- # Page config + CSS
28
- # =========================
29
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
30
 
31
- st.markdown("""
32
- <style>
33
- /* Hide default header/footer chrome */
34
- header, footer {visibility: hidden !important;}
35
- .stApp { background: #ffffff; }
36
-
37
- /* Sidebar look */
38
- section[data-testid="stSidebar"] { background: #F6F9FC; }
39
-
40
- /* Hero */
41
- .st-hero { display:flex; align-items:center; gap:14px; padding: 4px 0 2px 0; }
42
- .st-hero .brand { width:90px; height:90px; object-fit:contain; }
43
- .st-hero h1 { margin:0; line-height:1.05; }
44
- .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
45
-
46
- /* Keep hero snug to the top */
47
- [data-testid="stBlock"] { margin-top:0 !important; }
48
-
49
- /* Global primary button style (Run Model stays blue) */
50
- .stButton > button {
51
- background:#2563eb; color:#fff; font-weight:600; border:none; border-radius:8px;
52
- padding:9px 18px;
53
- }
54
-
55
- /* Orange preview button (scoped by wrapper) */
56
- #preview-btn button {
57
- background:#f59e0b !important; color:#fff !important;
58
- }
59
-
60
- /* Green proceed button (scoped by wrapper) */
61
- #proceed-btn button {
62
- background:#16a34a !important; color:#fff !important;
63
- }
64
-
65
- /* Info helper chip */
66
- .helper-note {
67
- background:#e7f0ff; border-radius:10px; padding:14px 16px; border:1px solid #d4e3ff;
68
- color:#0f172a;
69
- }
70
-
71
- /* Make tab content tighter */
72
- [data-baseweb="tab-border"] { margin-top: 0.2rem; }
73
-
74
- /* Plotly charts use white backgrounds via functions below */
75
- </style>
76
- """, unsafe_allow_html=True)
77
-
78
- # =========================
79
- # Utils
80
- # =========================
81
  def inline_logo(path="logo.png") -> str:
82
  try:
83
  p = Path(path)
@@ -86,32 +78,26 @@ def inline_logo(path="logo.png") -> str:
86
  except Exception:
87
  return ""
88
 
89
- def _get_model_url():
90
- # Safe access (prevents the "No secrets files" banner)
91
- try:
92
- return (st.secrets.get("MODEL_URL", "") or os.environ.get("MODEL_URL", "") or "").strip()
93
- except Exception:
94
- return (os.environ.get("MODEL_URL", "") or "").strip()
95
 
96
  @st.cache_data(show_spinner=False)
97
- def parse_excel_bytes(data_bytes: bytes):
98
  bio = io.BytesIO(data_bytes)
99
  xl = pd.ExcelFile(bio)
100
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
101
 
102
- def ensure_required_columns(df: pd.DataFrame, cols) -> bool:
103
  miss = [c for c in cols if c not in df.columns]
104
  if miss:
105
  st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
106
  return False
107
  return True
108
 
109
- @st.cache_resource(show_spinner=False)
110
- def load_model(model_path: str):
111
- return joblib.load(model_path)
112
-
113
  def infer_features_from_model(m):
114
- # Try scikit-learn feature names if present
115
  try:
116
  if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
117
  return [str(x) for x in m.feature_names_in_]
@@ -124,25 +110,114 @@ def infer_features_from_model(m):
124
  except Exception: pass
125
  return None
126
 
127
- def rmse(y_true, y_pred): # convenience
128
- from sklearn.metrics import mean_squared_error
129
- return float(np.sqrt(mean_squared_error(y_true, y_pred)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # =========================
132
- # Model availability
133
- # =========================
134
- MODEL_URL = _get_model_url()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def ensure_model_present() -> Path | None:
 
137
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
138
- if p.exists(): return p
 
 
 
139
  if MODEL_URL:
140
  try:
141
  import requests
142
- except Exception:
143
- st.error("Downloading the model requires 'requests'. Please add it to requirements.txt.")
144
- return None
145
- try:
146
  DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
147
  with requests.get(MODEL_URL, stream=True) as r:
148
  r.raise_for_status()
@@ -151,22 +226,22 @@ def ensure_model_present() -> Path | None:
151
  f.write(chunk)
152
  return DEFAULT_MODEL
153
  except Exception as e:
154
- st.error(f"Failed to download model from MODEL_URL. {e}")
155
- return None
156
  return None
157
 
158
  model_path = ensure_model_present()
159
  if not model_path:
160
- st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).")
161
  st.stop()
162
 
 
163
  try:
164
  model = load_model(str(model_path))
165
  except Exception as e:
166
  st.error(f"Failed to load model: {model_path}\n{e}")
167
  st.stop()
168
 
169
- # Optional meta overrides
170
  meta_path = MODELS_DIR / "meta.json"
171
  if meta_path.exists():
172
  try:
@@ -176,159 +251,26 @@ if meta_path.exists():
176
  except Exception:
177
  pass
178
  else:
179
- _inf = infer_features_from_model(model)
180
- if _inf: FEATURES = _inf
181
-
182
- # =========================
183
- # Plotly helpers (no titles, white background, safe margins)
184
- # =========================
185
- def _apply_plotly_base_layout(fig, *, top=40, left=60):
186
- fig.update_layout(
187
- margin=dict(l=left, r=10, t=top, b=40),
188
- paper_bgcolor="#ffffff",
189
- plot_bgcolor="#ffffff",
190
- font=dict(size=12),
191
- )
192
- fig.update_xaxes(automargin=True, title_font=dict(size=12), tickfont=dict(size=11))
193
- fig.update_yaxes(automargin=True, title_font=dict(size=12), tickfont=dict(size=11))
194
- return fig
195
-
196
- def cross_plotly(actual, pred):
197
- import plotly.graph_objects as go
198
- lo = float(np.nanmin([actual.min(), pred.min()]))
199
- hi = float(np.nanmax([actual.max(), pred.max()]))
200
- pad = 0.03 * (hi - lo if hi > lo else 1.0)
201
 
202
- fig = go.Figure()
203
- fig.add_trace(go.Scatter(
204
- x=actual, y=pred, mode="markers",
205
- marker=dict(size=6, color=COLORS["pred"]),
206
- hovertemplate="Actual: %{x:.2f}<br>Pred: %{y:.2f}<extra></extra>",
207
- showlegend=False, name="Points",
208
- ))
209
- fig.add_trace(go.Scatter(
210
- x=[lo - pad, hi + pad], y=[lo - pad, hi + pad],
211
- mode="lines", line=dict(dash="dash", width=1.5, color=COLORS["ref"]),
212
- hoverinfo="skip", showlegend=False,
213
- ))
214
-
215
- _apply_plotly_base_layout(fig, top=10, left=60)
216
- fig.update_xaxes(
217
- title_text="Actual UCS", title_standoff=10,
218
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
219
- zeroline=False, scaleanchor="y", scaleratio=1
220
- )
221
- fig.update_yaxes(
222
- title_text="Predicted UCS", title_standoff=10,
223
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
224
- zeroline=False
225
- )
226
- return fig
227
-
228
- def track_plotly(df, include_actual=True):
229
- import plotly.graph_objects as go
230
- depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
231
- if depth_col is not None:
232
- y = df[depth_col]; y_label = depth_col
233
- else:
234
- y = np.arange(1, len(df) + 1); y_label = "Point Index"
235
-
236
- fig = go.Figure()
237
- fig.add_trace(go.Scatter(
238
- x=df["UCS_Pred"], y=y, mode="lines",
239
- line=dict(color=COLORS["pred"], width=2),
240
- name="UCS_Pred",
241
- hovertemplate="UCS_Pred: %{x:.2f}<br>"+y_label+": %{y}<extra></extra>"
242
- ))
243
- if include_actual and TARGET in df.columns:
244
- fig.add_trace(go.Scatter(
245
- x=df[TARGET], y=y, mode="lines",
246
- line=dict(color=COLORS["actual"], dash="dot", width=2.2),
247
- name="UCS (actual)",
248
- hovertemplate="UCS (actual): %{x:.2f}<br>"+y_label+": %{y}<extra></extra>"
249
- ))
250
-
251
- _apply_plotly_base_layout(fig, top=60, left=70)
252
- fig.update_layout(
253
- legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0),
254
- height=650
255
- )
256
- fig.update_xaxes(
257
- title_text="UCS", side="top", title_standoff=12,
258
- showgrid=True, gridcolor="rgba(0,0,0,0.12)"
259
- )
260
- fig.update_yaxes(
261
- title_text=y_label, autorange="reversed", title_standoff=10,
262
- showgrid=True, gridcolor="rgba(0,0,0,0.12)"
263
- )
264
- return fig
265
-
266
- def make_index_tracks_plotly(df: pd.DataFrame, cols: list[str]):
267
- from plotly.subplots import make_subplots
268
- import plotly.graph_objects as go
269
-
270
- cols = [c for c in cols if c in df.columns]
271
- if not cols:
272
- fig = go.Figure()
273
- fig.add_annotation(text="No selected columns in sheet", showarrow=False, x=0.5, y=0.5)
274
- fig.update_xaxes(visible=False); fig.update_yaxes(visible=False)
275
- fig.update_layout(height=200, margin=dict(l=10,r=10,t=10,b=10),
276
- paper_bgcolor="#ffffff", plot_bgcolor="#ffffff")
277
- return fig
278
-
279
- n = len(cols)
280
- # IMPORTANT: shared_yaxes (not shared_y)
281
- fig = make_subplots(rows=1, cols=n, shared_yaxes=True, horizontal_spacing=0.05)
282
- idx = np.arange(1, len(df) + 1)
283
-
284
- for i, col in enumerate(cols, start=1):
285
- fig.add_trace(
286
- go.Scatter(
287
- x=df[col], y=idx, mode="lines",
288
- line=dict(color="#333333", width=1.2),
289
- hovertemplate=f"{col}: "+"%{x:.2f}<br>Index: %{y}<extra></extra>",
290
- showlegend=False, name=col,
291
- ), row=1, col=i
292
- )
293
- fig.update_xaxes(
294
- title_text=col, side="top", title_standoff=10,
295
- tickfont=dict(size=10),
296
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
297
- row=1, col=i
298
- )
299
-
300
- fig.update_yaxes(
301
- autorange="reversed", title_text="Point Index", title_standoff=10,
302
- tickfont=dict(size=10),
303
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
304
- row=1, col=1
305
- )
306
- fig.update_layout(
307
- height=650,
308
- margin=dict(l=60, r=10, t=60, b=40),
309
- paper_bgcolor="#ffffff",
310
- plot_bgcolor="#ffffff",
311
- font=dict(size=12),
312
- )
313
- return fig
314
-
315
- # =========================
316
  # Session state defaults
317
- # =========================
318
  ss = st.session_state
319
- ss.setdefault("app_step", "dev") # intro/dev/predict (you asked to start at dev)
320
- ss.setdefault("dev_bytes", None) # raw uploaded bytes
321
- ss.setdefault("dev_book", None) # parsed workbook dict
322
- ss.setdefault("dev_sheet_train", None) # chosen train sheet
323
- ss.setdefault("dev_sheet_test", None) # chosen test sheet
324
  ss.setdefault("dev_previewed", False)
325
  ss.setdefault("dev_ran", False)
326
  ss.setdefault("results", {})
327
  ss.setdefault("train_ranges", None)
 
 
328
 
329
- # =========================
330
- # Hero header
331
- # =========================
332
  st.markdown(
333
  f"""
334
  <div class="st-hero">
@@ -342,309 +284,269 @@ st.markdown(
342
  unsafe_allow_html=True,
343
  )
344
 
345
- # =========================
346
- # INTRO (kept for completeness – you said start in dev)
347
- # =========================
348
  if ss.app_step == "intro":
349
  st.header("Welcome!")
350
  st.markdown(
351
- "1. **Upload your data** to build the case and preview the performance of our model.\n"
352
- "2. **Run Model** to compute metrics and plots.\n"
353
- "3. **Proceed to Prediction** to validate on a new dataset and export results."
354
  )
355
- if st.button("Start", type="primary"): ss.app_step = "dev"; st.rerun()
356
-
357
- # =========================
358
- # DEVELOPMENT
359
- # =========================
360
- if ss.app_step == "dev":
361
- # Sidebar controls
362
- st.sidebar.header("Model Development Data")
363
- dev_file = st.sidebar.file_uploader("Replace data (Excel)", type=["xlsx","xls"], key="dev_upload")
364
-
365
- # Cache uploaded file into session (so preview doesn't clear it)
366
- if dev_file is not None:
367
- ss.dev_bytes = dev_file.getvalue()
368
- try:
369
- ss.dev_book = parse_excel_bytes(ss.dev_bytes)
370
- except Exception as e:
371
- st.sidebar.error(f"Failed to read workbook: {e}")
372
- ss.dev_book = None
373
- ss.dev_previewed = False
374
- ss.dev_ran = False
375
-
376
- # PREVIEW button (orange)
377
- st.sidebar.markdown("<div id='preview-btn'>", unsafe_allow_html=True)
378
- preview_click = st.sidebar.button("Preview data", use_container_width=True)
379
- st.sidebar.markdown("</div>", unsafe_allow_html=True)
380
-
381
- # RUN button (blue)
382
- run_click = st.sidebar.button("Run Model", use_container_width=True)
383
-
384
- # Proceed button (green; enabled after run)
385
- st.sidebar.markdown("<div id='proceed-btn'>", unsafe_allow_html=True)
386
- proceed_click = st.sidebar.button(
387
- "Proceed to Prediction ▶",
388
- use_container_width=True,
389
- disabled=not ss.dev_ran
390
- )
391
- st.sidebar.markdown("</div>", unsafe_allow_html=True)
392
-
393
- if proceed_click and ss.dev_ran:
394
- ss.app_step = "predict"
395
  st.rerun()
396
 
397
- # Section heading
 
 
 
398
  st.subheader("Model Development")
399
-
400
- # Helper message (sticks here always)
401
- helper = st.empty()
402
- if ss.dev_book is None:
403
- helper.markdown("<div class='helper-note'>Upload your data to build the case and preview the dataset.</div>", unsafe_allow_html=True)
404
- elif not ss.dev_previewed:
405
- helper.markdown("<div class='helper-note'>Data loaded ✓ — click <b>Preview data</b> to review tracks and summary.</div>", unsafe_allow_html=True)
406
  elif ss.dev_previewed and not ss.dev_ran:
407
- helper.markdown("<div class='helper-note'>Previewed ✓ — now click <b>Run Model</b> to build the case.</div>", unsafe_allow_html=True)
408
- else:
409
- helper.markdown("<div class='helper-note'>Case built ✓ — results are displayed below.</div>", unsafe_allow_html=True)
410
-
411
- # ----------------- Preview modal -----------------
412
- def preview_modal(book: dict, feature_cols: list[str]):
413
- if not book: return
414
- with st.expander("▼ Preview (tracks & summary)", expanded=True):
415
- # Choose a sheet to preview
416
- sheetnames = list(book.keys())
417
- sh = st.selectbox("Sheet", options=sheetnames, index=0, key="preview_sheet_sel")
418
- df = book[sh].copy()
419
-
420
- # Tracks tab + Stats tab
421
- t1, t2 = st.tabs(["Tracks", "Summary"])
422
-
423
- with t1:
424
- fig = make_index_tracks_plotly(df, feature_cols)
425
- st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
426
-
427
- with t2:
428
- stats = df[feature_cols].describe().T[["min", "max", "mean", "std"]].rename(
429
- columns={"min":"Min", "max":"Max", "mean":"Mean", "std":"Std"}
430
- )
431
- st.dataframe(stats, use_container_width=True)
432
-
433
- # If preview clicked and we have data
434
- if preview_click:
435
- if ss.dev_book:
436
- preview_modal(ss.dev_book, FEATURES)
437
- ss.dev_previewed = True
438
- ss.dev_ran = False
439
- st.rerun()
440
  else:
441
- st.warning("Please upload an Excel file first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
- # If run clicked and we have data
444
- if run_click:
445
- if not ss.dev_book:
446
- st.warning("Please upload and preview your data first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  else:
448
- # Try to find common sheet names
449
- names = list(ss.dev_book.keys())
450
- def find_sheet(book, alts):
451
- lo = {k.lower(): k for k in book.keys()}
452
- for nm in alts:
453
- if nm.lower() in lo: return lo[nm.lower()]
454
- return None
455
-
456
- sh_train = find_sheet(ss.dev_book, ["Train","Training","training2","train","training"]) or names[0]
457
- sh_test = find_sheet(ss.dev_book, ["Test","Testing","testing2","test","testing"]) or (names[1] if len(names)>1 else names[0])
458
- ss.dev_sheet_train, ss.dev_sheet_test = sh_train, sh_test
459
-
460
- df_tr = ss.dev_book[sh_train].copy()
461
- df_te = ss.dev_book[sh_test].copy()
462
-
463
- ok = ensure_required_columns(df_tr, FEATURES+[TARGET]) and ensure_required_columns(df_te, FEATURES+[TARGET])
464
- if ok:
465
  df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
466
  df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
467
-
468
- from sklearn.metrics import r2_score, mean_absolute_error
469
  ss.results["Train"] = df_tr
470
  ss.results["Test"] = df_te
471
  ss.results["metrics_train"] = {
472
  "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
473
  "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
474
- "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]),
475
  }
476
  ss.results["metrics_test"] = {
477
  "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
478
  "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
479
- "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
480
  }
481
-
482
  tr_min = df_tr[FEATURES].min().to_dict()
483
  tr_max = df_tr[FEATURES].max().to_dict()
484
  ss.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
485
-
486
  ss.dev_ran = True
487
- helper.markdown("<div class='helper-note'>Case built ✓ — results are displayed below.</div>", unsafe_allow_html=True)
488
- else:
489
- ss.dev_ran = False
490
 
491
- # Show results if available
492
- if ss.dev_ran and ("Train" in ss.results or "Test" in ss.results):
493
- ttr, tte = st.tabs(["Training", "Testing"])
494
 
495
- if "Train" in ss.results:
496
- with ttr:
497
- m = ss.results["metrics_train"]
498
- c1,c2,c3 = st.columns([1,1,1])
499
  c1.metric("R²", f"{m['R2']:.4f}")
500
  c2.metric("RMSE", f"{m['RMSE']:.4f}")
501
  c3.metric("MAE", f"{m['MAE']:.4f}")
502
- l, r = st.columns([0.55, 0.45])
503
- with l:
504
- st.plotly_chart(cross_plotly(ss.results["Train"][TARGET], ss.results["Train"]["UCS_Pred"]),
505
- use_container_width=True, config={"displayModeBar": False})
506
- with r:
507
- st.plotly_chart(track_plotly(ss.results["Train"], include_actual=True),
508
- use_container_width=True, config={"displayModeBar": False})
509
-
510
- if "Test" in ss.results:
511
- with tte:
512
- m = ss.results["metrics_test"]
513
- c1,c2,c3 = st.columns([1,1,1])
 
514
  c1.metric("R²", f"{m['R2']:.4f}")
515
  c2.metric("RMSE", f"{m['RMSE']:.4f}")
516
  c3.metric("MAE", f"{m['MAE']:.4f}")
517
- l, r = st.columns([0.55, 0.45])
518
- with l:
519
- st.plotly_chart(cross_plotly(ss.results["Test"][TARGET], ss.results["Test"]["UCS_Pred"]),
520
- use_container_width=True, config={"displayModeBar": False})
521
- with r:
522
- st.plotly_chart(track_plotly(ss.results["Test"], include_actual=True),
523
- use_container_width=True, config={"displayModeBar": False})
524
-
525
- # =========================
526
- # PREDICTION
527
- # =========================
528
- if ss.app_step == "predict":
529
- st.sidebar.header("Prediction (Validation)")
530
- val_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload")
531
- predict_click = st.sidebar.button("Predict", use_container_width=True)
532
- back_click = st.sidebar.button("⬅ Back", use_container_width=True)
533
 
534
- if back_click:
535
- ss.app_step = "dev"; st.rerun()
 
 
 
 
 
536
 
537
- st.subheader("Prediction")
538
- st.markdown("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.")
539
- st.success("Predictions ready ✓" if "Validate" in ss.results else "Waiting for input…")
 
 
540
 
541
- if predict_click and val_file is not None:
542
- try:
543
- vbook = parse_excel_bytes(val_file.getvalue())
544
- except Exception as e:
545
- st.error(f"Could not read the Validation Excel: {e}")
546
- vbook = {}
547
-
548
- if vbook:
549
- # Pick first sheet by default
550
- vname = list(vbook.keys())[0]
551
- df_val = vbook[vname].copy()
552
-
553
- if ensure_required_columns(df_val, FEATURES):
554
- df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
555
- ss.results["Validate"] = df_val
556
-
557
- # Out-of-range check vs training ranges
558
- ranges = ss.train_ranges; oor_table = None; oor_pct = 0.0
559
- if ranges:
560
- viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
561
- any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
562
- if any_viol.any():
563
- offenders = df_val.loc[any_viol, FEATURES].copy()
564
- offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(
565
- lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
566
- offenders.index = offenders.index + 1; oor_table = offenders
567
-
568
- from sklearn.metrics import r2_score, mean_absolute_error
569
- metrics_val = None
570
- if TARGET in df_val.columns:
571
- metrics_val = {
572
- "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
573
- "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
574
- "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"]),
575
- }
576
- ss.results["metrics_val"] = metrics_val
577
- ss.results["summary_val"] = {
578
- "n_points": len(df_val),
579
- "pred_min": float(df_val["UCS_Pred"].min()),
580
- "pred_max": float(df_val["UCS_Pred"].max()),
581
- "oor_pct": oor_pct
582
- }
583
- ss.results["oor_table"] = oor_table
584
- st.experimental_rerun()
585
-
586
- # Show prediction results
587
- if "Validate" in ss.results:
 
 
 
 
 
 
 
588
  sv = ss.results["summary_val"]; oor_table = ss.results.get("oor_table")
589
  c1,c2,c3,c4 = st.columns(4)
590
- c1.metric("# points", f"{sv['n_points']}")
591
- c2.metric("Pred min", f"{sv['pred_min']:.2f}")
592
- c3.metric("Pred max", f"{sv['pred_max']:.2f}")
593
- c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
594
- if sv["oor_pct"] > 0:
595
- st.warning("Some validation rows contain inputs outside the Training min–max ranges. Review the table below.")
596
-
597
- left, right = st.columns([0.55, 0.45])
598
  with left:
599
  if TARGET in ss.results["Validate"].columns:
600
  st.plotly_chart(
601
- cross_plotly(ss.results["Validate"][TARGET], ss.results["Validate"]["UCS_Pred"]),
602
- use_container_width=True, config={"displayModeBar": False}
603
  )
604
  else:
605
- st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
606
  with right:
607
  st.plotly_chart(
608
- track_plotly(ss.results["Validate"], include_actual=(TARGET in ss.results["Validate"].columns)),
609
- use_container_width=True, config={"displayModeBar": False}
 
610
  )
611
-
612
  if oor_table is not None:
613
- st.write("*Out-of-range rows (vs. Training min–max):*")
614
  st.dataframe(oor_table, use_container_width=True)
615
 
616
- # Export
617
- def export_workbook(sheets_dict, summary_df=None):
618
- try:
619
- import openpyxl
620
- except Exception:
621
- raise RuntimeError("Export requires openpyxl. Please add it to requirements.txt.")
622
- buf = io.BytesIO()
623
- with pd.ExcelWriter(buf, engine="openpyxl") as xw:
624
- for name, frame in sheets_dict.items():
625
- frame.to_excel(xw, sheet_name=name[:31], index=False)
626
- if summary_df is not None:
627
- summary_df.to_excel(xw, sheet_name="Summary", index=False)
628
- return buf.getvalue()
629
-
630
- st.markdown("---")
631
- sheets_to_save = {"Validate_with_pred": ss.results["Validate"]}
632
- rows = []
633
- for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]:
634
- m = ss.results.get(key)
635
- if m: rows.append({"Split": name, **{k: round(v,6) for k,v in m.items()}})
636
- summary_df = pd.DataFrame(rows) if rows else None
637
- try:
638
- data_bytes = export_workbook(sheets_to_save, summary_df)
639
- st.download_button("Export Validation Results to Excel",
640
- data=data_bytes, file_name="UCS_Validation_Results.xlsx",
641
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
642
- except RuntimeError as e:
643
- st.warning(str(e))
644
-
645
- # =========================
646
  # Footer
647
- # =========================
648
  st.markdown("---")
649
  st.markdown(
650
  "<div style='text-align:center; color:#6b7280;'>"
 
1
+ import io, json, os, base64
 
2
  from pathlib import Path
 
 
 
3
  import streamlit as st
4
+ import pandas as pd
5
+ import numpy as np
6
  import joblib
7
 
8
+ # --- Plotly (interactive) ---
9
+ import plotly.graph_objects as go
10
+ from plotly.subplots import make_subplots
11
+
12
+ from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
13
+
14
+ # =========================================================
15
+ # Defaults (overridden by models/meta.json or model.feature_names_in_)
16
+ # =========================================================
17
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
18
  TARGET = "UCS"
 
19
  MODELS_DIR = Path("models")
20
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
21
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
22
 
23
+ # =========================================================
24
+ # Page / Theme + CSS
25
+ # =========================================================
 
 
 
 
 
 
26
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
27
 
28
+ st.markdown(
29
+ """
30
+ <style>
31
+ /* App + sidebar background */
32
+ .stApp { background: #FFFFFF; }
33
+ section[data-testid="stSidebar"] { background: #F6F9FC; }
34
+
35
+ /* Tighten top spacing */
36
+ [data-testid="stBlock"]{ margin-top: 0 !important; }
37
+
38
+ /* Hero row */
39
+ .st-hero { display:flex; align-items:center; gap:16px; padding-top: 6px; }
40
+ .st-hero .brand { width:90px; height:90px; object-fit:contain; }
41
+ .st-hero h1 { margin:0; line-height:1.05; }
42
+ .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
43
+
44
+ /* Sidebar button palette (order-based within the Sidebar section)
45
+ 1) Preview (orange) 2) Run (blue) 3) Proceed (green)
46
+ We scope to the sidebar and to stButton blocks only. */
47
+ section[data-testid="stSidebar"] div.stButton > button {
48
+ font-weight:700; border-radius:10px; border:none; padding:10px 20px;
49
+ }
50
+ section[data-testid="stSidebar"] div.stButton:nth-of-type(1) > button { /* Preview */
51
+ background:#f59e0b; color:#fff;
52
+ }
53
+ section[data-testid="stSidebar"] div.stButton:nth-of-type(2) > button { /* Run (blue) */
54
+ background:#2563eb; color:#fff;
55
+ }
56
+ section[data-testid="stSidebar"] div.stButton:nth-of-type(3) > button { /* Proceed (green) */
57
+ background:#10b981; color:#fff;
58
+ }
59
+ section[data-testid="stSidebar"] div.stButton:nth-of-type(3) > button:disabled {
60
+ background:#a7f3d0 !important; color:#064e3b !important; opacity:.7 !important;
61
+ }
62
+
63
+ /* Modal tabs spacing */
64
+ .stTabs [data-baseweb="tab-list"] { gap: 6px; }
65
+ </style>
66
+ """,
67
+ unsafe_allow_html=True
68
+ )
69
+
70
+ # =========================================================
71
+ # Helpers
72
+ # =========================================================
 
 
 
 
 
73
  def inline_logo(path="logo.png") -> str:
74
  try:
75
  p = Path(path)
 
78
  except Exception:
79
  return ""
80
 
81
+ def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
82
+
83
+ @st.cache_resource(show_spinner=False)
84
+ def load_model(model_path: str):
85
+ return joblib.load(model_path)
 
86
 
87
  @st.cache_data(show_spinner=False)
88
+ def parse_excel(data_bytes: bytes):
89
  bio = io.BytesIO(data_bytes)
90
  xl = pd.ExcelFile(bio)
91
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
92
 
93
+ def ensure_cols(df, cols):
94
  miss = [c for c in cols if c not in df.columns]
95
  if miss:
96
  st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
97
  return False
98
  return True
99
 
 
 
 
 
100
  def infer_features_from_model(m):
 
101
  try:
102
  if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
103
  return [str(x) for x in m.feature_names_in_]
 
110
  except Exception: pass
111
  return None
112
 
113
+ def export_workbook(sheets_dict, summary_df=None):
114
+ try: import openpyxl # ensure engine is available
115
+ except Exception:
116
+ raise RuntimeError("Export requires openpyxl. Please add it to requirements.txt.")
117
+ buf = io.BytesIO()
118
+ with pd.ExcelWriter(buf, engine="openpyxl") as xw:
119
+ for name, frame in sheets_dict.items():
120
+ frame.to_excel(xw, sheet_name=name[:31], index=False)
121
+ if summary_df is not None:
122
+ summary_df.to_excel(xw, sheet_name="Summary", index=False)
123
+ return buf.getvalue()
124
+
125
+ # -------------------- Plotly styling blocks --------------------
126
+ AXES_STYLE = dict(
127
+ showline=True, linewidth=1.4, linecolor="#444",
128
+ mirror=True, ticks="outside", ticklen=4, tickwidth=1,
129
+ showgrid=True, gridcolor="rgba(0,0,0,0.08)"
130
+ )
131
+ FONT = dict(color="#111", size=13)
132
+
133
+ def style_layout(fig, width=None, height=None, margins=(12,18,36,12)):
134
+ t, r, b, l = margins
135
+ fig.update_layout(
136
+ margin=dict(t=t, r=r, b=b, l=l),
137
+ paper_bgcolor="white",
138
+ plot_bgcolor="white",
139
+ font=FONT,
140
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
141
+ )
142
+ if width: fig.update_layout(width=width)
143
+ if height: fig.update_layout(height=height)
144
+ # Apply to all axes
145
+ fig.update_xaxes(**AXES_STYLE, title_font=dict(size=14, color="#111"))
146
+ fig.update_yaxes(**AXES_STYLE, title_font=dict(size=14, color="#111"))
147
+ return fig
148
 
149
+ def make_cross_plotly(A, P, height=440, width=640):
150
+ a = pd.Series(A).astype(float)
151
+ p = pd.Series(P).astype(float)
152
+ lo = float(np.nanmin([a.min(), p.min()]))
153
+ hi = float(np.nanmax([a.max(), p.max()]))
154
+
155
+ fig = go.Figure()
156
+ fig.add_trace(go.Scatter(
157
+ x=a, y=p, mode="markers", name="Points", marker=dict(size=6)
158
+ ))
159
+ fig.add_trace(go.Scatter(
160
+ x=[lo, hi], y=[lo, hi], mode="lines", name="1:1",
161
+ line=dict(color="#666", width=2, dash="dash")
162
+ ))
163
+ fig.update_xaxes(range=[lo, hi], title="Actual UCS")
164
+ fig.update_yaxes(range=[lo, hi], title="Predicted UCS", scaleanchor="x", scaleratio=1)
165
+ style_layout(fig, width=width, height=height, margins=(8,10,36,50))
166
+ return fig
167
+
168
+ def make_depth_track_plotly(df, include_actual=True, height=640, width=360):
169
+ idx = np.arange(1, len(df) + 1)
170
+ fig = go.Figure()
171
+ # Predicted (solid blue)
172
+ fig.add_trace(go.Scatter(
173
+ x=df["UCS_Pred"], y=idx, mode="lines", name="UCS_Pred",
174
+ line=dict(color="#1f77b4", width=2)
175
+ ))
176
+ # Actual (dotted yellow)
177
+ if include_actual and TARGET in df.columns:
178
+ fig.add_trace(go.Scatter(
179
+ x=df[TARGET], y=idx, mode="lines", name="UCS (actual)",
180
+ line=dict(color="#f2b01e", width=2, dash="dot")
181
+ ))
182
+ fig.update_yaxes(autorange="reversed", title="Point Index")
183
+ fig.update_xaxes(title="UCS")
184
+ style_layout(fig, width=width, height=height, margins=(8,12,36,60))
185
+ return fig
186
+
187
+ def make_index_tracks_plotly(df, feature_cols, height=640, width=980):
188
+ n = len(feature_cols)
189
+ fig = make_subplots(rows=1, cols=n, shared_yaxes=True, horizontal_spacing=0.05)
190
+ idx = np.arange(1, len(df) + 1)
191
+
192
+ for i, col in enumerate(feature_cols, start=1):
193
+ fig.add_trace(
194
+ go.Scatter(x=df[col], y=idx, mode="lines", line=dict(color="#444", width=1.2), name=col, showlegend=False),
195
+ row=1, col=i
196
+ )
197
+ fig.update_xaxes(title=col, row=1, col=i)
198
+ fig.update_yaxes(autorange="reversed", title="Point Index", row=1, col=1)
199
+ style_layout(fig, width=width, height=height, margins=(6,8,36,60))
200
+ return fig
201
+
202
+ # =========================================================
203
+ # Model availability (cloud-safe)
204
+ # =========================================================
205
+ def _get_model_url():
206
+ try:
207
+ return (st.secrets.get("MODEL_URL", "") or os.environ.get("MODEL_URL", "") or "").strip()
208
+ except Exception:
209
+ return (os.environ.get("MODEL_URL", "") or "").strip()
210
 
211
  def ensure_model_present() -> Path | None:
212
+ # local candidates
213
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
214
+ if p.exists():
215
+ return p
216
+ # cloud download
217
+ MODEL_URL = _get_model_url()
218
  if MODEL_URL:
219
  try:
220
  import requests
 
 
 
 
221
  DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
222
  with requests.get(MODEL_URL, stream=True) as r:
223
  r.raise_for_status()
 
226
  f.write(chunk)
227
  return DEFAULT_MODEL
228
  except Exception as e:
229
+ st.error(f"Failed to download model from MODEL_URL: {e}")
 
230
  return None
231
 
232
  model_path = ensure_model_present()
233
  if not model_path:
234
+ st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
235
  st.stop()
236
 
237
+ # Load model
238
  try:
239
  model = load_model(str(model_path))
240
  except Exception as e:
241
  st.error(f"Failed to load model: {model_path}\n{e}")
242
  st.stop()
243
 
244
+ # Meta overrides
245
  meta_path = MODELS_DIR / "meta.json"
246
  if meta_path.exists():
247
  try:
 
251
  except Exception:
252
  pass
253
  else:
254
+ infer = infer_features_from_model(model)
255
+ if infer: FEATURES = infer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
+ # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Session state defaults
259
+ # =========================================================
260
  ss = st.session_state
261
+ ss.setdefault("app_step", "intro") # we start on Intro
262
+ ss.setdefault("dev_bytes", None)
263
+ ss.setdefault("dev_book", None)
 
 
264
  ss.setdefault("dev_previewed", False)
265
  ss.setdefault("dev_ran", False)
266
  ss.setdefault("results", {})
267
  ss.setdefault("train_ranges", None)
268
+ ss.setdefault("val_bytes", None)
269
+ ss.setdefault("val_book", None)
270
 
271
+ # =========================================================
272
+ # HERO (logo + title)
273
+ # =========================================================
274
  st.markdown(
275
  f"""
276
  <div class="st-hero">
 
284
  unsafe_allow_html=True,
285
  )
286
 
287
+ # =========================================================
288
+ # INTRO PAGE
289
+ # =========================================================
290
  if ss.app_step == "intro":
291
  st.header("Welcome!")
292
  st.markdown(
293
+ "1. **Upload your data to build the case** and preview the performance of our model. \n"
294
+ "2. Click **Run Model** to compute metrics, cross-plots, and the index track. \n"
295
+ "3. Click **Proceed to Prediction** to validate on a new dataset."
296
  )
297
+ if st.button("Start", type="primary"):
298
+ ss.app_step = "dev"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  st.rerun()
300
 
301
+ # =========================================================
302
+ # Helper banner (stays at top of Development page)
303
+ # =========================================================
304
+ def render_dev_helper():
305
  st.subheader("Model Development")
306
+ if not ss.dev_bytes:
307
+ st.info("Upload your data to build the case and preview the performance of our model.")
308
+ elif ss.dev_bytes and not ss.dev_previewed and not ss.dev_ran:
309
+ st.info("File loaded — click **Preview data**.")
 
 
 
310
  elif ss.dev_previewed and not ss.dev_ran:
311
+ st.info("Previewed ✓ — now click **Run Model** to build the case.")
312
+ elif ss.dev_ran:
313
+ st.success("Case built ✓ — results are displayed below. You can now **Proceed to Prediction**.")
314
+
315
+ # =========================================================
316
+ # PREVIEW MODAL
317
+ # =========================================================
318
+ def preview_modal_dev(book, feature_cols):
319
+ sh_train = None
320
+ sh_test = None
321
+ # try common names
322
+ low2orig = {k.lower(): k for k in book.keys()}
323
+ for nm in ["train","training","training2"]:
324
+ if nm in low2orig: sh_train = low2orig[nm]; break
325
+ for nm in ["test","testing","testing2"]:
326
+ if nm in low2orig: sh_test = low2orig[nm]; break
327
+
328
+ tabs = st.tabs(["Tracks", "Summary"])
329
+ with tabs[0]:
330
+ # prefer Train if available; else first sheet
331
+ pick = sh_train or list(book.keys())[0]
332
+ df = book[pick]
333
+ # only numeric columns needed for plotting
334
+ ok_cols = [c for c in feature_cols if c in df.columns]
335
+ if not ok_cols:
336
+ st.warning("No matching feature columns found for plotting.")
 
 
 
 
 
 
 
337
  else:
338
+ fig = make_index_tracks_plotly(df, ok_cols, height=640, width=1000)
339
+ st.plotly_chart(fig, use_container_width=True, theme=None)
340
+ with tabs[1]:
341
+ pick = sh_train or list(book.keys())[0]
342
+ df = book[pick]
343
+ st.dataframe(
344
+ df.describe().T.rename(columns={
345
+ "mean":"Mean","std":"Std","min":"Min","max":"Max"
346
+ })[["Min","Max","Mean","Std"]].round(4),
347
+ use_container_width=True
348
+ )
349
+
350
+ # =========================================================
351
+ # DEVELOPMENT PAGE
352
+ # =========================================================
353
+ if ss.app_step == "dev":
354
+ render_dev_helper()
355
+
356
+ with st.sidebar:
357
+ st.header("Model Development Data")
358
+
359
+ def _on_dev_upload():
360
+ file = st.session_state.get("dev_upload")
361
+ if file is not None:
362
+ ss.dev_bytes = file.getvalue()
363
+ ss.dev_book = parse_excel(ss.dev_bytes)
364
+ ss.dev_previewed = False
365
+ ss.dev_ran = False
366
 
367
+ st.file_uploader("Replace data (Excel)", type=["xlsx","xls"], key="dev_upload",
368
+ on_change=_on_dev_upload, help="Limit 200MB per file • XLSX, XLS")
369
+
370
+ if ss.dev_bytes and ss.dev_book:
371
+ # Small status line under upload
372
+ any_sheet = next(iter(ss.dev_book.values()))
373
+ st.caption(f"Data loaded: {getattr(st.session_state.get('dev_upload'), 'name', 'file')} • "
374
+ f"{any_sheet.shape[0]} rows × {any_sheet.shape[1]} cols")
375
+
376
+ preview_clicked = st.button("Preview data", disabled=not bool(ss.dev_book))
377
+ run_clicked = st.button("Run Model", disabled=not bool(ss.dev_book))
378
+ proceed_clicked = st.button("Proceed to Prediction ▶", disabled=not ss.get("dev_ran", False))
379
+
380
+ # Modal preview (does NOT clear the uploaded file)
381
+ if preview_clicked and ss.dev_book:
382
+ with st.modal("Preview data"):
383
+ st.write("Use the tabs below to inspect the uploaded data before running the model.")
384
+ preview_modal_dev(ss.dev_book, FEATURES)
385
+ if st.button("Close", type="primary"):
386
+ ss.dev_previewed = True
387
+ st.rerun()
388
+
389
+ # Run model
390
+ if run_clicked and ss.dev_book:
391
+ # pick sheets
392
+ book = ss.dev_book
393
+ low2orig = {k.lower(): k for k in book.keys()}
394
+ sh_train = None; sh_test=None
395
+ for nm in ["train","training","training2"]:
396
+ if nm in low2orig: sh_train = low2orig[nm]; break
397
+ for nm in ["test","testing","testing2"]:
398
+ if nm in low2orig: sh_test = low2orig[nm]; break
399
+
400
+ if sh_train is None or sh_test is None:
401
+ st.error("Workbook must include sheets named *Train/Training* and *Test/Testing* (any one of those).")
402
  else:
403
+ df_tr = book[sh_train].copy()
404
+ df_te = book[sh_test].copy()
405
+ if ensure_cols(df_tr, FEATURES+[TARGET]) and ensure_cols(df_te, FEATURES+[TARGET]):
406
+ # predict
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
408
  df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
 
 
409
  ss.results["Train"] = df_tr
410
  ss.results["Test"] = df_te
411
  ss.results["metrics_train"] = {
412
  "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
413
  "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
414
+ "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"])
415
  }
416
  ss.results["metrics_test"] = {
417
  "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
418
  "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
419
+ "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"])
420
  }
 
421
  tr_min = df_tr[FEATURES].min().to_dict()
422
  tr_max = df_tr[FEATURES].max().to_dict()
423
  ss.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
 
424
  ss.dev_ran = True
425
+ st.rerun()
 
 
426
 
427
+ # Results (if available)
428
+ if ss.results.get("Train") is not None or ss.results.get("Test") is not None:
429
+ tab1, tab2 = st.tabs(["Training", "Testing"])
430
 
431
+ if ss.results.get("Train") is not None:
432
+ with tab1:
433
+ df = ss.results["Train"]; m = ss.results["metrics_train"]
434
+ c1,c2,c3 = st.columns(3)
435
  c1.metric("R²", f"{m['R2']:.4f}")
436
  c2.metric("RMSE", f"{m['RMSE']:.4f}")
437
  c3.metric("MAE", f"{m['MAE']:.4f}")
438
+
439
+ left, right = st.columns([0.58, 0.42])
440
+ with left:
441
+ st.plotly_chart(make_cross_plotly(df[TARGET], df["UCS_Pred"], height=440, width=640),
442
+ use_container_width=True, theme=None)
443
+ with right:
444
+ st.plotly_chart(make_depth_track_plotly(df, include_actual=True, height=640, width=360),
445
+ use_container_width=True, theme=None)
446
+
447
+ if ss.results.get("Test") is not None:
448
+ with tab2:
449
+ df = ss.results["Test"]; m = ss.results["metrics_test"]
450
+ c1,c2,c3 = st.columns(3)
451
  c1.metric("R²", f"{m['R2']:.4f}")
452
  c2.metric("RMSE", f"{m['RMSE']:.4f}")
453
  c3.metric("MAE", f"{m['MAE']:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
+ left, right = st.columns([0.58, 0.42])
456
+ with left:
457
+ st.plotly_chart(make_cross_plotly(df[TARGET], df["UCS_Pred"], height=440, width=640),
458
+ use_container_width=True, theme=None)
459
+ with right:
460
+ st.plotly_chart(make_depth_track_plotly(df, include_actual=True, height=640, width=360),
461
+ use_container_width=True, theme=None)
462
 
463
+ # =========================================================
464
+ # PREDICTION PAGE
465
+ # =========================================================
466
+ if ss.app_step == "dev" and st.sidebar.button("→ Open Prediction in main area", key="force_pred"):
467
+ ss.app_step = "predict"; st.rerun()
468
 
469
+ if ss.app_step == "predict":
470
+ st.subheader("Prediction")
471
+ st.write("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.")
472
+
473
+ with st.sidebar:
474
+ st.header("Prediction (Validation)")
475
+ def _on_val_upload():
476
+ file = st.session_state.get("val_upload")
477
+ if file is not None:
478
+ ss.val_bytes = file.getvalue()
479
+ ss.val_book = parse_excel(ss.val_bytes)
480
+
481
+ st.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload", on_change=_on_val_upload)
482
+ predict_clicked = st.button("Predict", type="primary", use_container_width=True)
483
+ st.button("⬅ Back", on_click=lambda: ss.update(app_step="dev"))
484
+
485
+ if predict_clicked and ss.val_book:
486
+ vname = list(ss.val_book.keys())[0]
487
+ df_val = ss.val_book[vname].copy()
488
+ if not ensure_cols(df_val, FEATURES):
489
+ st.stop()
490
+ df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
491
+ ss.results["Validate"] = df_val
492
+
493
+ ranges = ss.train_ranges
494
+ oor_table = None; oor_pct = 0.0
495
+ if ranges:
496
+ viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
497
+ any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
498
+ if any_viol.any():
499
+ offenders = df_val.loc[any_viol, FEATURES].copy()
500
+ offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(
501
+ lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
502
+ offenders.index = offenders.index + 1; oor_table = offenders
503
+
504
+ metrics_val = None
505
+ if TARGET in df_val.columns:
506
+ metrics_val = {
507
+ "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
508
+ "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
509
+ "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"])
510
+ }
511
+ ss.results["metrics_val"] = metrics_val
512
+ ss.results["summary_val"] = {
513
+ "n_points": len(df_val),
514
+ "pred_min": float(df_val["UCS_Pred"].min()),
515
+ "pred_max": float(df_val["UCS_Pred"].max()),
516
+ "oor_pct": oor_pct
517
+ }
518
+ ss.results["oor_table"] = oor_table
519
+ st.experimental_rerun()
520
+
521
+ if ss.results.get("Validate") is not None:
522
+ st.subheader("Validation Results")
523
  sv = ss.results["summary_val"]; oor_table = ss.results.get("oor_table")
524
  c1,c2,c3,c4 = st.columns(4)
525
+ c1.metric("# points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
526
+ c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
527
+
528
+ left,right = st.columns([0.58, 0.42])
 
 
 
 
529
  with left:
530
  if TARGET in ss.results["Validate"].columns:
531
  st.plotly_chart(
532
+ make_cross_plotly(ss.results["Validate"][TARGET], ss.results["Validate"]["UCS_Pred"], height=440, width=640),
533
+ use_container_width=True, theme=None
534
  )
535
  else:
536
+ st.info("Actual UCS values are not available in the validation data.")
537
  with right:
538
  st.plotly_chart(
539
+ make_depth_track_plotly(ss.results["Validate"], include_actual=(TARGET in ss.results["Validate"].columns),
540
+ height=640, width=360),
541
+ use_container_width=True, theme=None
542
  )
 
543
  if oor_table is not None:
544
+ st.warning("Some validation rows contain inputs **outside** the training min–max. Review the table below.")
545
  st.dataframe(oor_table, use_container_width=True)
546
 
547
+ # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  # Footer
549
+ # =========================================================
550
  st.markdown("---")
551
  st.markdown(
552
  "<div style='text-align:center; color:#6b7280;'>"