UCS2014 commited on
Commit
1a93a5b
·
verified ·
1 Parent(s): eab342c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -199
app.py CHANGED
@@ -6,16 +6,17 @@ import pandas as pd
6
  import numpy as np
7
  import joblib
8
 
9
- # matplotlib only for PREVIEW modal
10
  import matplotlib
11
  matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
 
13
 
14
  import plotly.graph_objects as go
15
- from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
16
 
17
  # =========================
18
- # Constants (simple & robust)
19
  # =========================
20
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
21
  TARGET = "UCS"
@@ -24,46 +25,35 @@ DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
24
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
25
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
 
27
- # ---- Plot sizing controls (edit here) ----
28
- CROSS_W = 450; CROSS_H = 450 # square cross-plot (Build + Validate)
29
- TRACK_W = 400; TRACK_H = 950 # log-strip style (all pages)
30
  FONT_SZ = 15
31
- PLOT_COLS = [30, 1, 20] # 3-column band: left • spacer • right (Build + Validate)
32
- CROSS_NUDGE = 0.02 # push cross-plot to the RIGHT inside its band
33
 
34
  # =========================
35
  # Page / CSS
36
  # =========================
37
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
 
 
38
  st.markdown("""
39
  <style>
40
- /* Hide the helper text in file uploader */
41
  section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"] {
42
- display: none !important;
43
  }
 
 
 
 
 
44
  </style>
45
  """, unsafe_allow_html=True)
46
- st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
47
- st.markdown(
48
- """
49
- <style>
50
- .stApp { background:#fff; }
51
- section[data-testid="stSidebar"] { background:#F6F9FC; }
52
- .block-container { padding-top:.5rem; padding-bottom:.5rem; }
53
- .stButton>button { background:#007bff; color:#fff; font-weight:600; border-radius:8px; border:none; }
54
- .stButton>button:hover { background:#0056b3; }
55
- .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
56
- .st-hero .brand { width:110px; height:110px; object-fit:contain; }
57
- .st-hero h1 { margin:0; line-height:1.05; }
58
- .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
59
- [data-testid="stBlock"]{ margin-top:0 !important; }
60
- </style>
61
- """,
62
- unsafe_allow_html=True
63
- )
64
 
65
  # =========================
66
- # Password gate (define first, then call)
67
  # =========================
68
  def inline_logo(path="logo.png") -> str:
69
  try:
@@ -80,44 +70,16 @@ def add_password_gate() -> None:
80
  required = os.environ.get("APP_PASSWORD", "")
81
 
82
  if not required:
83
- st.markdown(
84
- f"""
85
- <div style="display:flex;align-items:center;gap:14px;margin:8px 0 6px 0;">
86
- <img src="{inline_logo()}" style="width:56px;height:56px;object-fit:contain"/>
87
- <div>
88
- <div style="font-size:1.9rem;font-weight:800;">ST_GeoMech_UCS</div>
89
- <div style="color:#667085;">Smart Thinking • Secure Access</div>
90
- </div>
91
- </div>
92
- <div style="font-size:1.25rem;font-weight:700;margin:8px 0 4px 0;">Protected Area</div>
93
- <div style="color:#6b7280;margin-bottom:14px;">
94
- Set <code>APP_PASSWORD</code> in <b>Settings → Secrets</b> (or environment) and restart.
95
- </div>
96
- """,
97
- unsafe_allow_html=True,
98
- )
99
  st.stop()
100
 
101
  if st.session_state.get("auth_ok", False):
102
  return
103
 
104
- st.markdown(
105
- f"""
106
- <div style="display:flex;align-items:center;gap:14px;margin:8px 0 6px 0;">
107
- <img src="{inline_logo()}" style="width:56px;height:56px;object-fit:contain"/>
108
- <div>
109
- <div style="font-size:1.9rem;font-weight:800;">ST_GeoMech_UCS</div>
110
- <div style="color:#667085;">Smart Thinking • Secure Access</div>
111
- </div>
112
- </div>
113
- <div style="font-size:1.25rem;font-weight:700;margin:8px 0 4px 0;">Protected</div>
114
- <div style="color:#6b7280;margin-bottom:14px;">Please enter your access key to continue.</div>
115
- """,
116
- unsafe_allow_html=True
117
- )
118
-
119
- pwd = st.text_input("Access key", type="password", placeholder="••••••••")
120
- if st.button("Unlock", type="primary"):
121
  if pwd == required:
