UCS2014 commited on
Commit
ef237f6
·
verified ·
1 Parent(s): 73b13cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -406
app.py CHANGED
@@ -1,36 +1,63 @@
1
  # app.py
2
- import io, json, os, base64
3
  from pathlib import Path
4
  import streamlit as st
5
  import pandas as pd
6
  import numpy as np
7
  import joblib
8
 
9
- # keep matplotlib ONLY for the preview modal (static thumbnails)
10
  import matplotlib
11
  matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
13
 
14
  import plotly.graph_objects as go
15
  from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
16
- from math import floor, log10
17
 
18
  # =========================
19
- # Defaults
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # =========================
21
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
22
  TARGET = "UCS"
23
  MODELS_DIR = Path("models")
24
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
25
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
26
-
27
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
28
 
 
 
 
 
 
29
  # =========================
30
- # Page / Theme
31
  # =========================
32
- st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
33
-
34
  def inline_logo(path="logo.png") -> str:
35
  try:
36
  p = Path(path)
@@ -39,10 +66,72 @@ def inline_logo(path="logo.png") -> str:
39
  except Exception:
40
  return ""
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # =========================
43
- # Password (brand-gated)
44
  # =========================
45
  def add_password_gate() -> bool:
 
 
46
  try:
47
  required = st.secrets.get("APP_PASSWORD", "")
48
  except Exception:
@@ -60,8 +149,7 @@ def add_password_gate() -> bool:
60
  </div>
61
  <div style="font-size:1.25rem;font-weight:700;margin:8px 0 4px 0;">Protected Area</div>
62
  <div style="color:#6b7280;margin-bottom:14px;">
63
- Admin action required: set <code>APP_PASSWORD</code> in <b>Settings → Secrets</b>
64
- (or as an environment variable) and restart the Space.
65
  </div>
66
  """,
67
  unsafe_allow_html=True,
@@ -81,9 +169,7 @@ def add_password_gate() -> bool:
81
  </div>
82
  </div>
83
  <div style="font-size:1.25rem;font-weight:700;margin:8px 0 4px 0;">Protected</div>
84
- <div style="color:#6b7280;margin-bottom:14px;">
85
- Please enter your access key to continue.
86
- </div>
87
  """,
88
  unsafe_allow_html=True
89
  )
@@ -98,102 +184,22 @@ def add_password_gate() -> bool:
98
  st.error("Incorrect key. Please try again.")
99
  st.stop()
100
 
 
101
  add_password_gate()
102
 
103
- # CSS
104
- st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
105
- st.markdown(
106
- """
107
- <style>
108
- .stApp { background: #FFFFFF; }
109
- section[data-testid="stSidebar"] { background: #F6F9FC; }
110
- .block-container { padding-top: .5rem; padding-bottom: .5rem; }
111
- .stButton>button{ background:#007bff; color:#fff; font-weight:bold; border-radius:8px; border:none; padding:10px 24px; }
112
- .stButton>button:hover{ background:#0056b3; }
113
- .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
114
- .st-hero .brand { width:110px; height:110px; object-fit:contain; }
115
- .st-hero h1 { margin:0; line-height:1.05; }
116
- .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
117
- [data-testid="stBlock"]{ margin-top:0 !important; }
118
- </style>
119
- """,
120
- unsafe_allow_html=True
121
- )
122
 
123
  # =========================
124
- # Helpers
125
  # =========================
126
- try:
127
- dialog = st.dialog
128
- except AttributeError:
129
- def dialog(title):
130
- def deco(fn):
131
- def wrapper(*args, **kwargs):
132
- with st.expander(title, expanded=True):
133
- return fn(*args, **kwargs)
134
- return wrapper
135
- return deco
136
-
137
- def _get_model_url():
138
- return (os.environ.get("MODEL_URL", "") or "").strip()
139
-
140
- def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
141
-
142
- def ensure_cols(df, cols):
143
- miss = [c for c in cols if c not in df.columns]
144
- if miss:
145
- st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
146
- return False
147
- return True
148
-
149
- @st.cache_resource(show_spinner=False)
150
- def load_model(model_path: str):
151
- return joblib.load(model_path)
152
-
153
- @st.cache_data(show_spinner=False)
154
- def parse_excel(data_bytes: bytes):
155
- bio = io.BytesIO(data_bytes)
156
- xl = pd.ExcelFile(bio)
157
- return {sh: xl.parse(sh) for sh in xl.sheet_names}
158
-
159
- def read_book_bytes(data_bytes: bytes):
160
- if not data_bytes: return {}
161
- try: return parse_excel(data_bytes)
162
- except Exception as e:
163
- st.error(f"Failed to read Excel: {e}"); return {}
164
-
165
- def find_sheet(book, names):
166
- low2orig = {k.lower(): k for k in book.keys()}
167
- for nm in names:
168
- if nm.lower() in low2orig: return low2orig[nm.lower()]
169
- return None
170
-
171
- # ----- Nice tick step for cross-plot -----
172
- def _nice_dtick(data_range: float) -> float:
173
- if data_range <= 0 or np.isnan(data_range): return 1.0
174
- raw = data_range / 6.0 # aim ~6 ticks
175
- k = floor(log10(raw))
176
- base = 10 ** k
177
- m = raw / base
178
- if m <= 1.5:
179
- step = 1 * base
180
- elif m <= 3.5:
181
- step = 2 * base
182
- elif m <= 7.5:
183
- step = 5 * base
184
- else:
185
- step = 10 * base
186
- return step
187
-
188
- # ---------- Interactive plotting ----------
189
- def cross_plot_interactive(actual, pred, size=(3.9, 3.9)):
190
  a = pd.Series(actual).astype(float)
191
  p = pd.Series(pred).astype(float)
192
  lo = float(np.nanmin([a.min(), p.min()]))
193
  hi = float(np.nanmax([a.max(), p.max()]))
194
- pad = 0.03 * (hi - lo if hi > lo else 1.0)
195
  x0, x1 = lo - pad, hi + pad
196
- dtick = _nice_dtick(x1 - x0)
197
 
198
  fig = go.Figure()
199
  fig.add_trace(go.Scatter(
@@ -208,30 +214,28 @@ def cross_plot_interactive(actual, pred, size=(3.9, 3.9)):
208
  hoverinfo="skip", showlegend=False
209
  ))
210
  fig.update_layout(
 
211
  paper_bgcolor="#ffffff", plot_bgcolor="#ffffff",
212
- margin=dict(l=50, r=10, t=10, b=36),
213
- hovermode="closest", font=dict(size=13), dragmode="zoom"
214
  )
215
  fig.update_xaxes(
216
- title_text="<b>Actual UCS</b>",
217
- range=[x0, x1], tickmode="linear", dtick=dtick, ticks="outside",
218
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
219
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
220
- tickformat=",.0f", automargin=True
221
  )
222
  fig.update_yaxes(
223
- title_text="<b>Predicted UCS</b>",
224
- range=[x0, x1], tickmode="linear", dtick=dtick, ticks="outside",
225
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
226
  showgrid=True, gridcolor="rgba(0,0,0,0.12)",
227
- tickformat=",.0f", scaleanchor="x", scaleratio=1,
228
- automargin=True
229
  )
230
- w = int(size[0] * 100); h = int(size[1] * 100)
231
- fig.update_layout(width=w, height=h)
232
  return fig
233
 
234
- def depth_or_index_track_interactive(df, title=None, include_actual=True, x_range=None):
 
