UCS2014 commited on
Commit
c08ab06
·
verified ·
1 Parent(s): fce6e42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -56
app.py CHANGED
@@ -12,7 +12,7 @@ matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
13
 
14
  import plotly.graph_objects as go
15
- from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
16
 
17
  # =========================
18
  # Constants (simple & robust)
@@ -22,6 +22,7 @@ TARGET = "UCS"
22
  MODELS_DIR = Path("models")
23
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
24
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
 
25
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
 
27
  # ---- Plot sizing controls (edit here) ----
@@ -29,14 +30,15 @@ CROSS_W = 500; CROSS_H = 500 # square cross-plot (Build + Validate)
29
  TRACK_W = 400; TRACK_H = 950 # log-strip style (all pages)
30
  FONT_SZ = 13
31
  PLOT_COLS = [14, 0.5, 10] # 3-column band: left • spacer • right (Build + Validate)
32
- CROSS_NUDGE = 0.5 # push cross-plot to the RIGHT inside its band:
33
- # inner columns [CROSS_NUDGE : 1] → bigger = more right
34
 
35
  # =========================
36
  # Page / CSS
37
  # =========================
38
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
39
  st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
 
 
40
  st.markdown(
41
  """
42
  <style>
@@ -45,11 +47,23 @@ st.markdown(
45
  .block-container { padding-top:.5rem; padding-bottom:.5rem; }
46
  .stButton>button { background:#007bff; color:#fff; font-weight:600; border-radius:8px; border:none; }
47
  .stButton>button:hover { background:#0056b3; }
48
- .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
49
  .st-hero .brand { width:110px; height:110px; object-fit:contain; }
50
  .st-hero h1 { margin:0; line-height:1.05; }
51
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
52
  [data-testid="stBlock"]{ margin-top:0 !important; }
 
 
 
 
 
 
 
 
 
 
 
 
53
  </style>
54
  """,
55
  unsafe_allow_html=True
@@ -134,7 +148,18 @@ except AttributeError:
134
  return wrapper
135
  return deco
136
 
137
- def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  @st.cache_resource(show_spinner=False)
140
  def load_model(model_path: str):
@@ -161,20 +186,36 @@ def find_sheet(book, names):
161
  if nm.lower() in low2orig: return low2orig[nm.lower()]
162
  return None
163
 
164
- def _nice_tick0(xmin: float, step: int = 100) -> float:
165
- """Round xmin down to a sensible multiple so the first tick sits at the left edge."""
166
- if not np.isfinite(xmin):
167
- return xmin
 
 
 
 
 
 
 
 
168
  return step * math.floor(xmin / step)
169
 
 
 
 
 
170
  # ---------- Plot builders ----------
171
  def cross_plot(actual, pred):
172
  a = pd.Series(actual).astype(float)
173
  p = pd.Series(pred).astype(float)
 
 
174
  lo = float(np.nanmin([a.min(), p.min()]))
175
  hi = float(np.nanmax([a.max(), p.max()]))
176
  pad = 0.03 * (hi - lo if hi > lo else 1.0)
177
- x0, x1 = lo - pad, hi + pad
 
 
178
 
179
  fig = go.Figure()
180
  fig.add_trace(go.Scatter(
@@ -183,25 +224,28 @@ def cross_plot(actual, pred):
183
  hovertemplate="Actual: %{x:.0f}<br>Pred: %{y:.0f}<extra></extra>",
184
  showlegend=False
185
  ))
 
186
  fig.add_trace(go.Scatter(
187
- x=[x0, x1], y=[x0, x1], mode="lines",
188
  line=dict(color=COLORS["ref"], width=1.2, dash="dash"),
189
  hoverinfo="skip", showlegend=False
190
  ))
 
191
  fig.update_layout(
192
  width=CROSS_W, height=CROSS_H, paper_bgcolor="#fff", plot_bgcolor="#fff",
193
  margin=dict(l=64, r=18, t=10, b=48), hovermode="closest",
194
  font=dict(size=FONT_SZ)
195
  )
196
- fig.update_xaxes(title_text="<b>Actual UCS (psi)</b>", range=[x0, x1],
197
- ticks="outside", tickformat=",.0f",
198
- showline=True, linewidth=1.2, linecolor="#444", mirror=True,
199
- showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True)
200
- fig.update_yaxes(title_text="<b>Predicted UCS (psi)</b>", range=[x0, x1],
201
- ticks="outside", tickformat=",.0f",
202
- showline=True, linewidth=1.2, linecolor="#444", mirror=True,
203
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
204
- scaleanchor="x", scaleratio=1, automargin=True)
 
205
  return fig
206
 
207
  def track_plot(df, include_actual=True):
@@ -224,7 +268,8 @@ def track_plot(df, include_actual=True):
224
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
225
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
226
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
227
- tick0 = _nice_tick0(xmin, step=100) # sensible first tick at left border
 
228
 
229
  fig = go.Figure()
230
  fig.add_trace(go.Scatter(
@@ -254,7 +299,7 @@ def track_plot(df, include_actual=True):
254
  fig.update_xaxes(
255
  title_text="<b>UCS (psi)</b>", side="top", range=[xmin, xmax],
256
  ticks="outside", tickformat=",.0f",
257
- tickmode="auto", tick0=tick0,
258
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
259
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
260
  )
@@ -309,8 +354,10 @@ def preview_modal(book: dict[str, pd.DataFrame]):
309
  t1, t2 = st.tabs(["Tracks", "Summary"])
310
  with t1: st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
311
  with t2:
312
- tbl = df[FEATURES].agg(['min','max','mean','std']).T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"})
313
- st.dataframe(tbl.reset_index(names="Feature"), use_container_width=True)
 
 
314
 
315
  # =========================
316
  # Load model (simple)
@@ -370,7 +417,7 @@ st.markdown(
370
  <img src="{inline_logo()}" class="brand" />
371
  <div>
372
  <h1>ST_GeoMech_UCS</h1>
373
- <div class="tagline">Real-Time UCS Tracking While Drilling — Cloud Ready</div>
374
  </div>
375
  </div>
376
  """,
@@ -396,7 +443,7 @@ if st.session_state.app_step == "intro":
396
  # CASE BUILDING
397
  # =========================
398
  if st.session_state.app_step == "dev":
399
- st.sidebar.header("Case Building (Development)")
400
  up = st.sidebar.file_uploader("Upload Train/Test Excel", type=["xlsx","xls"])
401
  if up is not None:
402
  st.session_state.dev_file_bytes = up.getvalue()
@@ -409,19 +456,22 @@ if st.session_state.app_step == "dev":
409
  df0 = next(iter(tmp.values()))
410
  st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
411
 
412
- if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
413
- preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
414
- st.session_state.dev_preview = True
 
 
 
 
415
 
416
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
417
- # always available nav
418
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
419
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
420
 
421
  # ---- Pinned helper at the very top of the page ----
422
  helper_top = st.container()
423
  with helper_top:
424
- st.subheader("Case Building (Development)")
425
  if st.session_state.dev_file_loaded and st.session_state.dev_preview:
426
  st.info("Previewed ✓ — now click **Run Model**.")
427
  elif st.session_state.dev_file_loaded:
@@ -438,12 +488,21 @@ if st.session_state.app_step == "dev":
438
  tr = book[sh_train].copy(); te = book[sh_test].copy()
439
  if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
440
  st.error("Missing required columns."); st.stop()
 
441
  tr["UCS_Pred"] = model.predict(tr[FEATURES])
442
  te["UCS_Pred"] = model.predict(te[FEATURES])
443
 
444
  st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
445
- st.session_state.results["m_train"]={"R2":r2_score(tr[TARGET],tr["UCS_Pred"]), "RMSE":rmse(tr[TARGET],tr["UCS_Pred"]), "MAE":mean_absolute_error(tr[TARGET],tr["UCS_Pred"])}
446
- st.session_state.results["m_test"] ={"R2":r2_score(te[TARGET],te["UCS_Pred"]), "RMSE":rmse(te[TARGET],te["UCS_Pred"]), "MAE":mean_absolute_error(te[TARGET],te["UCS_Pred"])}
 
 
 
 
 
 
 
 
447
 
448
  tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
449
  st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
@@ -451,22 +510,16 @@ if st.session_state.app_step == "dev":
451
 
452
  def _dev_block(df, m):
453
  c1,c2,c3 = st.columns(3)
454
- c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
455
  left, spacer, right = st.columns(PLOT_COLS)
456
  with left:
457
  pad, plotcol = left.columns([CROSS_NUDGE, 1]) # shift cross-plot right inside its band
458
  with plotcol:
459
- st.plotly_chart(
460
- cross_plot(df[TARGET], df["UCS_Pred"]),
461
- use_container_width=False,
462
- config={"displayModeBar": False, "scrollZoom": True}
463
- )
464
  with right:
465
- st.plotly_chart(
466
- track_plot(df, include_actual=True),
467
- use_container_width=False,
468
- config={"displayModeBar": False, "scrollZoom": True}
469
- )
470
 
471
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
472
  tab1, tab2 = st.tabs(["Training", "Testing"])
@@ -486,12 +539,18 @@ if st.session_state.app_step == "validate":
486
  if book:
487
  df0 = next(iter(book.values()))
488
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
489
- if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
490
- preview_modal(read_book_bytes(up.getvalue()))
 
 
 
 
 
491
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
492
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
493
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
494
 
 
495
  st.subheader("Validate the Model")
496
  st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
497
 
@@ -510,18 +569,23 @@ if st.session_state.app_step == "validate":
510
  if any_viol.any():
511
  tbl = df.loc[any_viol, FEATURES].copy()
512
  tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
513
- st.session_state.results["m_val"]={"R2":r2_score(df[TARGET],df["UCS_Pred"]), "RMSE":rmse(df[TARGET],df["UCS_Pred"]), "MAE":mean_absolute_error(df[TARGET],df["UCS_Pred"])}
514
- st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
 
 
 
 
 
515
  st.session_state.results["oor_tbl"]=tbl
516
 
517
  if "Validate" in st.session_state.results:
518
  m = st.session_state.results["m_val"]
519
  c1,c2,c3 = st.columns(3)
520
- c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
521
 
522
  left, spacer, right = st.columns(PLOT_COLS)
523
  with left:
524
- pad, plotcol = left.columns([CROSS_NUDGE, 1]) # same nudge
525
  with plotcol:
526
  st.plotly_chart(
527
  cross_plot(st.session_state.results["Validate"][TARGET],
@@ -538,7 +602,7 @@ if st.session_state.app_step == "validate":
538
  if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
539
  if st.session_state.results["oor_tbl"] is not None:
540
  st.write("*Out-of-range rows (vs. Training min–max):*")
541
- st.dataframe(st.session_state.results["oor_tbl"], use_container_width=True)
542
 
543
  # =========================
544
  # PREDICTION (no actual UCS)
@@ -551,11 +615,17 @@ if st.session_state.app_step == "predict":
551
  if book:
552
  df0 = next(iter(book.values()))
553
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
554
- if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
555
- preview_modal(read_book_bytes(up.getvalue()))
 
 
 
 
 
556
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
557
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
558
 
 
559
  st.subheader("Prediction")
560
  st.write("Upload a dataset with the feature columns (no **UCS**).")
561
 
@@ -584,12 +654,13 @@ if st.session_state.app_step == "predict":
584
 
585
  left, spacer, right = st.columns(PLOT_COLS)
586
  with left:
 
587
  table = pd.DataFrame({
588
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
589
- "Value": [sv["n"], sv["pred_min"], sv["pred_max"], sv["pred_mean"], sv["pred_std"], f'{sv["oor"]:.1f}%']
 
590
  })
591
- st.success("Predictions ready ✓")
592
- st.dataframe(table, use_container_width=True, hide_index=True)
593
  st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
594
  with right:
595
  st.plotly_chart(
 
12
  import matplotlib.pyplot as plt
13
 
14
  import plotly.graph_objects as go
15
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
16
 
17
  # =========================
18
  # Constants (simple & robust)
 
22
  MODELS_DIR = Path("models")
23
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
24
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
25
+
26
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
27
 
28
  # ---- Plot sizing controls (edit here) ----
 
30
  TRACK_W = 400; TRACK_H = 950 # log-strip style (all pages)
31
  FONT_SZ = 13
32
  PLOT_COLS = [14, 0.5, 10] # 3-column band: left • spacer • right (Build + Validate)
33
+ CROSS_NUDGE = 0.5 # inner columns [CROSS_NUDGE : 1] bigger = more right
 
34
 
35
  # =========================
36
  # Page / CSS
37
  # =========================
38
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
39
  st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
40
+
41
+ # Hide drag-n-drop helper texts inside uploaders; keep the Browse button
42
  st.markdown(
43
  """
44
  <style>
 
47
  .block-container { padding-top:.5rem; padding-bottom:.5rem; }
48
  .stButton>button { background:#007bff; color:#fff; font-weight:600; border-radius:8px; border:none; }
49
  .stButton>button:hover { background:#0056b3; }
50
+ .st-hero { display:flex; align-items:center; gap:16px; padding-top:4px; }
51
  .st-hero .brand { width:110px; height:110px; object-fit:contain; }
52
  .st-hero h1 { margin:0; line-height:1.05; }
53
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
54
  [data-testid="stBlock"]{ margin-top:0 !important; }
55
+
56
+ /* Remove drag & drop copy and limit line in uploader */
57
+ [data-testid="stFileUploadDropzone"] [data-testid="stFileUploaderInstructions"],
58
+ [data-testid="stFileUploadDropzone"] [data-testid="stCaptionContainer"],
59
+ .stFileUploader .dz-message {display:none !important;}
60
+
61
+ /* Center our HTML tables */
62
+ .table-center table { margin-left:auto; margin-right:auto; border-collapse:collapse; }
63
+ .table-center table th, .table-center table td {
64
+ text-align:center !important; padding:6px 10px; border:1px solid #e5e7eb;
65
+ }
66
+ .table-center table thead th { background:#f8fafc; }
67
  </style>
68
  """,
69
  unsafe_allow_html=True
 
148
  return wrapper
149
  return deco
150
 
151
+ def rmse(y_true, y_pred) -> float:
152
+ return float(np.sqrt(mean_squared_error(y_true, y_pred)))
153
+
154
+ def corrcoef_safe(y_true, y_pred) -> float:
155
+ a = pd.Series(y_true, dtype=float)
156
+ b = pd.Series(y_pred, dtype=float)
157
+ a = a.replace([np.inf, -np.inf], np.nan).dropna()
158
+ b = b.replace([np.inf, -np.inf], np.nan).dropna()
159
+ n = min(len(a), len(b))
160
+ if n == 0:
161
+ return float("nan")
162
+ return float(np.corrcoef(a.iloc[:n], b.iloc[:n])[0, 1])
163
 
164
  @st.cache_resource(show_spinner=False)
165
  def load_model(model_path: str):
 
186
  if nm.lower() in low2orig: return low2orig[nm.lower()]
187
  return None
188
 
189
+ def _nice_step(lo: float, hi: float, target_ticks: int = 6) -> float:
190
+ rng = max(hi - lo, 1.0)
191
+ raw = rng / max(target_ticks, 1)
192
+ mag = 10 ** math.floor(math.log10(raw))
193
+ for m in [1, 2, 2.5, 5, 10]:
194
+ step = m * mag
195
+ if raw <= step:
196
+ return step
197
+ return mag * 10
198
+
199
+ def _nice_tick0(xmin: float, step: float) -> float:
200
+ if not np.isfinite(xmin): return xmin
201
  return step * math.floor(xmin / step)
202
 
203
+ def html_table_center(df: pd.DataFrame, index: bool = False):
204
+ html = df.to_html(index=index, classes="table-center")
205
+ st.markdown(html, unsafe_allow_html=True)
206
+
207
  # ---------- Plot builders ----------
208
  def cross_plot(actual, pred):
209
  a = pd.Series(actual).astype(float)
210
  p = pd.Series(pred).astype(float)
211
+
212
+ # Symmetric / identical axis range & ticks
213
  lo = float(np.nanmin([a.min(), p.min()]))
214
  hi = float(np.nanmax([a.max(), p.max()]))
215
  pad = 0.03 * (hi - lo if hi > lo else 1.0)
216
+ lo -= pad; hi += pad
217
+ step = _nice_step(lo, hi, target_ticks=6)
218
+ tick0 = _nice_tick0(lo, step)
219
 
220
  fig = go.Figure()
221
  fig.add_trace(go.Scatter(
 
224
  hovertemplate="Actual: %{x:.0f}<br>Pred: %{y:.0f}<extra></extra>",
225
  showlegend=False
226
  ))
227
+ # 45° reference
228
  fig.add_trace(go.Scatter(
229
+ x=[lo, hi], y=[lo, hi], mode="lines",
230
  line=dict(color=COLORS["ref"], width=1.2, dash="dash"),
231
  hoverinfo="skip", showlegend=False
232
  ))
233
+
234
  fig.update_layout(
235
  width=CROSS_W, height=CROSS_H, paper_bgcolor="#fff", plot_bgcolor="#fff",
236
  margin=dict(l=64, r=18, t=10, b=48), hovermode="closest",
237
  font=dict(size=FONT_SZ)
238
  )
239
+ # identical x & y ranges/ticks; stays locked on zoom
240
+ axis_common = dict(
241
+ range=[lo, hi], ticks="outside", tickformat=",.0f",
242
+ tick0=tick0, dtick=step,
243
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
244
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
245
+ )
246
+ fig.update_xaxes(title_text="<b>Actual UCS (psi)</b>", **axis_common)
247
+ fig.update_yaxes(title_text="<b>Predicted UCS (psi)</b>", **axis_common,
248
+ scaleanchor="x", scaleratio=1)
249
  return fig
250
 
251
  def track_plot(df, include_actual=True):
 
268
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
269
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
270
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
271
+ x_step = _nice_step(xmin, xmax, target_ticks=6)
272
+ tick0 = _nice_tick0(xmin, x_step)
273
 
274
  fig = go.Figure()
275
  fig.add_trace(go.Scatter(
 
299
  fig.update_xaxes(
300
  title_text="<b>UCS (psi)</b>", side="top", range=[xmin, xmax],
301
  ticks="outside", tickformat=",.0f",
302
+ tickmode="linear", tick0=tick0, dtick=x_step,
303
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
304
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
305
  )
 
354
  t1, t2 = st.tabs(["Tracks", "Summary"])
355
  with t1: st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
356
  with t2:
357
+ tbl = df[FEATURES].agg(['min','max','mean','std']).T.rename(
358
+ columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}
359
+ ).reset_index(names="Feature")
360
+ html_table_center(tbl, index=False)
361
 
362
  # =========================
363
  # Load model (simple)
 
417
  <img src="{inline_logo()}" class="brand" />
418
  <div>
419
  <h1>ST_GeoMech_UCS</h1>
420
+ <div class="tagline">Real-Time UCS Tracking While Drilling</div>
421
  </div>
422
  </div>
423
  """,
 
443
  # CASE BUILDING
444
  # =========================
445
  if st.session_state.app_step == "dev":
446
+ st.sidebar.header("Case Building")
447
  up = st.sidebar.file_uploader("Upload Train/Test Excel", type=["xlsx","xls"])
448
  if up is not None:
449
  st.session_state.dev_file_bytes = up.getvalue()
 
456
  df0 = next(iter(tmp.values()))
457
  st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
458
 
459
+ # Preview button ALWAYS enabled
460
+ if st.sidebar.button("Preview data", use_container_width=True):
461
+ if not st.session_state.dev_file_loaded:
462
+ st.warning("Upload an Excel file first, then preview.")
463
+ else:
464
+ preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
465
+ st.session_state.dev_preview = True
466
 
467
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
 
468
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
469
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
470
 
471
  # ---- Pinned helper at the very top of the page ----
472
  helper_top = st.container()
473
  with helper_top:
474
+ st.subheader("Case Building")
475
  if st.session_state.dev_file_loaded and st.session_state.dev_preview:
476
  st.info("Previewed ✓ — now click **Run Model**.")
477
  elif st.session_state.dev_file_loaded:
 
488
  tr = book[sh_train].copy(); te = book[sh_test].copy()
489
  if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
490
  st.error("Missing required columns."); st.stop()
491
+
492
  tr["UCS_Pred"] = model.predict(tr[FEATURES])
493
  te["UCS_Pred"] = model.predict(te[FEATURES])
494
 
495
  st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
496
+ st.session_state.results["m_train"]={
497
+ "R": corrcoef_safe(tr[TARGET], tr["UCS_Pred"]),
498
+ "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
499
+ "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
500
+ }
501
+ st.session_state.results["m_test"] ={
502
+ "R": corrcoef_safe(te[TARGET], te["UCS_Pred"]),
503
+ "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
504
+ "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
505
+ }
506
 
507
  tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
508
  st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
 
510
 
511
  def _dev_block(df, m):
512
  c1,c2,c3 = st.columns(3)
513
+ c1.metric("R", f"{m['R']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
514
  left, spacer, right = st.columns(PLOT_COLS)
515
  with left:
516
  pad, plotcol = left.columns([CROSS_NUDGE, 1]) # shift cross-plot right inside its band
517
  with plotcol:
518
+ st.plotly_chart(cross_plot(df[TARGET], df["UCS_Pred"]),
519
+ use_container_width=False, config={"displayModeBar": False, "scrollZoom": True})
 
 
 
520
  with right:
521
+ st.plotly_chart(track_plot(df, include_actual=True),
522
+ use_container_width=False, config={"displayModeBar": False, "scrollZoom": True})
 
 
 
523
 
524
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
525
  tab1, tab2 = st.tabs(["Training", "Testing"])
 
539
  if book:
540
  df0 = next(iter(book.values()))
541
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
542
+
543
+ if st.sidebar.button("Preview data", use_container_width=True):
544
+ if up is None:
545
+ st.warning("Upload an Excel file first, then preview.")
546
+ else:
547
+ preview_modal(read_book_bytes(up.getvalue()))
548
+
549
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
550
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
551
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
552
 
553
+ # pinned
554
  st.subheader("Validate the Model")
555
  st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
556
 
 
569
  if any_viol.any():
570
  tbl = df.loc[any_viol, FEATURES].copy()
571
  tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
572
+ st.session_state.results["m_val"]={
573
+ "R": corrcoef_safe(df[TARGET], df["UCS_Pred"]),
574
+ "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
575
+ "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
576
+ }
577
+ st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),
578
+ "pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
579
  st.session_state.results["oor_tbl"]=tbl
580
 
581
  if "Validate" in st.session_state.results:
582
  m = st.session_state.results["m_val"]
583
  c1,c2,c3 = st.columns(3)
584
+ c1.metric("R", f"{m['R']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
585
 
586
  left, spacer, right = st.columns(PLOT_COLS)
587
  with left:
588
+ pad, plotcol = left.columns([CROSS_NUDGE, 1])
589
  with plotcol:
590
  st.plotly_chart(
591
  cross_plot(st.session_state.results["Validate"][TARGET],
 
602
  if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
603
  if st.session_state.results["oor_tbl"] is not None:
604
  st.write("*Out-of-range rows (vs. Training min–max):*")
605
+ html_table_center(st.session_state.results["oor_tbl"].reset_index(drop=True), index=False)
606
 
607
  # =========================
608
  # PREDICTION (no actual UCS)
 
615
  if book:
616
  df0 = next(iter(book.values()))
617
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
618
+
619
+ if st.sidebar.button("Preview data", use_container_width=True):
620
+ if up is None:
621
+ st.warning("Upload an Excel file first, then preview.")
622
+ else:
623
+ preview_modal(read_book_bytes(up.getvalue()))
624
+
625
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
626
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
627
 
628
+ # pinned
629
  st.subheader("Prediction")
630
  st.write("Upload a dataset with the feature columns (no **UCS**).")
631
 
 
654
 
655
  left, spacer, right = st.columns(PLOT_COLS)
656
  with left:
657
+ st.success("Predictions ready ✓")
658
  table = pd.DataFrame({
659
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
660
+ "Value": [sv["n"], sv["pred_min"], sv["pred_max"],
661
+ sv["pred_mean"], sv["pred_std"], f'{sv["oor"]:.1f}%']
662
  })
663
+ html_table_center(table, index=False)
 
664
  st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
665
  with right:
666
  st.plotly_chart(