122
  st.session_state.auth_ok = True
123
  st.rerun()
@@ -141,17 +103,14 @@ except AttributeError:
141
  return wrapper
142
  return deco
143
 
144
- def rmse(y_true, y_pred):
145
  return float(np.sqrt(mean_squared_error(y_true, y_pred)))
146
 
147
- def r_value(y_true, y_pred):
148
- """Pearson correlation coefficient (R)."""
149
- y_true = np.asarray(y_true, dtype=float)
150
- y_pred = np.asarray(y_pred, dtype=float)
151
- mask = np.isfinite(y_true) & np.isfinite(y_pred)
152
- if mask.sum() < 2:
153
- return float("nan")
154
- return float(np.corrcoef(y_true[mask], y_pred[mask])[0, 1])
155
 
156
  @st.cache_resource(show_spinner=False)
157
  def load_model(model_path: str):
@@ -163,8 +122,7 @@ def parse_excel(data_bytes: bytes):
163
  xl = pd.ExcelFile(bio)
164
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
165
 
166
- def read_book_bytes(b: bytes):
167
- return parse_excel(b) if b else {}
168
 
169
  def ensure_cols(df, cols):
170
  miss = [c for c in cols if c not in df.columns]
@@ -176,116 +134,99 @@ def ensure_cols(df, cols):
176
  def find_sheet(book, names):
177
  low2orig = {k.lower(): k for k in book.keys()}
178
  for nm in names:
179
- if nm.lower() in low2orig:
180
- return low2orig[nm.lower()]
181
  return None
182
 
183
  def _nice_tick0(xmin: float, step: int = 100) -> float:
184
  return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
185
 
186
- # ---------- cross_plot ----------
187
- def cross_plot(actual, pred):
 
 
188
  a = pd.Series(actual).astype(float)
189
  p = pd.Series(pred).astype(float)
190
 
191
- # Fixed extents + identical tick settings
192
- AX_MIN, AX_MAX = 6000, 10000
193
- TICK0, DTICK = 6000, 1000
194
 
195
- fig = go.Figure()
196
 
197
- # Points
198
- fig.add_trace(go.Scatter(
199
- x=a, y=p, mode="markers",
200
- marker=dict(size=6, color=COLORS["pred"]),
201
- hovertemplate="Actual: %{x:.0f}<br>Pred: %{y:.0f}<extra></extra>",
202
- showlegend=False
203
- ))
204
 
205
- # 1:1 reference line (full diagonal)
206
- fig.add_trace(go.Scatter(
207
- x=[AX_MIN, AX_MAX], y=[AX_MIN, AX_MAX], mode="lines",
208
- line=dict(color=COLORS["ref"], width=1.2, dash="dash"),
209
- hoverinfo="skip", showlegend=False
210
- ))
211
 
212
- fig.update_layout(
213
- width=CROSS_W, height=CROSS_H,
214
- paper_bgcolor="#fff", plot_bgcolor="#fff",
215
- margin=dict(l=64, r=18, t=10, b=48),
216
- hovermode="closest",
217
- font=dict(size=FONT_SZ)
218
- )
219
 
220
- # Make the two axes identical: same range, same ticks, same aspect
221
- fig.update_xaxes(
222
- title_text="<b>Actual UCS (psi)</b>",
223
- title_font=dict(size=18, family="Arial", color="#000"),
224
- range=[AX_MIN, AX_MAX],
225
- tick0=TICK0, dtick=DTICK, tickformat=",.0f", ticks="outside",
226
- showline=True, linewidth=1.2, linecolor="#444", mirror=True,
227
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
228
- scaleanchor="y", scaleratio=1, # lock aspect to keep 45° line exact
229
- fixedrange=True, # keep the range fixed (no zoom/pan)
230
- automargin=True
231
- )
232
- fig.update_yaxes(
233
- title_text="<b>Predicted UCS (psi)</b>",
234
- title_font=dict(size=18, family="Arial", color="#000"),
235
- range=[AX_MIN, AX_MAX],
236
- tick0=TICK0, dtick=DTICK, tickformat=",.0f", ticks="outside",
237
- showline=True, linewidth=1.2, linecolor="#444", mirror=True,
238
- showgrid=True, gridcolor="rgba(0,0,0,0.12)",
239
- fixedrange=True, # same fixed range
240
- automargin=True
241
- )
242
 