235
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
236
  if depth_col is not None:
237
  y = df[depth_col]; y_label = depth_col
@@ -252,38 +256,46 @@ def depth_or_index_track_interactive(df, title=None, include_actual=True, x_rang
252
  name="UCS (actual)",
253
  hovertemplate="UCS (actual): %{x:.0f}<br>"+y_label+": %{y}<extra></extra>"
254
  ))
255
-
256
- # slimmer & taller like a log profile
257
  fig.update_layout(
 
258
  paper_bgcolor="#ffffff", plot_bgcolor="#ffffff",
259
- margin=dict(l=60, r=10, t=10, b=36),
260
  hovermode="closest", font=dict(size=13),
261
  legend=dict(
262
  x=0.98, y=0.05, xanchor="right", yanchor="bottom",
263
  bgcolor="rgba(255,255,255,0.75)", bordercolor="#cccccc", borderwidth=1
264
  ),
265
- legend_title_text="",
266
- width=int(2.4 * 100), # narrower
267
- height=int(8.4 * 100), # taller
268
- dragmode="zoom"
269
  )
270
  fig.update_xaxes(
271
  title_text="<b>UCS</b>", side="top",
272
- ticks="outside", showline=True, linewidth=1.2, linecolor="#444", mirror=True,
273
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
274
- tickformat=",.0f",
275
- automargin=True,
276
- range=x_range
277
  )
278
  fig.update_yaxes(
279
  title_text=f"<b>{y_label}</b>", autorange="reversed",
280
- ticks="outside", showline=True, linewidth=1.2, linecolor="#444", mirror=True,
281
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
282
- automargin=True
283
  )
284
  return fig
285
 
286
- # ---------- Preview modal helpers (matplotlib static) ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  def make_index_tracks(df: pd.DataFrame, cols: list[str]):
288
  cols = [c for c in cols if c in df.columns]
289
  n = len(cols)
@@ -313,37 +325,24 @@ def stats_table(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
313
  return out.reset_index().rename(columns={"index": "Feature"})
314
 
315
  @dialog("Preview data")
316
- def preview_modal_dev(book: dict[str, pd.DataFrame], feature_cols: list[str]):
317
  if not book:
318
  st.info("No data loaded yet."); return
319
- sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
320
- sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
321
- tabs, data = [], []
322
- if sh_train: tabs.append("Train"); data.append(book[sh_train])
323
- if sh_test: tabs.append("Test"); data.append(book[sh_test])
324
- if not tabs:
325
  first_name = list(book.keys())[0]
326
- tabs = [first_name]; data = [book[first_name]]
327
- st.write("Use the tabs to switch between Train/Test views (if available).")
328
- t_objs = st.tabs(tabs)
329
- for t, df in zip(t_objs, data):
330
  with t:
331
- t1, t2 = st.tabs(["Tracks", "Summary"])
332
  with t1: st.pyplot(make_index_tracks(df, feature_cols), use_container_width=True)
333
  with t2: st.dataframe(stats_table(df, feature_cols), use_container_width=True)
334
 
335
- @dialog("Preview data")
336
- def preview_modal_val(book: dict[str, pd.DataFrame], feature_cols: list[str]):
337
- if not book:
338
- st.info("No data loaded yet."); return
339
- vname = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
340
- df = book[vname]
341
- t1, t2 = st.tabs(["Tracks", "Summary"])
342
- with t1: st.pyplot(make_index_tracks(df, feature_cols), use_container_width=True)
343
- with t2: st.dataframe(stats_table(df, feature_cols), use_container_width=True)
344
 
345
  # =========================
346
- # Model presence
347
  # =========================
348
  MODEL_URL = _get_model_url()
349
 
@@ -384,24 +383,12 @@ if meta_path.exists():
384
  try:
385
  meta = json.loads(meta_path.read_text(encoding="utf-8"))
386
  FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
387
- except Exception:
388
- pass
389
  else:
390
- def infer_features_from_model(m):
391
- try:
392
- if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
393
- return [str(x) for x in m.feature_names_in_]
394
- except Exception: pass
395
- try:
396
- if hasattr(m, "steps") and len(m.steps):
397
- last = m.steps[-1][1]
398
- if hasattr(last, "feature_names_in_") and len(last.feature_names_in_):
399
- return [str(x) for x in last.feature_names_in_]
400
- except Exception: pass
401
- return None
402
  infer = infer_features_from_model(model)
403
  if infer: FEATURES = infer
404
 
 
405
  # =========================
406
  # Session state
407
  # =========================
@@ -409,17 +396,21 @@ if "app_step" not in st.session_state: st.session_state.app_step = "intro"
409
  if "results" not in st.session_state: st.session_state.results = {}
410
  if "train_ranges" not in st.session_state: st.session_state.train_ranges = None
411
 
412
- # Dev/Val/Pred state
413
- defaults = {
414
- "dev_ready": False, "dev_file_loaded": False, "dev_previewed": False,
415
- "dev_file_signature": None, "dev_preview_request": False,
416
- "dev_file_bytes": b"", "dev_file_name": "", "dev_file_rows": 0, "dev_file_cols": 0,
417
- "val_file_bytes": b"", "val_file_loaded": False, "val_preview_request": False,
418
- "pred_file_bytes": b"", "pred_file_loaded": False, "pred_preview_request": False,
419
- }
420
- for k, v in defaults.items():
 
 
 
421
  if k not in st.session_state: st.session_state[k] = v
422
 
 
423
  # =========================
424
  # Hero header
425
  # =========================
@@ -436,12 +427,15 @@ st.markdown(
436
  unsafe_allow_html=True,
437
  )
438
 
 
439
  # =========================
440
- # INTRO PAGE
441
  # =========================
442
  if st.session_state.app_step == "intro":
443
  st.header("Welcome!")
444
- st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
 
 
445
  st.subheader("Expected Input Features (in Order)")
446
  st.markdown(
447
  "- Q, gpm — Flow rate (gallons per minute) \n"
@@ -450,28 +444,29 @@ if st.session_state.app_step == "intro":
450
  "- WOB (klbf) — Weight on bit \n"
451
  "- ROP (ft/h) — Rate of penetration"
452
  )
453
- st.subheader("How It Works")
454
  st.markdown(
455
- "1. **Upload your data to build the case and preview the performance of our model.** \n"
456
- "2. Click **Run Model** to compute metrics and plots. \n"
457
- "3. Click **Proceed to Validation** to evaluate on a new dataset with actual UCS (if available). \n"
458
- "4. Click **Proceed to Prediction** to generate production predictions (no actuals). \n"
459
- "5. Export results to Excel at any time."
460
  )
461
- if st.button("Start Showcase", type="primary", key="start_showcase"):
462
  st.session_state.app_step = "dev"; st.rerun()
463
 
 
464
  # =========================
465
- # 1) CASE BUILDING (Development)
466
  # =========================
467
  if st.session_state.app_step == "dev":
468
  st.sidebar.header("Case Building (Development)")
469
- dev_label = "Upload Data (Excel)" if not st.session_state.dev_file_name else "Replace data (Excel)"
470
  train_test_file = st.sidebar.file_uploader(dev_label, type=["xlsx","xls"], key="dev_upload")
471
 
 
472
  if train_test_file is not None:
473
  try:
474
- file_bytes = train_test_file.getvalue(); size = len(file_bytes)
 
475
  except Exception:
476
  file_bytes = b""; size = 0
477
  sig = (train_test_file.name, size)
@@ -494,37 +489,35 @@ if st.session_state.app_step == "dev":
494
  f"{st.session_state.dev_file_rows} rows × {st.session_state.dev_file_cols} cols"
495
  )