243
- return fig
 
 
 
 
 
 
 
244
 
 
 
 
 
 
 
245
 
246
- # ---------- track_plot ----------
 
 
 
 
 
247
  def track_plot(df, include_actual=True):
248
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
249
- if depth_col:
250
  y = pd.Series(df[depth_col]).astype(float)
251
  ylab = depth_col
 
252
  else:
253
  y = pd.Series(np.arange(1, len(df) + 1))
254
  ylab = "Point Index"
 
255
 
256
- y_range = [float(y.max()), float(y.min())]
257
-
258
  x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
259
  if include_actual and TARGET in df.columns:
260
  x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
261
-
262
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
263
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
264
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
265
  tick0 = _nice_tick0(xmin, step=100)
266
 
267
  fig = go.Figure()
268
-
269
  fig.add_trace(go.Scatter(
270
  x=df["UCS_Pred"], y=y, mode="lines",
271
  line=dict(color=COLORS["pred"], width=1.8),
272
  name="UCS_Pred",
273
- hovertemplate="UCS_Pred: %{x:.0f}<br>" + ylab + ": %{y}<extra></extra>"
274
  ))
275
-
276
  if include_actual and TARGET in df.columns:
277
  fig.add_trace(go.Scatter(
278
  x=df[TARGET], y=y, mode="lines",
279
  line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
280
  name="UCS (actual)",
281
- hovertemplate="UCS (actual): %{x:.0f}<br>" + ylab + ": %{y}<extra></extra>"
282
  ))
283
 
284
  fig.update_layout(
285
- width=TRACK_W, height=TRACK_H,
286
- paper_bgcolor="#fff", plot_bgcolor="#fff",
287
- margin=dict(l=72, r=18, t=36, b=48),
288
- hovermode="closest",
289
  font=dict(size=FONT_SZ),
290
  legend=dict(
291
  x=0.98, y=0.05, xanchor="right", yanchor="bottom",
@@ -293,26 +234,19 @@ def track_plot(df, include_actual=True):
293
  ),
294
  legend_title_text=""
295
  )
296
-
297
  fig.update_xaxes(
298
- title_text="<b>UCS (psi)</b>",
299
- title_font=dict(size=18, family="Arial", color="#000"),
300
  side="top", range=[xmin, xmax],
301
- tick0=tick0, tickmode="auto", tickformat=",.0f",
302
- ticks="outside",
303
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
304
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
305
  )
306
-
307
  fig.update_yaxes(
308
- title_text=f"<b>{ylab}</b>",
309
- title_font=dict(size=18, family="Arial", color="#000"),
310
- range=y_range,
311
- ticks="outside",
312
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
313
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
314
  )
315
-
316
  return fig
317
 
318
  # ---------- Preview modal (matplotlib) ----------
@@ -358,11 +292,13 @@ def preview_modal(book: dict[str, pd.DataFrame]):
358
  t1, t2 = st.tabs(["Tracks", "Summary"])
359
  with t1: st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
360
  with t2:
361
- tbl = df[FEATURES].agg(['min','max','mean','std']).T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"})
 
 
362
  st.dataframe(tbl.reset_index(names="Feature"), use_container_width=True)
363
 
364
  # =========================
365
- # Load model (simple)
366
  # =========================
367
  def ensure_model() -> Path|None:
368
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
@@ -411,19 +347,13 @@ st.session_state.setdefault("dev_file_loaded",False)
411
  st.session_state.setdefault("dev_preview",False)
412
 
413
  # =========================
414
- # Hero
415
  # =========================