496
 
497
- preview_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded)
498
- if preview_btn and st.session_state.dev_file_loaded:
499
  st.session_state.dev_preview_request = True
500
-
501
  run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
502
-
503
- # jump links
504
- proceed_val = st.sidebar.button("Proceed to Validation ▶", use_container_width=True)
505
- proceed_pred = st.sidebar.button("Proceed to Prediction ▶", use_container_width=True)
506
- if proceed_val:
507
  st.session_state.app_step = "validate"; st.rerun()
508
- if proceed_pred:
509
  st.session_state.app_step = "predict"; st.rerun()
510
 
511
- with st.container():
512
- st.subheader("Case Building")
513
- if st.session_state.dev_ready:
514
- st.success("Case has been built and results are displayed below.")
515
- elif st.session_state.dev_file_loaded and st.session_state.dev_previewed:
516
- st.info("Previewed ✓ — now click **Run Model** to build the case.")
517
- elif st.session_state.dev_file_loaded:
518
- st.info("📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
519
- else:
520
- st.write("**Upload your data to build a case, then run the model to review development performance.**")
521
 
 
522
  if st.session_state.dev_preview_request and st.session_state.dev_file_bytes:
523
  _book = read_book_bytes(st.session_state.dev_file_bytes)
524
  st.session_state.dev_previewed = True
525
  st.session_state.dev_preview_request = False
526
- preview_modal_dev(_book, FEATURES)
527
 
 
528
  if run_btn and st.session_state.dev_file_bytes:
529
  with st.status("Processing…", expanded=False) as status:
530
  book = read_book_bytes(st.session_state.dev_file_bytes)
@@ -533,7 +526,7 @@ if st.session_state.app_step == "dev":
533
  sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
534
  sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
535
  if sh_train is None or sh_test is None:
536
- status.update(label="Workbook must include Train/Training/training2 and Test/Testing/testing2.", state="error"); st.stop()
537
  df_tr = book[sh_train].copy(); df_te = book[sh_test].copy()
538
  if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])):
539
  status.update(label="Missing required columns.", state="error"); st.stop()
@@ -553,70 +546,42 @@ if st.session_state.app_step == "dev":
553
  "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
554
  "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
555
  }
556
-
557
  tr_min = df_tr[FEATURES].min().to_dict(); tr_max = df_tr[FEATURES].max().to_dict()
558
  st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
559
 
560
  st.session_state.dev_ready = True
561
  status.update(label="Done ✓", state="complete"); st.rerun()
562
 
 
563
  if ("Train" in st.session_state.results) or ("Test" in st.session_state.results):
564
  tab1, tab2 = st.tabs(["Training", "Testing"])
 
 
565
  if "Train" in st.session_state.results:
566
  with tab1:
567
  df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
568
  c1,c2,c3 = st.columns(3)
569
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
570
- left, right = st.columns([0.9, 0.55])
571
  with left:
572
- st.plotly_chart(
573
- cross_plot_interactive(df[TARGET], df["UCS_Pred"], size=(3.9,3.9)),
574
- use_container_width=True, config={"displayModeBar": False}
575
- )
576
  with right:
577
- # Zoom control for UCS axis
578
- pr_min = float(df["UCS_Pred"].min())
579
- xs = [pr_min]
580
- if TARGET in df: xs.append(float(df[TARGET].min()))
581
- x_min = min(xs)
582
- pr_max = float(df["UCS_Pred"].max())
583
- xs = [pr_max]
584
- if TARGET in df: xs.append(float(df[TARGET].max()))
585
- x_max = max(xs)
586
- with st.expander("Zoom (UCS axis)", expanded=False):
587
- z = st.slider("UCS range", min_value=float(x_min), max_value=float(x_max),
588
- value=(float(x_min), float(x_max)), step=10.0, key="zoom_train")
589
- st.plotly_chart(
590
- depth_or_index_track_interactive(df, title=None, include_actual=True, x_range=z),
591
- use_container_width=True, config={"displayModeBar": False}
592
- )
593
  if "Test" in st.session_state.results:
594
  with tab2:
595
  df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
596
  c1,c2,c3 = st.columns(3)
597
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
598
- left, right = st.columns([0.9, 0.55])
599
  with left:
600
- st.plotly_chart(
601
- cross_plot_interactive(df[TARGET], df["UCS_Pred"], size=(3.9,3.9)),
602
- use_container_width=True, config={"displayModeBar": False}
603
- )
604
  with right:
605
- pr_min = float(df["UCS_Pred"].min())
606
- xs = [pr_min]
607
- if TARGET in df: xs.append(float(df[TARGET].min()))
608
- x_min = min(xs)
609
- pr_max = float(df["UCS_Pred"].max())
610
- xs = [pr_max]
611
- if TARGET in df: xs.append(float(df[TARGET].max()))
612
- x_max = max(xs)
613
- with st.expander("Zoom (UCS axis)", expanded=False):
614
- z2 = st.slider("UCS range", min_value=float(x_min), max_value=float(x_max),
615
- value=(float(x_min), float(x_max)), step=10.0, key="zoom_test")
616
- st.plotly_chart(
617
- depth_or_index_track_interactive(df, title=None, include_actual=True, x_range=z2),
618
- use_container_width=True, config={"displayModeBar": False}
619
- )
620
 
621
  st.markdown("---")
622
  sheets = {}; rows = []
@@ -643,48 +608,38 @@ if st.session_state.app_step == "dev":
643
  except Exception as e:
644
  st.warning(str(e))
645
 
 
646
  # =========================
647
- # 2) VALIDATE THE MODEL
648
  # =========================
649
  if st.session_state.app_step == "validate":
650
- st.sidebar.header("Validate the model")
651
  validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload")
652
-
653
  if validation_file is not None:
654
- st.session_state.val_file_bytes = validation_file.getvalue()
655
- _book_tmp = read_book_bytes(st.session_state.val_file_bytes)
656
  if _book_tmp:
657
  first_df = next(iter(_book_tmp.values()))
658
- st.session_state.val_file_loaded = True
659
  st.sidebar.caption(f"**Data loaded:** {validation_file.name} • {first_df.shape[0]} rows × {first_df.shape[1]} cols")
660
 
661
- preview_val_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.val_file_loaded)
662
- if preview_val_btn and st.session_state.val_file_loaded:
663
- st.session_state.val_preview_request = True
664
 
665
  predict_btn = st.sidebar.button("Run Validation", type="primary", use_container_width=True)
666
- proceed_pred = st.sidebar.button("Proceed to Prediction ▶", use_container_width=True)
667
  st.sidebar.button("⬅ Back to Case Building", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
668
- if proceed_pred:
669
- st.session_state.app_step = "predict"; st.rerun()
670
 
671
- with st.container():
672
- st.subheader("Validate the model")
673
- st.write("Upload a validation dataset (with actual UCS if available), preview it, then run to view metrics and plots.")
674
 
675
- if st.session_state.val_preview_request and st.session_state.val_file_bytes:
676
- _book = read_book_bytes(st.session_state.val_file_bytes)
677
- st.session_state.val_preview_request = False
678
- preview_modal_val(_book, FEATURES)
679
-
680
- if predict_btn and st.session_state.val_file_bytes:
681
  with st.status("Validating…", expanded=False) as status:
682
- vbook = read_book_bytes(st.session_state.val_file_bytes)
683
  if not vbook: status.update(label="Could not read the Validation Excel.", state="error"); st.stop()
684
  status.update(label="Workbook read ✓")
685
  vname = find_sheet(vbook, ["Validation","Validate","validation2","Val","val"]) or list(vbook.keys())[0]
686
  df_val = vbook[vname].copy()
687
- if not ensure_cols(df_val, FEATURES): status.update(label="Missing required columns.", state="error"); st.stop()
688
  status.update(label="Columns validated ✓")
689
  df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
690
  st.session_state.results["Validate"] = df_val
@@ -698,13 +653,11 @@ if st.session_state.app_step == "validate":
698
  offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
699
  offenders.index = offenders.index + 1; oor_table = offenders
700
 
701
- metrics_val = None
702
- if TARGET in df_val.columns:
703
- metrics_val = {
704
- "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
705
- "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
706
- "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"])
707
- }
708
  st.session_state.results["metrics_val"] = metrics_val
709
  st.session_state.results["summary_val"] = {
710
  "n_points": len(df_val),
@@ -713,71 +666,34 @@ if st.session_state.app_step == "validate":
713
  "oor_pct": oor_pct
714
  }
715
  st.session_state.results["oor_table"] = oor_table
716
- status.update(label="Validation ready ✓", state="complete")
717
 
718
  if "Validate" in st.session_state.results:
719
- sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
720
-
 
721
  if sv["oor_pct"] > 0:
722
  st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
 
 
723
 
724
- metrics_val = st.session_state.results.get("metrics_val")
725
- if metrics_val is not None:
726
- c1, c2, c3 = st.columns(3)
727
- c1.metric("R²", f"{metrics_val['R2']:.4f}")
728
- c2.metric("RMSE", f"{metrics_val['RMSE']:.4f}")
729
- c3.metric("MAE", f"{metrics_val['MAE']:.4f}")
730
- else:
731
- c1, c2, c3 = st.columns(3)
732
- c1.metric("# points", f"{sv['n_points']}")
733
- c2.metric("Pred min", f"{sv['pred_min']:.2f}")
734
- c3.metric("Pred max", f"{sv['pred_max']:.2f}")
735
-
736
- left, right = st.columns([0.9, 0.55])
737
  with left:
738
- if TARGET in st.session_state.results["Validate"].columns:
739
- st.plotly_chart(
740
- cross_plot_interactive(
741
- st.session_state.results["Validate"][TARGET],
742
- st.session_state.results["Validate"]["UCS_Pred"],
743
- size=(3.9,3.9)
744
- ),
745
- use_container_width=True, config={"displayModeBar": False}
746
- )
747
- else:
748
- st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
749
  with right:
750
- df = st.session_state.results["Validate"]
751
- pr_min = float(df["UCS_Pred"].min())
752
- xs = [pr_min]
753
- if TARGET in df: xs.append(float(df[TARGET].min()))
754
- x_min = min(xs)
755
- pr_max = float(df["UCS_Pred"].max())
756
- xs = [pr_max]
757
- if TARGET in df: xs.append(float(df[TARGET].max()))
758
- x_max = max(xs)
759
- with st.expander("Zoom (UCS axis)", expanded=False):
760
- zv = st.slider("UCS range", min_value=float(x_min), max_value=float(x_max),
761
- value=(float(x_min), float(x_max)), step=10.0, key="zoom_val")
762
- st.plotly_chart(
763
- depth_or_index_track_interactive(
764
- df, title=None,
765
- include_actual=(TARGET in df.columns),
766
- x_range=zv
767
- ),
768
- use_container_width=True, config={"displayModeBar": False}
769
- )
770
 
771
- if oor_table is not None:
772
  st.write("*Out-of-range rows (vs. Training min–max):*")
773
- st.dataframe(oor_table, use_container_width=True)
774
 
 
775
  st.markdown("---")
776
- sheets = {"Validate_with_pred": st.session_state.results["Validate"]}
777
  rows = []
778
  for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]:
779
- m = st.session_state.results.get(key)
780
- if m: rows.append({"Split": name, **{k: round(v,6) for k,v in m.items()}})
781
  summary_df = pd.DataFrame(rows) if rows else None
782
  try:
783
  buf = io.BytesIO()
@@ -795,114 +711,86 @@ if st.session_state.app_step == "validate":
795
  except Exception as e:
796
  st.warning(str(e))
797
 
 
 
 
 
798
  # =========================
799
- # 3) PREDICTION (no actual UCS)
800
  # =========================
801
  if st.session_state.app_step == "predict":
802
  st.sidebar.header("Prediction")
803
- pred_file = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"], key="pred_upload")
804
-
805
  if pred_file is not None:
806
- st.session_state.pred_file_bytes = pred_file.getvalue()
807
- _book_tmp = read_book_bytes(st.session_state.pred_file_bytes)
808
  if _book_tmp:
809
  first_df = next(iter(_book_tmp.values()))
810
- st.session_state.pred_file_loaded = True
811
  st.sidebar.caption(f"**Data loaded:** {pred_file.name} • {first_df.shape[0]} rows × {first_df.shape[1]} cols")
812
 
813
- preview_pred_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.pred_file_loaded)
814
- if preview_pred_btn and st.session_state.pred_file_loaded:
815
- st.session_state.pred_preview_request = True
816
 
817
- predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
818
  st.sidebar.button("⬅ Back to Validation", on_click=lambda: st.session_state.update(app_step="validate"), use_container_width=True)
 
819
 
820
- with st.container():
821
- st.subheader("Prediction")
822
- st.write("Upload a dataset (no actual UCS needed), preview it, then click **Predict** to generate UCS estimates.")
823
-
824
- if st.session_state.pred_preview_request and st.session_state.pred_file_bytes:
825
- _book = read_book_bytes(st.session_state.pred_file_bytes)
826
- st.session_state.pred_preview_request = False
827
- preview_modal_val(_book, FEATURES)
828
 
829
- if predict_btn and st.session_state.pred_file_bytes:
830
  with st.status("Predicting…", expanded=False) as status:
831
- pbook = read_book_bytes(st.session_state.pred_file_bytes)
832
- if not pbook: status.update(label="Could not read the Excel file.", state="error"); st.stop()
833
  status.update(label="Workbook read ✓")
834
- pname = list(pbook.keys())[0]
835
- df_pred = pbook[pname].copy()
836
- if not ensure_cols(df_pred, FEATURES): status.update(label="Missing required columns.", state="error"); st.stop()
837
  status.update(label="Columns validated ✓")
838
- df_pred["UCS_Pred"] = model.predict(df_pred[FEATURES])
839
- st.session_state.results["Prediction"] = df_pred
840
 
841
- ranges = st.session_state.train_ranges; oor_table = None; oor_pct = 0.0
842
  if ranges:
843
- viol = {f: (df_pred[f] < ranges[f][0]) | (df_pred[f] > ranges[f][1]) for f in FEATURES}
844
  any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