416
- st.markdown(
417
- f"""
418
- <div class="st-hero">
419
- <img src="{inline_logo()}" class="brand" />
420
- <div>
421
- <h1>ST_GeoMech_UCS</h1>
422
- <div class="tagline">Real-Time UCS Tracking While Drilling</div>
423
- </div>
424
- </div>
425
- """,
426
- unsafe_allow_html=True,
427
  )
428
 
429
  # =========================
@@ -463,11 +393,11 @@ if st.session_state.app_step == "dev":
463
  st.session_state.dev_preview = True
464
 
465
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
466
- # always available nav
467
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
468
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
469
 
470
- # ---- Pinned helper at the very top of the page ----
471
  helper_top = st.container()
472
  with helper_top:
473
  st.subheader("Case Building")
@@ -490,16 +420,14 @@ if st.session_state.app_step == "dev":
490
  tr["UCS_Pred"] = model.predict(tr[FEATURES])
491
  te["UCS_Pred"] = model.predict(te[FEATURES])
492
 
493
- # ---- metrics (R, RMSE, MAE) ----
494
- st.session_state.results["Train"]=tr
495
- st.session_state.results["Test"]=te
496
  st.session_state.results["m_train"]={
497
- "R": r_value(tr[TARGET], tr["UCS_Pred"]),
498
  "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
499
  "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
500
  }
501
  st.session_state.results["m_test"]={
502
- "R": r_value(te[TARGET], te["UCS_Pred"]),
503
  "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
504
  "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
505
  }
@@ -510,18 +438,13 @@ if st.session_state.app_step == "dev":
510
 
511
  def _dev_block(df, m):
512
  c1,c2,c3 = st.columns(3)
513
- c1.metric("R", f"{m['R']:.2f}")
514
- c2.metric("RMSE", f"{m['RMSE']:.2f}")
515
- c3.metric("MAE", f"{m['MAE']:.2f}")
516
  left, spacer, right = st.columns(PLOT_COLS)
517
  with left:
518
- pad, plotcol = left.columns([CROSS_NUDGE, 1]) # shift cross-plot right inside its band
519
  with plotcol:
520
- st.plotly_chart(
521
- cross_plot(df[TARGET], df["UCS_Pred"]),
522
- use_container_width=False,
523
- config={"displayModeBar": False, "scrollZoom": True}
524
- )
525
  with right:
526
  st.plotly_chart(
527
  track_plot(df, include_actual=True),
@@ -570,38 +493,31 @@ if st.session_state.app_step == "validate":
570
  oor_pct = float(any_viol.mean()*100.0)
571
  if any_viol.any():
572
  tbl = df.loc[any_viol, FEATURES].copy()
573
- tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(
574
- lambda r: ", ".join([c for c,v in r.items() if v]),
575
- axis=1
576
- )
577
  st.session_state.results["m_val"]={
578
- "R": r_value(df[TARGET], df["UCS_Pred"]),
579
  "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
580
  "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
581
  }
582
- st.session_state.results["sv_val"]={
583
- "n":len(df),
584
- "pred_min":float(df["UCS_Pred"].min()),
585
- "pred_max":float(df["UCS_Pred"].max()),
586
- "oor":oor_pct
587
- }
588
  st.session_state.results["oor_tbl"]=tbl
589
 
590
  if "Validate" in st.session_state.results:
591
  m = st.session_state.results["m_val"]
592
  c1,c2,c3 = st.columns(3)
593
- c1.metric("R", f"{m['R']:.2f}")
594
- c2.metric("RMSE", f"{m['RMSE']:.2f}")
595
- c3.metric("MAE", f"{m['MAE']:.2f}")
596
 
597
  left, spacer, right = st.columns(PLOT_COLS)
598
  with left:
599
  pad, plotcol = left.columns([CROSS_NUDGE, 1])
600
  with plotcol:
601
- st.plotly_chart(
602
- cross_plot(st.session_state.results["Validate"][TARGET],
603
- st.session_state.results["Validate"]["UCS_Pred"]),
604
- use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
605
  )
606
  with right:
607
  st.plotly_chart(
@@ -656,12 +572,16 @@ if st.session_state.app_step == "predict":
656
 
657
  if "PredictOnly" in st.session_state.results:
658
  df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
659
-
660
  left, spacer, right = st.columns(PLOT_COLS)
661
  with left:
662
  table = pd.DataFrame({
663
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
664
- "Value": [sv["n"], sv["pred_min"], sv["pred_max"], sv["pred_mean"], sv["pred_std"], f'{sv["oor"]:.1f}%']
 
 
 
 
 
665
  })
666
  st.success("Predictions ready ✓")
667
  st.dataframe(table, use_container_width=True, hide_index=True)
@@ -684,4 +604,4 @@ st.markdown(
684
  </div>
685
  """,
686
  unsafe_allow_html=True
687
- )
 
6
  import numpy as np
7
  import joblib
8
 
9
+ # matplotlib for PREVIEW modal and for the CROSS-PLOT (static)
10
  import matplotlib
11
  matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
13
+ from matplotlib.ticker import FuncFormatter
14
 
15
  import plotly.graph_objects as go
16
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
17
 
18
  # =========================
19
+ # Constants
20
  # =========================
21
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
22
  TARGET = "UCS"
 
25
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
26
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
27
 
28
+ # ---- Plot sizing controls ----
29
+ CROSS_W = 450; CROSS_H = 450 # square cross-plot
30
+ TRACK_W = 400; TRACK_H = 950 # log-strip style
31
  FONT_SZ = 15
32
+ PLOT_COLS = [30, 1, 20] # left • spacer • right
33
+ CROSS_NUDGE = 0.02
34
 
35
  # =========================
36
  # Page / CSS
37
  # =========================
38
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
39
+
40
+ # Hide file-uploader helper text + center dataframes (headers & cells)
41
  st.markdown("""
42
  <style>
43
+ /* Hide 'Drag and drop file here' and limit note in uploader */
44
  section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"] {
45
+ display: none !important;
46
  }
47
+ /* Center st.dataframe headers and cells */
48
+ div[data-testid="stDataFrame"] div[role="columnheader"] { justify-content: center; }
49
+ div[data-testid="stDataFrame"] div[role="gridcell"] { justify-content: center; }
50
+ /* Remove default app header/footer */
51
+ header, footer { visibility: hidden !important; }
52
  </style>
53
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # =========================
56
+ # Password gate
57
  # =========================
58
  def inline_logo(path="logo.png") -> str:
59
  try:
 
70
  required = os.environ.get("APP_PASSWORD", "")
71
 
72
  if not required:
73
+ st.warning("Set APP_PASSWORD in Secrets (or environment) and restart.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  st.stop()
75
 
76
  if st.session_state.get("auth_ok", False):
77
  return
78
 
79
+ st.sidebar.image("logo.png", use_column_width=True)
80
+ st.sidebar.markdown("### ST_GeoMech_UCS\nSmart Thinking • Secure Access")
81
+ pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
82
+ if st.sidebar.button("Unlock", type="primary"):
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if pwd == required:
84
  st.session_state.auth_ok = True
85
  st.rerun()
 
103
  return wrapper
104
  return deco
105
 
106
+ def rmse(y_true, y_pred) -> float:
107
  return float(np.sqrt(mean_squared_error(y_true, y_pred)))
108
 
109
+ def pearson_r(y_true, y_pred) -> float:
110
+ a = np.asarray(y_true, dtype=float)
111
+ p = np.asarray(y_pred, dtype=float)
112
+ if a.size < 2: return float("nan")
113
+ return float(np.corrcoef(a, p)[0, 1])
 
 
 
114
 
115
  @st.cache_resource(show_spinner=False)
116
  def load_model(model_path: str):
 
122
  xl = pd.ExcelFile(bio)
123
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
124
 
125
+ def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
 
126
 
127
  def ensure_cols(df, cols):
128
  miss = [c for c in cols if c not in df.columns]
 
134
  def find_sheet(book, names):
135
  low2orig = {k.lower(): k for k in book.keys()}
136
  for nm in names:
137
+ if nm.lower() in low2orig: return low2orig[nm.lower()]
 
138
  return None
139
 
140
  def _nice_tick0(xmin: float, step: int = 100) -> float:
141
  return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
142
 
143
+ # =========================
144
+ # Cross-plot (Matplotlib, static)
145
+ # =========================
146
+ def cross_plot_static(actual, pred):
147
  a = pd.Series(actual).astype(float)
148
  p = pd.Series(pred).astype(float)
149
 
150
+ fixed_min, fixed_max = 6000, 10000
151
+ ticks = np.arange(6000, 10001, 1000)
 
152
 
153
+ fig, ax = plt.subplots(figsize=(CROSS_W/100, CROSS_H/100), dpi=100)
154
 
155
+ ax.scatter(a, p, s=14, c=COLORS["pred"], alpha=0.9, edgecolors="none")
 
 
 
 
 
 
156
 
157
+ # 1:1 diagonal
158
+ ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
159
+ linestyle="--", linewidth=1.2, color=COLORS["ref"])
 
 
 
160
 
161
+ # Limits and ticks
162
+ ax.set_xlim(fixed_min, fixed_max)
163
+ ax.set_ylim(fixed_min, fixed_max)
164
+ ax.set_xticks(ticks)
165
+ ax.set_yticks(ticks)
 
 
166
 
167
+ # Equal aspect for true 45°
168
+ ax.set_aspect('equal', adjustable='box')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ # Thousands formatter
171
+ fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
172
+ ax.xaxis.set_major_formatter(fmt)
173
+ ax.yaxis.set_major_formatter(fmt)
174
+
175
+ # Labels (bold, larger)
176
+ ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=16)
177
+ ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=16)
178
 