845
- if any_viol.any():
846
- offenders = df_pred.loc[any_viol, FEATURES].copy()
847
- offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
848
- offenders.index = offenders.index + 1; oor_table = offenders
849
-
850
- st.session_state.results["summary_pred"] = {
851
- "n_points": len(df_pred),
852
- "pred_min": float(df_pred["UCS_Pred"].min()),
853
- "pred_max": float(df_pred["UCS_Pred"].max()),
854
- "pred_mean": float(df_pred["UCS_Pred"].mean()),
855
- "pred_std": float(df_pred["UCS_Pred"].std(ddof=0)),
856
  "oor_pct": oor_pct
857
  }
858
- st.session_state.results["oor_table_pred"] = oor_table
859
  status.update(label="Predictions ready ✓", state="complete")
860
 
861
- if "Prediction" in st.session_state.results:
862
- sv = st.session_state.results["summary_pred"]
863
- if sv.get("oor_pct", 0) > 0:
864
- st.warning("Some inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
865
 
866
- left, right = st.columns([0.6, 0.9])
 
867
  with left:
868
- table = pd.DataFrame(
869
- {
870
- "Metric": ["# points", "Pred min", "Pred max", "Pred mean", "Pred std", "OOR %"],
871
- "Value": [
872
- f"{sv['n_points']}",
873
- f"{sv['pred_min']:.2f}",
874
- f"{sv['pred_max']:.2f}",
875
- f"{sv['pred_mean']:.2f}",
876
- f"{sv['pred_std']:.2f}",
877
- f"{sv['oor_pct']:.1f}%",
878
- ],
879
- }
880
- )
881
- st.dataframe(table, use_container_width=True, hide_index=True)
882
- # ★ footnote under table
883
- st.caption("★ OOR % = percentage of rows where at least one input feature is outside the training set's min–max range.")
884
  with right:
885
- # Optional zoom
886
- dfp = st.session_state.results["Prediction"]
887
- pmin, pmax = float(dfp["UCS_Pred"].min()), float(dfp["UCS_Pred"].max())
888
- with st.expander("Zoom (UCS axis)", expanded=False):
889
- zp = st.slider("UCS range", min_value=pmin, max_value=pmax, value=(pmin, pmax), step=10.0, key="zoom_pred")
890
  st.plotly_chart(
891
- depth_or_index_track_interactive(
892
- dfp, title=None, include_actual=False, x_range=zp
893
- ),
894
- use_container_width=True, config={"displayModeBar": False}
895
  )
896
 
897
- if st.session_state.results.get("oor_table_pred") is not None:
898
- st.write("*Out-of-range rows (vs. Training min–max):*")
899
- st.dataframe(st.session_state.results["oor_table_pred"], use_container_width=True)
900
-
901
  st.markdown("---")
902
  try:
903
  buf = io.BytesIO()
904
  with pd.ExcelWriter(buf, engine="openpyxl") as xw:
905
- st.session_state.results["Prediction"].to_excel(xw, sheet_name="Prediction_with_pred", index=False)
906
  pd.DataFrame([sv]).to_excel(xw, sheet_name="Summary", index=False)
907
  st.download_button(
908
  "Export Prediction Results to Excel",
@@ -913,6 +801,7 @@ if st.session_state.app_step == "predict":
913
  except Exception as e:
914
  st.warning(str(e))
915
 
 
916
  # =========================
917
  # Footer
918
  # =========================
 
1
  # app.py
2
+ import io, json, os, base64, math
3
  from pathlib import Path
4
  import streamlit as st
5
  import pandas as pd
6
  import numpy as np
7
  import joblib
8
 
9
+ # matplotlib only for preview modal thumbnails
10
  import matplotlib
11
  matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
13
 
14
  import plotly.graph_objects as go
15
  from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
 
16
 
17
  # =========================
18
+ # Page / Theme
19
+ # =========================
20
+ st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
21
+ st.markdown(
22
+ """
23
+ <style>
24
+ header, footer {visibility:hidden !important;}
25
+ .stApp { background: #FFFFFF; }
26
+ section[data-testid="stSidebar"] { background: #F6F9FC; }
27
+ .block-container { padding-top: .5rem; padding-bottom: .5rem; }
28
+ /* Buttons */
29
+ .stButton>button{ background:#007bff; color:#fff; font-weight:600;
30
+ border-radius:8px; border:none; padding:10px 20px; }
31
+ .stButton>button:hover{ background:#0056b3; }
32
+ /* Hero */
33
+ .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
34
+ .st-hero .brand { width:110px; height:110px; object-fit:contain; }
35
+ .st-hero h1 { margin:0; line-height:1.05; }
36
+ .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
37
+ [data-testid="stBlock"]{ margin-top:0 !important; }
38
+ </style>
39
+ """,
40
+ unsafe_allow_html=True
41
+ )
42
+
43
+ # =========================
44
+ # Small constants / defaults
45
  # =========================
46
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
47
  TARGET = "UCS"
48
  MODELS_DIR = Path("models")
49
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
50
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
 
51
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
52
 
53
+ # Fixed plot sizes (tuned to fit without scrolling on typical 14–15" laptops)
54
+ CROSS_W, CROSS_H = 540, 540 # px — square cross-plot
55
+ TRACK_W, TRACK_H = 360, 700 # px — slim & tall track
56
+
57
+
58
  # =========================
59
+ # Helpers (general)
60
  # =========================
 
 
61
  def inline_logo(path="logo.png") -> str:
62
  try:
63
  p = Path(path)
 
66
  except Exception:
67
  return ""
68
 
69
+ def _get_model_url():
70
+ return (os.environ.get("MODEL_URL", "") or "").strip()
71
+
72
+ def rmse(y_true, y_pred):
73
+ return float(np.sqrt(mean_squared_error(y_true, y_pred)))
74
+
75
+ @st.cache_resource(show_spinner=False)
76
+ def load_model(model_path: str):
77
+ return joblib.load(model_path)
78
+
79
+ @st.cache_data(show_spinner=False)
80
+ def parse_excel(data_bytes: bytes):
81
+ bio = io.BytesIO(data_bytes)
82
+ xl = pd.ExcelFile(bio)
83
+ return {sh: xl.parse(sh) for sh in xl.sheet_names}
84
+
85
+ def read_book_bytes(data_bytes: bytes):
86
+ if not data_bytes: return {}
87
+ try: return parse_excel(data_bytes)
88
+ except Exception as e:
89
+ st.error(f"Failed to read Excel: {e}"); return {}
90
+
91
+ def find_sheet(book, names):
92
+ low2orig = {k.lower(): k for k in book.keys()}
93
+ for nm in names:
94
+ if nm.lower() in low2orig: return low2orig[nm.lower()]
95
+ return None
96
+
97
+ def ensure_cols(df, cols):
98
+ miss = [c for c in cols if c not in df.columns]
99
+ if miss:
100
+ st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
101
+ return False
102
+ return True
103
+
104
+ def infer_features_from_model(m):
105
+ try:
106
+ if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
107
+ return [str(x) for x in m.feature_names_in_]
108
+ except Exception: pass
109
+ try:
110
+ if hasattr(m, "steps") and len(m.steps):
111
+ last = m.steps[-1][1]
112
+ if hasattr(last, "feature_names_in_") and len(last.feature_names_in_):
113
+ return [str(x) for x in last.feature_names_in_]
114
+ except Exception: pass
115
+ return None
116
+
117
+ def compute_tick_step(lo, hi, target_ticks=6):
118
+ rng = max(hi - lo, 1.0)
119
+ raw = rng / target_ticks
120
+ power = 10 ** math.floor(math.log10(raw))
121
+ mult = round(raw / power)
122
+ step = mult * power
123
+ # snap to 50/100/200 etc for tidier thousands
124
+ if step >= 50 and step % 50 != 0:
125
+ step = round(step / 50) * 50
126
+ return step
127
+
128
+
129
  # =========================
130
+ # Password Gate (define FIRST, then call)
131
  # =========================
132
  def add_password_gate() -> bool:
133
+ """Branded password screen. Requires APP_PASSWORD in Secrets/Env."""
134
+ required = ""
135
  try:
136
  required = st.secrets.get("APP_PASSWORD", "")
137
  except Exception:
 
149
  </div>
150
  <div style="font-size:1.25rem;font-weight:700;margin:8px 0 4px 0;">Protected Area</div>
151
  <div style="color:#6b7280;margin-bottom:14px;">
152
+ Admin action required: set <code>APP_PASSWORD</code> in <b>Settings → Secrets</b>, then restart the Space.
 
153
  </div>
154
  """,
155
  unsafe_allow_html=True,
 
169
  </div>
170
  </div>
171
  <div style="font-size:1.25rem;font-weight:700;margin:8px 0 4px 0;">Protected</div>
172
+ <div style="color:#6b7280;margin-bottom:14px;">Please enter your access key to continue.</div>
 
 
173
  """,
174
  unsafe_allow_html=True
175
  )
 
184
  st.error("Incorrect key. Please try again.")
185
  st.stop()
186
 
187
+ # 🔒 Invoke password gate
188
  add_password_gate()
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  # =========================
192
+ # Interactive plots
193
  # =========================
194
+ def cross_plot_interactive(actual, pred, width=CROSS_W, height=CROSS_H):
195
+ """Fixed-size square cross-plot, 1:1 axes, tidy ticks, full outline."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  a = pd.Series(actual).astype(float)
197
  p = pd.Series(pred).astype(float)
198
  lo = float(np.nanmin([a.min(), p.min()]))
199
  hi = float(np.nanmax([a.max(), p.max()]))
200
+ pad = 0.04 * (hi - lo if hi > lo else 1.0)
201
  x0, x1 = lo - pad, hi + pad
202
+ step = compute_tick_step(x0, x1)
203
 
204
  fig = go.Figure()
205
  fig.add_trace(go.Scatter(
 
214
  hoverinfo="skip", showlegend=False
215
  ))
216
  fig.update_layout(
217
+ width=width, height=height, title=None,
218
  paper_bgcolor="#ffffff", plot_bgcolor="#ffffff",
219
+ margin=dict(l=60, r=20, t=10, b=50),
220
+ hovermode="closest", font=dict(size=13)
221
  )
222
  fig.update_xaxes(
223
+ title_text="<b>Actual UCS</b>", range=[x0, x1], dtick=step,
224
+ ticks="outside", tickformat=",.0f",
225
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
226
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
 
227
  )
228
  fig.update_yaxes(
229
+ title_text="<b>Predicted UCS</b>", range=[x0, x1], dtick=step,
230
+ ticks="outside", tickformat=",.0f",
231
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
232
  showgrid=True, gridcolor="rgba(0,0,0,0.12)",
233
+ scaleanchor="x", scaleratio=1, automargin=True
 
234
  )
 
 
235
  return fig
236
 
237
+ def depth_or_index_track_interactive(df, include_actual=True, width=TRACK_W, height=TRACK_H):
238
+ """Tall & slim track; legend inside; x on top; full outline; reversed y."""
239
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
240
  if depth_col is not None:
241
  y = df[depth_col]; y_label = depth_col
 
256
  name="UCS (actual)",
257
  hovertemplate="UCS (actual): %{x:.0f}<br>"+y_label+": %{y}<extra></extra>"
258
  ))
 
 
259
  fig.update_layout(
260
+ width=width, height=height,
261
  paper_bgcolor="#ffffff", plot_bgcolor="#ffffff",
262
+ margin=dict(l=60, r=10, t=10, b=40),
263
  hovermode="closest", font=dict(size=13),
264
  legend=dict(
265
  x=0.98, y=0.05, xanchor="right", yanchor="bottom",
266
  bgcolor="rgba(255,255,255,0.75)", bordercolor="#cccccc", borderwidth=1
267
  ),
268
+ legend_title_text=""
 
 
 
269
  )
270
  fig.update_xaxes(
271
  title_text="<b>UCS</b>", side="top",
272
+ ticks="outside", tickformat=",.0f",
273
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
274
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
 
 
275
  )
276
  fig.update_yaxes(
277
  title_text=f"<b>{y_label}</b>", autorange="reversed",
278
+ ticks="outside",
279
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
280
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
281
  )
282
  return fig
283
 
284
+
285
+ # =========================
286
+ # Preview modals (matplotlib)
287
+ # =========================
288
+ try:
289
+ dialog = st.dialog
290
+ except AttributeError:
291
+ def dialog(title):
292
+ def deco(fn):
293
+ def wrapper(*args, **kwargs):
294
+ with st.expander(title, expanded=True):
295
+ return fn(*args, **kwargs)
296
+ return wrapper
297
+ return deco
298
+
299
  def make_index_tracks(df: pd.DataFrame, cols: list[str]):
300
  cols = [c for c in cols if c in df.columns]
301
  n = len(cols)
 
325
  return out.reset_index().rename(columns={"index": "Feature"})
326
 
327
  @dialog("Preview data")
328
+ def preview_modal(book: dict[str, pd.DataFrame], feature_cols: list[str], sheet_names):
329
  if not book:
330
  st.info("No data loaded yet."); return
331
+ resolved = [find_sheet(book, [nm]) for nm in sheet_names]
332
+ existing = [(nm, book[nm]) for nm in resolved if nm is not None]
333
+ if not existing:
 
 
 
334
  first_name = list(book.keys())[0]
335
+ existing = [(first_name, book[first_name])]
336
+ tabs = st.tabs([nm for nm,_ in existing])
337
+ for t,(nm,df) in zip(tabs, existing):
 
338
  with t:
339
+ t1,t2 = st.tabs(["Tracks","Summary"])
340
  with t1: st.pyplot(make_index_tracks(df, feature_cols), use_container_width=True)
341
  with t2: st.dataframe(stats_table(df, feature_cols), use_container_width=True)
342
 
 
 
 
 
 
 
 
 
 
343
 
344
  # =========================
345
+ # Model loading (includes optional remote download)
346
  # =========================
347
  MODEL_URL = _get_model_url()
348
 
 
383
  try:
384
  meta = json.loads(meta_path.read_text(encoding="utf-8"))
385
  FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
386
+ except Exception: pass
 
387
  else:
 
 
 
 
 
 
 
 
 
 
 
 
388
  infer = infer_features_from_model(model)
389
  if infer: FEATURES = infer
390
 
391
+
392
  # =========================
393
  # Session state
394
  # =========================
 
396
  if "results" not in st.session_state: st.session_state.results = {}
397
  if "train_ranges" not in st.session_state: st.session_state.train_ranges = None
398
 
399
+ # persist dev upload
400
+ for k, v in {
401
+ "dev_ready": False,
402
+ "dev_file_loaded": False,
403
+ "dev_previewed": False,
404
+ "dev_file_signature": None,
405
+ "dev_preview_request": False,
406
+ "dev_file_bytes": b"",
407
+ "dev_file_name": "",
408
+ "dev_file_rows": 0,
409
+ "dev_file_cols": 0,
410
+ }.items():
411
  if k not in st.session_state: st.session_state[k] = v
412
 
413
+
414
  # =========================
415
  # Hero header
416
  # =========================
 
427
  unsafe_allow_html=True,
428
  )
429
 
430
+
431
  # =========================
432
+ # INTRO
433
  # =========================
434
  if st.session_state.app_step == "intro":
435
  st.header("Welcome!")
436
+ st.markdown(
437
+ "This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data."
438
+ )
439
  st.subheader("Expected Input Features (in Order)")
440
  st.markdown(
441
  "- Q, gpm — Flow rate (gallons per minute) \n"
 
444
  "- WOB (klbf) — Weight on bit \n"
445
  "- ROP (ft/h) — Rate of penetration"
446
  )
447
+ st.subheader("Process")
448
  st.markdown(
449
+ "1) **Case Building**: Upload your data, preview, then run the model. \n"
450
+ "2) **Validate the Model**: Upload a validation dataset (with actual UCS) to evaluate performance. \n"
451
+ "3) **Prediction**: Upload a production dataset (no UCS target) to get predictions."
 
 
452
  )
453
+ if st.button("Start Showcase", type="primary"):
454
  st.session_state.app_step = "dev"; st.rerun()
455
 
456
+
457
  # =========================
458
+ # CASE BUILDING (Development)
459
  # =========================
460
  if st.session_state.app_step == "dev":
461
  st.sidebar.header("Case Building (Development)")
462
+ dev_label = "Upload data (Excel)" if not st.session_state.dev_file_name else "Replace data (Excel)"
463
  train_test_file = st.sidebar.file_uploader(dev_label, type=["xlsx","xls"], key="dev_upload")
464
 
465
+ # Persist upload
466
  if train_test_file is not None:
467
  try:
468
+ file_bytes = train_test_file.getvalue()
469
+ size = len(file_bytes)
470
  except Exception:
471
  file_bytes = b""; size = 0
472
  sig = (train_test_file.name, size)
 
489
  f"{st.session_state.dev_file_rows} rows × {st.session_state.dev_file_cols} cols"
490
  )
491
 
492
+ # Always-on navigation
493
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
494
  st.session_state.dev_preview_request = True
 
495
  run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
496
+ # Enabled always so users can jump ahead
497
+ if st.sidebar.button("Proceed to Validation ▶", use_container_width=True):
 
 
 
498
  st.session_state.app_step = "validate"; st.rerun()
499
+ if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True):
500
  st.session_state.app_step = "predict"; st.rerun()
501
 
502
+ # Helper text
503
+ st.subheader("Case Building (Development)")
504
+ if st.session_state.dev_ready:
505
+ st.success("Case has been built and results are displayed below.")
506
+ elif st.session_state.dev_file_loaded and st.session_state.dev_previewed:
507
+ st.info("Previewed ✓ — now click **Run Model** to build the case.")
508
+ elif st.session_state.dev_file_loaded:
509
+ st.info("📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
510
+ else:
511
+ st.write("**Upload your data to build a case, then run the model to review development performance.**")
512
 
513
+ # Preview modal
514
  if st.session_state.dev_preview_request and st.session_state.dev_file_bytes:
515
  _book = read_book_bytes(st.session_state.dev_file_bytes)
516
  st.session_state.dev_previewed = True
517
  st.session_state.dev_preview_request = False
518
+ preview_modal(_book, FEATURES, sheet_names=["Train","Test"])
519
 
520
+ # Run model
521
  if run_btn and st.session_state.dev_file_bytes:
522
  with st.status("Processing…", expanded=False) as status:
523
  book = read_book_bytes(st.session_state.dev_file_bytes)
 
526
  sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
527
  sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
528
  if sh_train is None or sh_test is None:
529
+ status.update(label="Workbook must include Train... and Test...", state="error"); st.stop()
530
  df_tr = book[sh_train].copy(); df_te = book[sh_test].copy()
531
  if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])):
532
  status.update(label="Missing required columns.", state="error"); st.stop()
 
546
  "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
547
  "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
548
  }
 
549
  tr_min = df_tr[FEATURES].min().to_dict(); tr_max = df_tr[FEATURES].max().to_dict()
550
  st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
551
 
552
  st.session_state.dev_ready = True
553
  status.update(label="Done ✓", state="complete"); st.rerun()
554
 
555
+ # Results
556
  if ("Train" in st.session_state.results) or ("Test" in st.session_state.results):
557
  tab1, tab2 = st.tabs(["Training", "Testing"])
558
+ cfg = {"displayModeBar": False, "scrollZoom": True}
559
+
560
  if "Train" in st.session_state.results:
561
  with tab1:
562
  df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
563
  c1,c2,c3 = st.columns(3)
564
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
565
+ left, right = st.columns([1.0, 0.7])
566
  with left:
567
+ st.plotly_chart(cross_plot_interactive(df[TARGET], df["UCS_Pred"]),
568
+ use_container_width=False, config=cfg)
 
 
569
  with right:
570
+ st.plotly_chart(depth_or_index_track_interactive(df, include_actual=True),
571
+ use_container_width=False, config=cfg)
572
+
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  if "Test" in st.session_state.results:
574
  with tab2:
575
  df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
576
  c1,c2,c3 = st.columns(3)
577
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
578
+ left, right = st.columns([1.0, 0.7])
579
  with left:
580
+ st.plotly_chart(cross_plot_interactive(df[TARGET], df["UCS_Pred"]),
581
+ use_container_width=False, config=cfg)
 
 
582
  with right:
583
+ st.plotly_chart(depth_or_index_track_interactive(df, include_actual=True),
584
+ use_container_width=False, config=cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
  st.markdown("---")
587
  sheets = {}; rows = []
 
608
  except Exception as e:
609
  st.warning(str(e))
610
 
611
+
612
  # =========================
613
+ # VALIDATE THE MODEL
614
  # =========================
615
  if st.session_state.app_step == "validate":
616
+ st.sidebar.header("Validate the Model")
617
  validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload")
 
618
  if validation_file is not None:
619
+ _book_tmp = read_book_bytes(validation_file.getvalue())
 
620
  if _book_tmp:
621
  first_df = next(iter(_book_tmp.values()))
 
622
  st.sidebar.caption(f"**Data loaded:** {validation_file.name} • {first_df.shape[0]} rows × {first_df.shape[1]} cols")
623
 
624
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=(validation_file is None)):
625
+ _book = read_book_bytes(validation_file.getvalue())
626
+ preview_modal(_book, FEATURES, sheet_names=["Validation","Val","Validate"])
627
 
628
  predict_btn = st.sidebar.button("Run Validation", type="primary", use_container_width=True)
 
629
  st.sidebar.button("⬅ Back to Case Building", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
630
+ st.sidebar.button("Proceed to Prediction ▶", on_click=lambda: st.session_state.update(app_step="predict"), use_container_width=True)
 
631
 
632
+ st.subheader("Validate the Model")
633
+ st.write("Upload a dataset with **actual UCS** to evaluate model performance on unseen data.")
 
634
 
635
+ if predict_btn and validation_file is not None:
 
 
 
 
 
636
  with st.status("Validating…", expanded=False) as status:
637
+ vbook = read_book_bytes(validation_file.getvalue())
638
  if not vbook: status.update(label="Could not read the Validation Excel.", state="error"); st.stop()
639
  status.update(label="Workbook read ✓")
640
  vname = find_sheet(vbook, ["Validation","Validate","validation2","Val","val"]) or list(vbook.keys())[0]
641
  df_val = vbook[vname].copy()
642
+ if not ensure_cols(df_val, FEATURES + [TARGET]): status.update(label="Missing required columns.", state="error"); st.stop()
643
  status.update(label="Columns validated ✓")
644
  df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
645
  st.session_state.results["Validate"] = df_val
 
653
  offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
654
  offenders.index = offenders.index + 1; oor_table = offenders
655
 
656
+ metrics_val = {
657
+ "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
658
+ "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
659
+ "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"])
660
+ }
 
 
661
  st.session_state.results["metrics_val"] = metrics_val
662
  st.session_state.results["summary_val"] = {
663
  "n_points": len(df_val),
 
666
  "oor_pct": oor_pct
667
  }
668
  st.session_state.results["oor_table"] = oor_table
669
+ status.update(label="Predictions ready ✓", state="complete")
670
 
671
  if "Validate" in st.session_state.results:
672
+ cfg = {"displayModeBar": False, "scrollZoom": True}
673
+ df = st.session_state.results["Validate"]
674
+ m = st.session_state.results.get("metrics_val"); sv = st.session_state.results["summary_val"]
675
  if sv["oor_pct"] > 0:
676
  st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
677
+ c1,c2,c3 = st.columns(3)
678
+ c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
679
 
680
+ left, right = st.columns([1.0, 0.7])
 
 
 
 
 
 
 
 
 
 
 
 
681
  with left:
682
+ st.plotly_chart(cross_plot_interactive(df[TARGET], df["UCS_Pred"]), use_container_width=False, config=cfg)
 
 
 
 
 
 
 
 
 
 
683
  with right:
684
+ st.plotly_chart(depth_or_index_track_interactive(df, include_actual=True), use_container_width=False, config=cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
 
686
+ if st.session_state.results.get("oor_table") is not None:
687
  st.write("*Out-of-range rows (vs. Training min–max):*")
688
+ st.dataframe(st.session_state.results["oor_table"], use_container_width=True)
689
 
690
+ # Export
691
  st.markdown("---")
692
+ sheets = {"Validate_with_pred": df}
693
  rows = []
694
  for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]:
695
+ mm = st.session_state.results.get(key)
696
+ if mm: rows.append({"Split": name, **{k: round(v,6) for k,v in mm.items()}})
697
  summary_df = pd.DataFrame(rows) if rows else None
698
  try:
699
  buf = io.BytesIO()
 
711
  except Exception as e:
712
  st.warning(str(e))
713
 
714
+ # OOR footnote
715
+ st.caption("**★ OOR %**: fraction of rows with any input feature outside the training min–max ranges.")
716
+
717
+
718
  # =========================
719
+ # PREDICTION (no actual UCS)
720
  # =========================
721
  if st.session_state.app_step == "predict":
722
  st.sidebar.header("Prediction")
723
+ pred_file = st.sidebar.file_uploader("Upload Production Excel", type=["xlsx","xls"], key="pred_upload")
 
724
  if pred_file is not None:
725
+ _book_tmp = read_book_bytes(pred_file.getvalue())
 
726
  if _book_tmp:
727
  first_df = next(iter(_book_tmp.values()))
 
728
  st.sidebar.caption(f"**Data loaded:** {pred_file.name} • {first_df.shape[0]} rows × {first_df.shape[1]} cols")
729
 
730
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=(pred_file is None)):
731
+ _book = read_book_bytes(pred_file.getvalue())
732
+ preview_modal(_book, FEATURES, sheet_names=["Prediction","Pred"])
733
 
734
+ run_pred = st.sidebar.button("Predict", type="primary", use_container_width=True)
735
  st.sidebar.button("⬅ Back to Validation", on_click=lambda: st.session_state.update(app_step="validate"), use_container_width=True)
736
+ st.sidebar.button("⬅ Back to Case Building", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
737
 
738
+ st.subheader("Prediction")
739
+ st.write("Upload a production dataset (**without UCS target**) to generate predictions.")
 
 
 
 
 
 
740
 
741
+ if run_pred and pred_file is not None:
742
  with st.status("Predicting…", expanded=False) as status:
743
+ pbook = read_book_bytes(pred_file.getvalue())
744
+ if not pbook: status.update(label="Could not read the Excel.", state="error"); st.stop()
745
  status.update(label="Workbook read ✓")
746
+ pname = find_sheet(pbook, ["Prediction","Pred"]) or list(pbook.keys())[0]
747
+ dfp = pbook[pname].copy()
748
+ if not ensure_cols(dfp, FEATURES): status.update(label="Missing required columns.", state="error"); st.stop()
749
  status.update(label="Columns validated ✓")
750
+ dfp["UCS_Pred"] = model.predict(dfp[FEATURES])
751
+ st.session_state.results["Prod"] = dfp
752
 
753
+ ranges = st.session_state.train_ranges; oor_pct = 0.0
754
  if ranges:
755
+ viol = {f: (dfp[f] < ranges[f][0]) | (dfp[f] > ranges[f][1]) for f in FEATURES}
756
  any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
757
+ st.session_state.results["summary_prod"] = {
758
+ "n_points": len(dfp),
759
+ "pred_min": float(dfp["UCS_Pred"].min()),
760
+ "pred_max": float(dfp["UCS_Pred"].max()),
761
+ "pred_mean": float(dfp["UCS_Pred"].mean()),
762
+ "pred_std": float(dfp["UCS_Pred"].std(ddof=0)),
 
 
 
 
 
763
  "oor_pct": oor_pct
764
  }
 
765
  status.update(label="Predictions ready ✓", state="complete")
766
 
767
+ if "Prod" in st.session_state.results:
768
+ cfg = {"displayModeBar": False, "scrollZoom": True}
769
+ dfp = st.session_state.results["Prod"]
770
+ sv = st.session_state.results["summary_prod"]
771
 
772
+ # Small summary table on the LEFT (where cross-plot would be)
773
+ left, right = st.columns([0.7, 1.0])
774
  with left:
775
+ tbl = pd.DataFrame({
776
+ "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
777
+ "Value": [sv["n_points"], sv["pred_min"], sv["pred_max"], sv["pred_mean"], sv["pred_std"], sv["oor_pct"]]
778
+ })
779
+ st.dataframe(tbl, use_container_width=True)
780
+ st.caption("**★ OOR %**: fraction of rows with any input feature outside the training min–max ranges.")
781
+
 
 
 
 
 
 
 
 
 
782
  with right:
 
 
 
 
 
783
  st.plotly_chart(
784
+ depth_or_index_track_interactive(dfp, include_actual=False),
785
+ use_container_width=False, config=cfg
 
 
786
  )
787
 
788
+ # Export
 
 
 
789
  st.markdown("---")
790
  try:
791
  buf = io.BytesIO()
792
  with pd.ExcelWriter(buf, engine="openpyxl") as xw:
793
+ dfp.to_excel(xw, sheet_name="Prediction_with_UCS_Pred", index=False)
794
  pd.DataFrame([sv]).to_excel(xw, sheet_name="Summary", index=False)
795
  st.download_button(
796
  "Export Prediction Results to Excel",
 
801
  except Exception as e:
802
  st.warning(str(e))
803
 
804
+
805
  # =========================
806
  # Footer
807
  # =========================