179
+ # Grid & frame
180
+ ax.grid(True, linestyle=":", alpha=0.35)
181
+ for spine in ax.spines.values():
182
+ spine.set_visible(True)
183
+ spine.set_linewidth(1.2)
184
+ spine.set_color("#444")
185
 
186
+ fig.tight_layout()
187
+ return fig
188
+
189
+ # =========================
190
+ # Track plot (Plotly)
191
+ # =========================
192
  def track_plot(df, include_actual=True):
193
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
194
+ if depth_col is not None:
195
  y = pd.Series(df[depth_col]).astype(float)
196
  ylab = depth_col
197
+ y_range = [float(y.max()), float(y.min())] # reverse
198
  else:
199
  y = pd.Series(np.arange(1, len(df) + 1))
200
  ylab = "Point Index"
201
+ y_range = [float(y.max()), float(y.min())]
202
 
203
+ # X (UCS) range & ticks
 
204
  x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
205
  if include_actual and TARGET in df.columns:
206
  x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
 
207
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
208
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
209
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
210
  tick0 = _nice_tick0(xmin, step=100)
211
 
212
  fig = go.Figure()
 
213
  fig.add_trace(go.Scatter(
214
  x=df["UCS_Pred"], y=y, mode="lines",
215
  line=dict(color=COLORS["pred"], width=1.8),
216
  name="UCS_Pred",
217
+ hovertemplate="UCS_Pred: %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
218
  ))
 
219
  if include_actual and TARGET in df.columns:
220
  fig.add_trace(go.Scatter(
221
  x=df[TARGET], y=y, mode="lines",
222
  line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
223
  name="UCS (actual)",
224
+ hovertemplate="UCS (actual): %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
225
  ))
226
 
227
  fig.update_layout(
228
+ width=TRACK_W, height=TRACK_H, paper_bgcolor="#fff", plot_bgcolor="#fff",
229
+ margin=dict(l=72, r=18, t=36, b=48), hovermode="closest",
 
 
230
  font=dict(size=FONT_SZ),
231
  legend=dict(
232
  x=0.98, y=0.05, xanchor="right", yanchor="bottom",
 
234
  ),
235
  legend_title_text=""
236
  )
 
237
  fig.update_xaxes(
238
+ title_text="<b>UCS (psi)</b>", title_font=dict(size=18),
 
239
  side="top", range=[xmin, xmax],
240
+ ticks="outside", tickformat=",.0f", tickmode="auto", tick0=tick0,
 
241
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
242
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
243
  )
 
244
  fig.update_yaxes(
245
+ title_text=f"<b>{ylab}</b>", title_font=dict(size=18),
246
+ range=y_range, ticks="outside",
 
 
247
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
248
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
249
  )
 
250
  return fig
251
 
252
  # ---------- Preview modal (matplotlib) ----------
 
292
  t1, t2 = st.tabs(["Tracks", "Summary"])
293
  with t1: st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
294
  with t2:
295
+ tbl = df[FEATURES].agg(['min','max','mean','std']).T.rename(
296
+ columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}
297
+ ).round(2)
298
  st.dataframe(tbl.reset_index(names="Feature"), use_container_width=True)
299
 
300
  # =========================
301
+ # Load model
302
  # =========================
303
  def ensure_model() -> Path|None:
304
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
 
347
  st.session_state.setdefault("dev_preview",False)
348
 
349
  # =========================
350
+ # Branding in Sidebar
351
  # =========================
352
+ st.sidebar.image("logo.png", use_column_width=True)
353
+ st.sidebar.markdown(
354
+ "<div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>"
355
+ "<div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>",
356
+ unsafe_allow_html=True
 
 
 
 
 
 
357
  )
358
 
359
  # =========================
 
393
  st.session_state.dev_preview = True
394
 
395
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
396
+ # nav
397
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
398
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
399
 
400
+ # ---- Pinned helper at the very top ----
401
  helper_top = st.container()
402
  with helper_top:
403
  st.subheader("Case Building")
 
420
  tr["UCS_Pred"] = model.predict(tr[FEATURES])
421
  te["UCS_Pred"] = model.predict(te[FEATURES])
422
 
423
+ st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
 
 
424
  st.session_state.results["m_train"]={
425
+ "R": pearson_r(tr[TARGET], tr["UCS_Pred"]),
426
  "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
427
  "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
428
  }
429
  st.session_state.results["m_test"]={
430
+ "R": pearson_r(te[TARGET], te["UCS_Pred"]),
431
  "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
432
  "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
433
  }
 
438
 
439
  def _dev_block(df, m):
440
  c1,c2,c3 = st.columns(3)
441
+ c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
442
+
 
443
  left, spacer, right = st.columns(PLOT_COLS)
444
  with left:
445
+ pad, plotcol = left.columns([CROSS_NUDGE, 1])
446
  with plotcol:
447
+ st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
 
 
 
 
448
  with right:
449
  st.plotly_chart(
450
  track_plot(df, include_actual=True),
 
493
  oor_pct = float(any_viol.mean()*100.0)
494
  if any_viol.any():
495
  tbl = df.loc[any_viol, FEATURES].copy()
496
+ for c in FEATURES:
497
+ if pd.api.types.is_numeric_dtype(tbl[c]):
498
+ tbl[c] = tbl[c].round(2)
499
+ tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
500
  st.session_state.results["m_val"]={
501
+ "R": pearson_r(df[TARGET], df["UCS_Pred"]),
502
  "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
503
  "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
504
  }
505
+ st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
 
 
 
 
 
506
  st.session_state.results["oor_tbl"]=tbl
507
 
508
  if "Validate" in st.session_state.results:
509
  m = st.session_state.results["m_val"]
510
  c1,c2,c3 = st.columns(3)
511
+ c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
 
 
512
 
513
  left, spacer, right = st.columns(PLOT_COLS)
514
  with left:
515
  pad, plotcol = left.columns([CROSS_NUDGE, 1])
516
  with plotcol:
517
+ st.pyplot(
518
+ cross_plot_static(st.session_state.results["Validate"][TARGET],
519
+ st.session_state.results["Validate"]["UCS_Pred"]),
520
+ use_container_width=False
521
  )
522
  with right:
523
  st.plotly_chart(
 
572
 
573
  if "PredictOnly" in st.session_state.results:
574
  df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
 
575
  left, spacer, right = st.columns(PLOT_COLS)
576
  with left:
577
  table = pd.DataFrame({
578
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
579
+ "Value": [sv["n"],
580
+ round(sv["pred_min"],2),
581
+ round(sv["pred_max"],2),
582
+ round(sv["pred_mean"],2),
583
+ round(sv["pred_std"],2),
584
+ f'{sv["oor"]:.1f}%']
585
  })
586
  st.success("Predictions ready ✓")
587
  st.dataframe(table, use_container_width=True, hide_index=True)
 
604
  </div>
605
  """,
606
  unsafe_allow_html=True
607
+ )