UCS2014 commited on
Commit
2971cdf
·
verified ·
1 Parent(s): cbd43ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +562 -575
app.py CHANGED
@@ -30,8 +30,8 @@ CROSS_H = 250
30
  TRACK_H = 1000 # px (plotly height; width auto-fits column)
31
  # NEW: Add a TRACK_W variable to control the width
32
  TRACK_W = 500 # px (plotly width)
33
- FONT_SZ  = 13
34
- BOLD_FONT = "Arial Black, Arial, sans-serif"  # used for bold axis titles & ticks
35
 
36
  # =========================
37
  # Page / CSS
@@ -41,16 +41,16 @@ st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wi
41
  # General CSS (logo helpers etc.)
42
  st.markdown("""
43
  <style>
44
-   .brand-logo { width: 200px; height: auto; object-fit: contain; }
45
-   .sidebar-header { display:flex; align-items:center; gap:12px; }
46
-   .sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
47
-   .sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
48
-   .centered-container {
49
-     display: flex;
50
-     flex-direction: column;
51
-     align-items: center;
52
-     text-align: center;
53
-   }
54
  </style>
55
  """, unsafe_allow_html=True)
56
 
@@ -59,12 +59,12 @@ st.markdown("""
59
  <style>
60
  /* This targets the main content area */
61
  .main .block-container {
62
-     overflow: unset !important;
63
  }
64
 
65
  /* This targets the vertical block that holds all your elements */
66
  div[data-testid="stVerticalBlock"] {
67
-     overflow: unset !important;
68
  }
69
  </style>
70
  """, unsafe_allow_html=True)
@@ -87,52 +87,52 @@ section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] s
87
  st.markdown("""
88
  <style>
89
  div[data-testid="stExpander"] > details > summary {
90
-   position: sticky;
91
-   top: 0;
92
-   z-index: 10;
93
-   background: #fff;
94
-   border-bottom: 1px solid #eee;
95
  }
96
  div[data-testid="stExpander"] div[data-baseweb="tab-list"] {
97
-   position: sticky;
98
-   top: 42px;    /* adjust if your expander header height differs */
99
-   z-index: 9;
100
-   background: #fff;
101
-   padding-top: 6px;
102
  }
103
  </style>
104
  """, unsafe_allow_html=True)
105
 
106
  # Center text in all pandas Styler tables (headers + cells)
107
  TABLE_CENTER_CSS = [
108
-     dict(selector="th", props=[("text-align", "center")]),
109
-     dict(selector="td", props=[("text-align", "center")]),
110
  ]
111
 
112
  # NEW: CSS for the message box
113
  st.markdown("""
114
  <style>
115
  .st-message-box {
116
-     background-color: #f0f2f6;
117
-     color: #333333;
118
-     padding: 10px;
119
-     border-radius: 10px;
120
-     border: 1px solid #e6e9ef;
121
  }
122
  .st-message-box.st-success {
123
-     background-color: #d4edda;
124
-     color: #155724;
125
-     border-color: #c3e6cb;
126
  }
127
  .st-message-box.st-warning {
128
-     background-color: #fff3cd;
129
-     color: #856404;
130
-     border-color: #ffeeba;
131
  }
132
  .st-message-box.st-error {
133
-     background-color: #f8d7da;
134
-     color: #721c24;
135
-     border-color: #f5c6cb;
136
  }
137
  </style>
138
  """, unsafe_allow_html=True)
@@ -141,42 +141,42 @@ st.markdown("""
141
  # Password gate
142
  # =========================
143
  def inline_logo(path="logo.png") -> str:
144
-     try:
145
-         p = Path(path)
146
-         if not p.exists(): return ""
147
-         return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
148
-     except Exception:
149
-         return ""
150
 
151
  def add_password_gate() -> None:
152
-     try:
153
-         required = st.secrets.get("APP_PASSWORD", "")
154
-     except Exception:
155
-         required = os.environ.get("APP_PASSWORD", "")
156
-
157
-     if not required:
158
-         st.warning("Set APP_PASSWORD in Secrets (or environment) and restart.")
159
-         st.stop()
160
-
161
-     if st.session_state.get("auth_ok", False):
162
-         return
163
-
164
-     st.sidebar.markdown(f"""
165
-         <div class="centered-container">
166
-             <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
167
-             <div style='font-weight:800;font-size:1.2rem; margin-top: 10px;'>ST_GeoMech_UCS</div>
168
-             <div style='color:#667085;'>Smart Thinking • Secure Access</div>
169
-         </div>
170
-         """, unsafe_allow_html=True
171
-     )
172
-     pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
173
-     if st.sidebar.button("Unlock", type="primary"):
174
-         if pwd == required:
175
-             st.session_state.auth_ok = True
176
-             st.rerun()
177
-         else:
178
-             st.error("Incorrect key.")
179
-     st.stop()
180
 
181
  add_password_gate()
182
 
@@ -184,260 +184,260 @@ add_password_gate()
184
  # Utilities
185
  # =========================
186
  def rmse(y_true, y_pred) -> float:
187
-     return float(np.sqrt(mean_squared_error(y_true, y_pred)))
188
 
189
  def pearson_r(y_true, y_pred) -> float:
190
-     a = np.asarray(y_true, dtype=float)
191
-     p = np.asarray(y_pred,   dtype=float)
192
-     if a.size < 2: return float("nan")
193
-     return float(np.corrcoef(a, p)[0, 1])
194
 
195
  @st.cache_resource(show_spinner=False)
196
  def load_model(model_path: str):
197
-     return joblib.load(model_path)
198
 
199
  @st.cache_data(show_spinner=False)
200
  def parse_excel(data_bytes: bytes):
201
-     bio = io.BytesIO(data_bytes)
202
-     xl = pd.ExcelFile(bio)
203
-     return {sh: xl.parse(sh) for sh in xl.sheet_names}
204
 
205
  def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
206
 
207
  def ensure_cols(df, cols):
208
-     miss = [c for c in cols if c not in df.columns]
209
-     if miss:
210
-         st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
211
-         return False
212
-     return True
213
 
214
  def find_sheet(book, names):
215
-     low2orig = {k.lower(): k for k in book.keys()}
216
-     for nm in names:
217
-         if nm.lower() in low2orig: return low2orig[nm.lower()]
218
-     return None
219
 
220
  def _nice_tick0(xmin: float, step: int = 100) -> float:
221
-     return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
222
 
223
  def df_centered_rounded(df: pd.DataFrame, hide_index=True):
224
-     """Center headers & cells; format numeric columns to 2 decimals."""
225
-     out = df.copy()
226
-     numcols = out.select_dtypes(include=[np.number]).columns
227
-     styler = (
228
-         out.style
229
-             .format({c: "{:.2f}" for c in numcols})
230
-             .set_properties(**{"text-align": "center"})
231
-             .set_table_styles(TABLE_CENTER_CSS)
232
-     )
233
-     st.dataframe(styler, use_container_width=True, hide_index=hide_index)
234
 
235
  # =========================
236
  # Cross plot (Matplotlib, fixed limits & ticks)
237
  # =========================
238
  def cross_plot_static(actual, pred):
239
-     a = pd.Series(actual, dtype=float)
240
-     p = pd.Series(pred,   dtype=float)
241
 
242
-     fixed_min, fixed_max = 6000, 10000
243
-     ticks = np.arange(fixed_min, fixed_max + 1, 1000)
244
 
245
-     dpi = 110
246
-     fig, ax = plt.subplots(
247
-         figsize=(CROSS_W / dpi, CROSS_H / dpi),
248
-         dpi=dpi,
249
-         constrained_layout=False
250
-     )
251
 
252
-     ax.scatter(a, p, s=14, c=COLORS["pred"], alpha=0.9, linewidths=0)
253
-     ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
254
-             linestyle="--", linewidth=1.2, color=COLORS["ref"])
255
 
256
-     ax.set_xlim(fixed_min, fixed_max)
257
-     ax.set_ylim(fixed_min, fixed_max)
258
-     ax.set_xticks(ticks)
259
-     ax.set_yticks(ticks)
260
-     ax.set_aspect("equal", adjustable="box")  # true 45°
261
 
262
-     fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
263
-     ax.xaxis.set_major_formatter(fmt)
264
-     ax.yaxis.set_major_formatter(fmt)
265
 
266
-     ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=4, color="black")
267
-     ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=4, color="black")
268
-     ax.tick_params(labelsize=2, colors="black")
269
 
270
-     ax.grid(True, linestyle=":", alpha=0.3)
271
-     for spine in ax.spines.values():
272
-         spine.set_linewidth(1.1)
273
-         spine.set_color("#444")
274
 
275
-     fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
276
-     return fig
277
 
278
  # =========================
279
  # Track plot (Plotly)
280
  # =========================
281
  def track_plot(df, include_actual=True):
282
-     depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
283
-     if depth_col is not None:
284
-         y = pd.Series(df[depth_col]).astype(float)
285
-         ylab = depth_col
286
-         y_range = [float(y.max()), float(y.min())]  # reverse
287
-     else:
288
-         y = pd.Series(np.arange(1, len(df) + 1))
289
-         ylab = "Point Index"
290
-         y_range = [float(y.max()), float(y.min())]
291
-
292
-     # X (UCS) range & ticks
293
-     x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
294
-     if include_actual and TARGET in df.columns:
295
-         x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
296
-     x_lo, x_hi = float(x_series.min()), float(x_series.max())
297
-     x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
298
-     xmin, xmax = x_lo - x_pad, x_hi + x_pad
299
-     tick0 = _nice_tick0(xmin, step=100)
300
-
301
-     fig = go.Figure()
302
-     fig.add_trace(go.Scatter(
303
-         x=df["UCS_Pred"], y=y, mode="lines",
304
-         line=dict(color=COLORS["pred"], width=1.8),
305
-         name="UCS_Pred",
306
-         hovertemplate="UCS_Pred: %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
307
-     ))
308
-     if include_actual and TARGET in df.columns:
309
-         fig.add_trace(go.Scatter(
310
-             x=df[TARGET], y=y, mode="lines",
311
-             line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
312
-             name="UCS (actual)",
313
-             hovertemplate="UCS (actual): %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
314
-         ))
315
-
316
-     fig.update_layout(
317
-         height=TRACK_H,
318
-         width=TRACK_W, # Set the width here
319
-         autosize=False, # Disable autosizing to respect the width
320
-         paper_bgcolor="#fff", plot_bgcolor="#fff",
321
-         margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
322
-         font=dict(size=FONT_SZ, color="#000"),
323
-         legend=dict(
324
-             x=0.98, y=0.05, xanchor="right", yanchor="bottom",
325
-             bgcolor="rgba(255,255,255,0.75)", bordercolor="#ccc", borderwidth=1
326
-         ),
327
-         legend_title_text=""
328
-     )
329
-
330
-     # Bold, black axis titles & ticks
331
-     fig.update_xaxes(
332
-         title_text="UCS (psi)",
333
-         title_font=dict(size=20, family=BOLD_FONT, color="#000"),
334
-         tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
335
-         side="top",
336
-         range=[xmin, xmax],
337
-         ticks="outside",
338
-         tickformat=",.0f",
339
-         tickmode="auto",
340
-         tick0=tick0,
341
-         showline=True, linewidth=1.2, linecolor="#444", mirror=True,
342
-         showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
343
-     )
344
-     fig.update_yaxes(
345
-         title_text=ylab,
346
-         title_font=dict(size=20, family=BOLD_FONT, color="#000"),
347
-         tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
348
-         range=y_range,
349
-         ticks="outside",
350
-         showline=True, linewidth=1.2, linecolor="#444", mirror=True,
351
-         showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
352
-     )
353
-
354
-     return fig
355
 
356
  # ---------- Preview modal (matplotlib) ----------
357
  def preview_tracks(df: pd.DataFrame, cols: list[str]):
358
-     cols = [c for c in cols if c in df.columns]
359
-     n = len(cols)
360
-     if n == 0:
361
-         fig, ax = plt.subplots(figsize=(4, 2))
362
-         ax.text(0.5,0.5,"No selected columns",ha="center",va="center"); ax.axis("off")
363
-         return fig
364
-     fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
365
-     if n == 1: axes = [axes]
366
-     idx = np.arange(1, len(df) + 1)
367
-     for ax, col in zip(axes, cols):
368
-         ax.plot(df[col], idx, '-', lw=1.4, color="#333")
369
-         ax.set_xlabel(col); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
370
-         ax.grid(True, linestyle=":", alpha=0.3)
371
-         for s in ax.spines.values(): s.set_visible(True)
372
-     axes[0].set_ylabel("Point Index")
373
-     return fig
374
 
375
  # Modal wrapper (Streamlit compatibility)
376
  try:
377
-     dialog = st.dialog
378
  except AttributeError:
379
-     def dialog(title):
380
-         def deco(fn):
381
-             def wrapper(*args, **kwargs):
382
-                 with st.expander(title, expanded=True):
383
-                     return fn(*args, **kwargs)
384
-             return wrapper
385
-         return deco
386
 
387
  def preview_modal(book: dict[str, pd.DataFrame]):
388
-     if not book:
389
-         st.info("No data loaded yet."); return
390
-     names = list(book.keys())
391
-     tabs = st.tabs(names)
392
-     for t, name in zip(tabs, names):
393
-         with t:
394
-             df = book[name]
395
-             t1, t2 = st.tabs(["Tracks", "Summary"])
396
-             with t1:
397
-                 st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
398
-             with t2:
399
-                 tbl = (df[FEATURES]
400
-                         .agg(['min','max','mean','std'])
401
-                         .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
402
-                 df_centered_rounded(tbl.reset_index(names="Feature"))
403
 
404
  # =========================
405
  # Load model
406
  # =========================
407
  def ensure_model() -> Path|None:
408
-     for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
409
-         if p.exists() and p.stat().st_size > 0: return p
410
-     url = os.environ.get("MODEL_URL", "")
411
-     if not url: return None
412
-     try:
413
-         import requests
414
-         DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
415
-         with requests.get(url, stream=True, timeout=30) as r:
416
-             r.raise_for_status()
417
-             with open(DEFAULT_MODEL, "wb") as f:
418
-                 for chunk in r.iter_content(1<<20):
419
-                     if chunk: f.write(chunk)
420
-         return DEFAULT_MODEL
421
-     except Exception:
422
-         return None
423
 
424
  mpath = ensure_model()
425
  if not mpath:
426
-     st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
427
-     st.stop()
428
  try:
429
-     model = load_model(str(mpath))
430
  except Exception as e:
431
-     st.error(f"Failed to load model: {e}")
432
-     st.stop()
433
 
434
  meta_path = MODELS_DIR / "meta.json"
435
  if meta_path.exists():
436
-     try:
437
-         meta = json.loads(meta_path.read_text(encoding="utf-8"))
438
-         FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
439
-     except Exception:
440
-         pass
441
 
442
  # =========================
443
  # Session state
@@ -455,343 +455,330 @@ st.session_state.setdefault("show_preview_modal", False) # New state variable
455
  # Branding in Sidebar
456
  # =========================
457
  st.sidebar.markdown(f"""
458
-     <div class="centered-container">
459
-         <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
460
-         <div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>
461
-         <div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>
462
-     </div>
463
-     """, unsafe_allow_html=True
464
  )
465
 
466
  # =========================
467
  # Reusable Sticky Header Function
468
  # =========================
469
  def sticky_header(title, message):
470
-     st.markdown(
471
-         f"""
472
-         <style>
473
-         .sticky-container {{
474
-             position: sticky;
475
-             top: 0;
476
-             background-color: white;
477
-             z-index: 100;
478
-             padding-top: 10px;
479
-             padding-bottom: 10px;
480
-             border-bottom: 1px solid #eee;
481
-         }}
482
-         </style>
483
-         <div class="sticky-container">
484
-             <h3>{title}</h3>
485
-             <p>{message}</p>
486
-         </div>
487
-         """,
488
-         unsafe_allow_html=True
489
-     )
490
 
491
  # =========================
492
  # INTRO
493
  # =========================
494
  if st.session_state.app_step == "intro":
495
-     st.header("Welcome!")
496
-     st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
497
-     st.subheader("How It Works")
498
-     st.markdown(
499
-         "1) **Upload your data to build the case and preview the performance of our model.** \n"
500
-         "2) Click **Run Model** to compute metrics and plots.  \n"
501
-         "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
502
-     )
503
-     if st.button("Start Showcase", type="primary"):
504
-         st.session_state.app_step = "dev"; st.rerun()
505
 
506
  # =========================
507
  # CASE BUILDING
508
  # =========================
509
  if st.session_state.app_step == "dev":
510
-     st.sidebar.header("Case Building")
511
-     up = st.sidebar.file_uploader("Upload Your Data File", type=["xlsx","xls"])
512
-     if up is not None:
513
-         st.session_state.dev_file_bytes = up.getvalue()
514
-         st.session_state.dev_file_name = up.name
515
-         st.session_state.dev_file_loaded = True
516
-         st.session_state.dev_preview = False
517
-     if st.session_state.dev_file_loaded:
518
-         tmp = read_book_bytes(st.session_state.dev_file_bytes)
519
-         if tmp:
520
-             df0 = next(iter(tmp.values()))
521
-             st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
522
-
523
-     if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
524
-         st.session_state.show_preview_modal = True  # Set state to show modal
525
-         st.session_state.dev_preview = True
526
-
527
-     run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
528
-     if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
529
-     if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
530
-
531
-     # Apply sticky header
532
-     if st.session_state.dev_file_loaded and st.session_state.dev_preview:
533
-         sticky_header("Case Building", "Previewed ✓ — now click **Run Model**.")
534
-     elif st.session_state.dev_file_loaded:
535
-         sticky_header("Case Building", "📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
536
-     else:
537
-         sticky_header("Case Building", "**Upload your data to build a case, then run the model to review development performance.**")
538
-
539
-     if run and st.session_state.dev_file_bytes:
540
-         book = read_book_bytes(st.session_state.dev_file_bytes)
541
-         sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
542
-         sh_test  = find_sheet(book, ["Test","Testing","testing2","test","testing"])
543
-         if sh_train is None or sh_test is None:
544
-             st.markdown('<div class="st-message-box st-error">Workbook must include Train/Training/training2 and Test/Testing/testing2 sheets.</div>', unsafe_allow_html=True)
545
-             st.stop()
546
-         tr = book[sh_train].copy(); te = book[sh_test].copy()
547
-         if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
548
-             st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True)
549
-             st.stop()
550
-         tr["UCS_Pred"] = model.predict(tr[FEATURES])
551
-         te["UCS_Pred"] = model.predict(te[FEATURES])
552
-
553
-         st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
554
-         st.session_state.results["m_train"]={
555
-             "R": pearson_r(tr[TARGET], tr["UCS_Pred"]),
556
-             "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
557
-             "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
558
-         }
559
-         st.session_state.results["m_test"]={
560
-             "R": pearson_r(te[TARGET], te["UCS_Pred"]),
561
-             "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
562
-             "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
563
-         }
564
-
565
-         tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
566
-         st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
567
-         st.markdown('<div class="st-message-box st-success">Case has been built and results are displayed below.</div>', unsafe_allow_html=True)
568
-
569
-     def _dev_block(df, m):
570
-         c1,c2,c3 = st.columns(3)
571
-         c1.metric("R", f"{m['R']:.2f}")
572
-         c2.metric("RMSE", f"{m['RMSE']:.2f}")
573
-         c3.metric("MAE", f"{m['MAE']:.2f}")
574
-
575
-         # NEW: Footer for metric abbreviations
576
-         st.markdown("""
577
-             <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
578
-                 <strong>R:</strong> Pearson Correlation Coefficient<br>
579
-                 <strong>RMSE:</strong> Root Mean Square Error<br>
580
-                 <strong>MAE:</strong> Mean Absolute Error
581
-             </div>
582
-         """, unsafe_allow_html=True)
583
-
584
-         # 2-column layout, big gap (prevents overlap)
585
-         col_cross, col_track = st.columns([3, 2], gap="large")
586
-         with col_cross:
587
-             st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
588
-         with col_track:
589
-             st.plotly_chart(
590
-                 track_plot(df, include_actual=True),
591
-                 use_container_width=False, # Set to False to honor the width in track_plot()
592
-                 config={"displayModeBar": False, "scrollZoom": True}
593
-             )
594
-
595
-     if "Train" in st.session_state.results or "Test" in st.session_state.results:
596
-         tab1, tab2 = st.tabs(["Training", "Testing"])
597
-         if "Train" in st.session_state.results:
598
-             with tab1: _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
599
-         if "Test" in st.session_state.results:
600
-             with tab2: _dev_block(st.session_state.results["Test"],  st.session_state.results["m_test"])
601
 
602
  # =========================
603
  # VALIDATION (with actual UCS)
604
  # =========================
605
  if st.session_state.app_step == "validate":
606
-     st.sidebar.header("Validate the Model")
607
-     up = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"])
608
-     if up is not None:
609
-         book = read_book_bytes(up.getvalue())
610
-         if book:
611
-             df0 = next(iter(book.values()))
612
-             st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
613
-     if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
614
-         st.session_state.show_preview_modal = True  # Set state to show modal
615
-     go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
616
-     if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
617
-     if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
618
-
619
-     sticky_header("Validate the Model", "Upload a dataset with the same **features** and **UCS** to evaluate performance.")
620
-
621
-     if go_btn and up is not None:
622
-         book = read_book_bytes(up.getvalue())
623
-         name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
624
-         df = book[name].copy()
625
-         if not ensure_cols(df, FEATURES+[TARGET]): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
626
-         df["UCS_Pred"] = model.predict(df[FEATURES])
627
-         st.session_state.results["Validate"]=df
628
-
629
-         ranges = st.session_state.train_ranges; oor_pct = 0.0; tbl=None
630
-         if ranges:
631
-             any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
632
-             oor_pct = float(any_viol.mean()*100.0)
633
-             if any_viol.any():
634
-                 tbl = df.loc[any_viol, FEATURES].copy()
635
-                 for c in FEATURES:
636
-                     if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
637
-                 tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
638
-         st.session_state.results["m_val"]={
639
-             "R": pearson_r(df[TARGET], df["UCS_Pred"]),
640
-             "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
641
-             "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
642
-         }
643
-         st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
644
-         st.session_state.results["oor_tbl"]=tbl
645
-
646
-     if "Validate" in st.session_state.results:
647
-         m = st.session_state.results["m_val"]
648
-         c1,c2,c3 = st.columns(3)
649
-         c1.metric("R", f"{m['R']:.2f}")
650
-         c2.metric("RMSE", f"{m['RMSE']:.2f}")
651
-         c3.metric("MAE", f"{m['MAE']:.2f}")
652
-
653
-         # NEW: Footer for metric abbreviations
654
-         st.markdown("""
655
-             <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
656
-                 <strong>R:</strong> Pearson Correlation Coefficient<br>
657
-                 <strong>RMSE:</strong> Root Mean Square Error<br>
658
-                 <strong>MAE:</strong> Mean Absolute Error
659
-             </div>
660
-         """, unsafe_allow_html=True)
661
-    
662
-         col_cross, col_track = st.columns([3, 2], gap="large")
663
-         with col_cross:
664
-             st.pyplot(
665
-                 cross_plot_static(st.session_state.results["Validate"][TARGET],
666
-                                      st.session_state.results["Validate"]["UCS_Pred"]),
667
-                 use_container_width=False
668
-             )
669
-         with col_track:
670
-             st.plotly_chart(
671
-                 track_plot(st.session_state.results["Validate"], include_actual=True),
672
-                 use_container_width=False, # Set to False to honor the width in track_plot()
673
-                 config={"displayModeBar": False, "scrollZoom": True}
674
-             )
675
-
676
-         sv = st.session_state.results["sv_val"]
677
-         if sv["oor"] > 0: st.markdown('<div class="st-message-box st-warning">Some inputs fall outside **training min–max** ranges.</div>', unsafe_allow_html=True)
678
-         if st.session_state.results["oor_tbl"] is not None:
679
-             st.write("*Out-of-range rows (vs. Training min–max):*")
680
-             df_centered_rounded(st.session_state.results["oor_tbl"])
681
 
682
  # =========================
683
  # PREDICTION (no actual UCS)
684
  # =========================
685
  if st.session_state.app_step == "predict":
686
-     st.sidebar.header("Prediction (No Actual UCS)")
687
-     up = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"])
688
-     if up is not None:
689
-         book = read_book_bytes(up.getvalue())
690
-         if book:
691
-             df0 = next(iter(book.values()))
692
-             st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
693
-     if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
694
-         st.session_state.show_preview_modal = True  # Set state to show modal
695
-     go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
696
-     if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
697
-
698
-     sticky_header("Prediction", "Upload a dataset with the feature columns (no **UCS**).")
699
-
700
-     if go_btn and up is not None:
701
-         book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
702
-         df = book[name].copy()
703
-         if not ensure_cols(df, FEATURES): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
704
-         df["UCS_Pred"] = model.predict(df[FEATURES])
705
-         st.session_state.results["PredictOnly"]=df
706
-
707
-         ranges = st.session_state.train_ranges; oor_pct = 0.0
708
-         if ranges:
709
-             any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
710
-             oor_pct = float(any_viol.mean()*100.0)
711
-         st.session_state.results["sv_pred"]={
712
-             "n":len(df),
713
-             "pred_min":float(df["UCS_Pred"].min()),
714
-             "pred_max":float(df["UCS_Pred"].max()),
715
-             "pred_mean":float(df["UCS_Pred"].mean()),
716
-             "pred_std":float(df["UCS_Pred"].std(ddof=0)),
717
-             "oor":oor_pct
718
-         }
719
-
720
-     if "PredictOnly" in st.session_state.results:
721
-         df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
722
-
723
-         col_left, col_right = st.columns([2,3], gap="large")
724
-         with col_left:
725
-             table = pd.DataFrame({
726
-                 "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
727
-                 "Value":  [sv["n"],
728
-                           round(sv["pred_min"],2),
729
-                           round(sv["pred_max"],2),
730
-                           round(sv["pred_mean"],2),
731
-                           round(sv["pred_std"],2),
732
-                           f'{sv["oor"]:.1f}%']
733
-             })
734
-             st.markdown('<div class="st-message-box st-success">Predictions ready ✓</div>', unsafe_allow_html=True)
735
-             df_centered_rounded(table, hide_index=True)
736
-             st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
737
-         with col_right:
738
-             st.plotly_chart(
739
-                 track_plot(df, include_actual=False),
740
-                 use_container_width=False, # Set to False to honor the width in track_plot()
741
-                 config={"displayModeBar": False, "scrollZoom": True}
742
-             )
743
 
744
  # =========================
745
  # Run preview modal after all other elements
746
  # =========================
747
  if st.session_state.show_preview_modal:
748
-     # Get the correct book based on the current app step
749
-     book_to_preview = {}
750
-     if st.session_state.app_step == "dev":
751
-         book_to_preview = read_book_bytes(st.session_state.dev_file_bytes)
752
-     elif st.session_state.app_step in ["validate", "predict"] and up is not None:
753
-         book_to_preview = read_book_bytes(up.getvalue())
754
-
755
-     # Use a try-except block to handle cases where 'up' might be None
756
-     # and the logic tries to access its attributes.
757
-     try:
758
-         if st.session_state.app_step == "validate" and up is not None:
759
-               book_to_preview = read_book_bytes(up.getvalue())
760
-         elif st.session_state.app_step == "predict" and up is not None:
761
-     ��         book_to_preview = read_book_bytes(up.getvalue())
762
-     except NameError:
763
-         book_to_preview = {}
764
-
765
-     with st.expander("Preview data", expanded=True):
766
-         if not book_to_preview:
767
-             st.markdown('<div class="st-message-box">No data loaded yet.</div>', unsafe_allow_html=True)
768
-         else:
769
-             names = list(book_to_preview.keys())
770
-             tabs = st.tabs(names)
771
-             for t, name in zip(tabs, names):
772
-                 with t:
773
-                     df = book_to_preview[name]
774
-                     t1, t2 = st.tabs(["Tracks", "Summary"])
775
-                     with t1:
776
-                         st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
777
-                     with t2:
778
-                         tbl = (df[FEATURES]
779
-                                  .agg(['min','max','mean','std'])
780
-                                  .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
781
-                         df_centered_rounded(tbl.reset_index(names="Feature"))
782
-     # Reset the state variable after the modal is displayed
783
-     st.session_state.show_preview_modal = False
784
 
785
  # =========================
786
  # Footer
787
  # =========================
788
- st.markdown("---")
789
- st.markdown(
790
-     """
791
-     <div style='text-align:center; color:#6b7280; line-height:1.6'>
792
-       ST_GeoMech_UCS © Smart Thinking<br/>
793
-       <strong>Visit our website:</strong> <a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>
794
-     </div>
795
-     """,
796
-     unsafe_allow_html=True
797
- )
 
30
  TRACK_H = 1000 # px (plotly height; width auto-fits column)
31
  # NEW: Add a TRACK_W variable to control the width
32
  TRACK_W = 500 # px (plotly width)
33
+ FONT_SZ = 13
34
+ BOLD_FONT = "Arial Black, Arial, sans-serif" # used for bold axis titles & ticks
35
 
36
  # =========================
37
  # Page / CSS
 
41
  # General CSS (logo helpers etc.)
42
  st.markdown("""
43
  <style>
44
+ .brand-logo { width: 200px; height: auto; object-fit: contain; }
45
+ .sidebar-header { display:flex; align-items:center; gap:12px; }
46
+ .sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
47
+ .sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
48
+ .centered-container {
49
+ display: flex;
50
+ flex-direction: column;
51
+ align-items: center;
52
+ text-align: center;
53
+ }
54
  </style>
55
  """, unsafe_allow_html=True)
56
 
 
59
  <style>
60
  /* This targets the main content area */
61
  .main .block-container {
62
+ overflow: unset !important;
63
  }
64
 
65
  /* This targets the vertical block that holds all your elements */
66
  div[data-testid="stVerticalBlock"] {
67
+ overflow: unset !important;
68
  }
69
  </style>
70
  """, unsafe_allow_html=True)
 
87
  st.markdown("""
88
  <style>
89
  div[data-testid="stExpander"] > details > summary {
90
+ position: sticky;
91
+ top: 0;
92
+ z-index: 10;
93
+ background: #fff;
94
+ border-bottom: 1px solid #eee;
95
  }
96
  div[data-testid="stExpander"] div[data-baseweb="tab-list"] {
97
+ position: sticky;
98
+ top: 42px; /* adjust if your expander header height differs */
99
+ z-index: 9;
100
+ background: #fff;
101
+ padding-top: 6px;
102
  }
103
  </style>
104
  """, unsafe_allow_html=True)
105
 
106
  # Center text in all pandas Styler tables (headers + cells)
107
  TABLE_CENTER_CSS = [
108
+ dict(selector="th", props=[("text-align", "center")]),
109
+ dict(selector="td", props=[("text-align", "center")]),
110
  ]
111
 
112
  # NEW: CSS for the message box
113
  st.markdown("""
114
  <style>
115
  .st-message-box {
116
+ background-color: #f0f2f6;
117
+ color: #333333;
118
+ padding: 10px;
119
+ border-radius: 10px;
120
+ border: 1px solid #e6e9ef;
121
  }
122
  .st-message-box.st-success {
123
+ background-color: #d4edda;
124
+ color: #155724;
125
+ border-color: #c3e6cb;
126
  }
127
  .st-message-box.st-warning {
128
+ background-color: #fff3cd;
129
+ color: #856404;
130
+ border-color: #ffeeba;
131
  }
132
  .st-message-box.st-error {
133
+ background-color: #f8d7da;
134
+ color: #721c24;
135
+ border-color: #f5c6cb;
136
  }
137
  </style>
138
  """, unsafe_allow_html=True)
 
141
  # Password gate
142
  # =========================
143
  def inline_logo(path="logo.png") -> str:
144
+ try:
145
+ p = Path(path)
146
+ if not p.exists(): return ""
147
+ return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
148
+ except Exception:
149
+ return ""
150
 
151
  def add_password_gate() -> None:
152
+ try:
153
+ required = st.secrets.get("APP_PASSWORD", "")
154
+ except Exception:
155
+ required = os.environ.get("APP_PASSWORD", "")
156
+
157
+ if not required:
158
+ st.warning("Set APP_PASSWORD in Secrets (or environment) and restart.")
159
+ st.stop()
160
+
161
+ if st.session_state.get("auth_ok", False):
162
+ return
163
+
164
+ st.sidebar.markdown(f"""
165
+ <div class="centered-container">
166
+ <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
167
+ <div style='font-weight:800;font-size:1.2rem; margin-top: 10px;'>ST_GeoMech_UCS</div>
168
+ <div style='color:#667085;'>Smart Thinking • Secure Access</div>
169
+ </div>
170
+ """, unsafe_allow_html=True
171
+ )
172
+ pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
173
+ if st.sidebar.button("Unlock", type="primary"):
174
+ if pwd == required:
175
+ st.session_state.auth_ok = True
176
+ st.rerun()
177
+ else:
178
+ st.error("Incorrect key.")
179
+ st.stop()
180
 
181
  add_password_gate()
182
 
 
184
  # Utilities
185
  # =========================
186
  def rmse(y_true, y_pred) -> float:
187
+ return float(np.sqrt(mean_squared_error(y_true, y_pred)))
188
 
189
  def pearson_r(y_true, y_pred) -> float:
190
+ a = np.asarray(y_true, dtype=float)
191
+ p = np.asarray(y_pred, dtype=float)
192
+ if a.size < 2: return float("nan")
193
+ return float(np.corrcoef(a, p)[0, 1])
194
 
195
  @st.cache_resource(show_spinner=False)
196
  def load_model(model_path: str):
197
+ return joblib.load(model_path)
198
 
199
  @st.cache_data(show_spinner=False)
200
  def parse_excel(data_bytes: bytes):
201
+ bio = io.BytesIO(data_bytes)
202
+ xl = pd.ExcelFile(bio)
203
+ return {sh: xl.parse(sh) for sh in xl.sheet_names}
204
 
205
  def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
206
 
207
  def ensure_cols(df, cols):
208
+ miss = [c for c in cols if c not in df.columns]
209
+ if miss:
210
+ st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
211
+ return False
212
+ return True
213
 
214
  def find_sheet(book, names):
215
+ low2orig = {k.lower(): k for k in book.keys()}
216
+ for nm in names:
217
+ if nm.lower() in low2orig: return low2orig[nm.lower()]
218
+ return None
219
 
220
  def _nice_tick0(xmin: float, step: int = 100) -> float:
221
+ return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
222
 
223
  def df_centered_rounded(df: pd.DataFrame, hide_index=True):
224
+ """Center headers & cells; format numeric columns to 2 decimals."""
225
+ out = df.copy()
226
+ numcols = out.select_dtypes(include=[np.number]).columns
227
+ styler = (
228
+ out.style
229
+ .format({c: "{:.2f}" for c in numcols})
230
+ .set_properties(**{"text-align": "center"})
231
+ .set_table_styles(TABLE_CENTER_CSS)
232
+ )
233
+ st.dataframe(styler, use_container_width=True, hide_index=hide_index)
234
 
235
  # =========================
236
  # Cross plot (Matplotlib, fixed limits & ticks)
237
  # =========================
238
  def cross_plot_static(actual, pred):
239
+ a = pd.Series(actual, dtype=float)
240
+ p = pd.Series(pred, dtype=float)
241
 
242
+ fixed_min, fixed_max = 6000, 10000
243
+ ticks = np.arange(fixed_min, fixed_max + 1, 1000)
244
 
245
+ dpi = 110
246
+ fig, ax = plt.subplots(
247
+ figsize=(CROSS_W / dpi, CROSS_H / dpi),
248
+ dpi=dpi,
249
+ constrained_layout=False
250
+ )
251
 
252
+ ax.scatter(a, p, s=14, c=COLORS["pred"], alpha=0.9, linewidths=0)
253
+ ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
254
+ linestyle="--", linewidth=1.2, color=COLORS["ref"])
255
 
256
+ ax.set_xlim(fixed_min, fixed_max)
257
+ ax.set_ylim(fixed_min, fixed_max)
258
+ ax.set_xticks(ticks)
259
+ ax.set_yticks(ticks)
260
+ ax.set_aspect("equal", adjustable="box") # true 45°
261
 
262
+ fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
263
+ ax.xaxis.set_major_formatter(fmt)
264
+ ax.yaxis.set_major_formatter(fmt)
265
 
266
+ ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=4, color="black")
267
+ ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=4, color="black")
268
+ ax.tick_params(labelsize=2, colors="black")
269
 
270
+ ax.grid(True, linestyle=":", alpha=0.3)
271
+ for spine in ax.spines.values():
272
+ spine.set_linewidth(1.1)
273
+ spine.set_color("#444")
274
 
275
+ fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
276
+ return fig
277
 
278
  # =========================
279
  # Track plot (Plotly)
280
  # =========================
281
  def track_plot(df, include_actual=True):
282
+ depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
283
+ if depth_col is not None:
284
+ y = pd.Series(df[depth_col]).astype(float)
285
+ ylab = depth_col
286
+ y_range = [float(y.max()), float(y.min())] # reverse
287
+ else:
288
+ y = pd.Series(np.arange(1, len(df) + 1))
289
+ ylab = "Point Index"
290
+ y_range = [float(y.max()), float(y.min())]
291
+
292
+ # X (UCS) range & ticks
293
+ x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
294
+ if include_actual and TARGET in df.columns:
295
+ x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
296
+ x_lo, x_hi = float(x_series.min()), float(x_series.max())
297
+ x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
298
+ xmin, xmax = x_lo - x_pad, x_hi + x_pad
299
+ tick0 = _nice_tick0(xmin, step=100)
300
+
301
+ fig = go.Figure()
302
+ fig.add_trace(go.Scatter(
303
+ x=df["UCS_Pred"], y=y, mode="lines",
304
+ line=dict(color=COLORS["pred"], width=1.8),
305
+ name="UCS_Pred",
306
+ hovertemplate="UCS_Pred: %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
307
+ ))
308
+ if include_actual and TARGET in df.columns:
309
+ fig.add_trace(go.Scatter(
310
+ x=df[TARGET], y=y, mode="lines",
311
+ line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
312
+ name="UCS (actual)",
313
+ hovertemplate="UCS (actual): %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
314
+ ))
315
+
316
+ fig.update_layout(
317
+ height=TRACK_H,
318
+ width=TRACK_W, # Set the width here
319
+ autosize=False, # Disable autosizing to respect the width
320
+ paper_bgcolor="#fff", plot_bgcolor="#fff",
321
+ margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
322
+ font=dict(size=FONT_SZ, color="#000"),
323
+ legend=dict(
324
+ x=0.98, y=0.05, xanchor="right", yanchor="bottom",
325
+ bgcolor="rgba(255,255,255,0.75)", bordercolor="#ccc", borderwidth=1
326
+ ),
327
+ legend_title_text=""
328
+ )
329
+
330
+ # Bold, black axis titles & ticks
331
+ fig.update_xaxes(
332
+ title_text="UCS (psi)",
333
+ title_font=dict(size=20, family=BOLD_FONT, color="#000"),
334
+ tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
335
+ side="top",
336
+ range=[xmin, xmax],
337
+ ticks="outside",
338
+ tickformat=",.0f",
339
+ tickmode="auto",
340
+ tick0=tick0,
341
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
342
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
343
+ )
344
+ fig.update_yaxes(
345
+ title_text=ylab,
346
+ title_font=dict(size=20, family=BOLD_FONT, color="#000"),
347
+ tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
348
+ range=y_range,
349
+ ticks="outside",
350
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
351
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
352
+ )
353
+
354
+ return fig
355
 
356
  # ---------- Preview modal (matplotlib) ----------
357
  def preview_tracks(df: pd.DataFrame, cols: list[str]):
358
+ cols = [c for c in cols if c in df.columns]
359
+ n = len(cols)
360
+ if n == 0:
361
+ fig, ax = plt.subplots(figsize=(4, 2))
362
+ ax.text(0.5,0.5,"No selected columns",ha="center",va="center"); ax.axis("off")
363
+ return fig
364
+ fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
365
+ if n == 1: axes = [axes]
366
+ idx = np.arange(1, len(df) + 1)
367
+ for ax, col in zip(axes, cols):
368
+ ax.plot(df[col], idx, '-', lw=1.4, color="#333")
369
+ ax.set_xlabel(col); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
370
+ ax.grid(True, linestyle=":", alpha=0.3)
371
+ for s in ax.spines.values(): s.set_visible(True)
372
+ axes[0].set_ylabel("Point Index")
373
+ return fig
374
 
375
  # Modal wrapper (Streamlit compatibility)
376
  try:
377
+ dialog = st.dialog
378
  except AttributeError:
379
+ def dialog(title):
380
+ def deco(fn):
381
+ def wrapper(*args, **kwargs):
382
+ with st.expander(title, expanded=True):
383
+ return fn(*args, **kwargs)
384
+ return wrapper
385
+ return deco
386
 
387
  def preview_modal(book: dict[str, pd.DataFrame]):
388
+ if not book:
389
+ st.info("No data loaded yet."); return
390
+ names = list(book.keys())
391
+ tabs = st.tabs(names)
392
+ for t, name in zip(tabs, names):
393
+ with t:
394
+ df = book[name]
395
+ t1, t2 = st.tabs(["Tracks", "Summary"])
396
+ with t1:
397
+ st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
398
+ with t2:
399
+ tbl = (df[FEATURES]
400
+ .agg(['min','max','mean','std'])
401
+ .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
402
+ df_centered_rounded(tbl.reset_index(names="Feature"))
403
 
404
  # =========================
405
  # Load model
406
  # =========================
407
  def ensure_model() -> Path|None:
408
+ for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
409
+ if p.exists() and p.stat().st_size > 0: return p
410
+ url = os.environ.get("MODEL_URL", "")
411
+ if not url: return None
412
+ try:
413
+ import requests
414
+ DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
415
+ with requests.get(url, stream=True, timeout=30) as r:
416
+ r.raise_for_status()
417
+ with open(DEFAULT_MODEL, "wb") as f:
418
+ for chunk in r.iter_content(1<<20):
419
+ if chunk: f.write(chunk)
420
+ return DEFAULT_MODEL
421
+ except Exception:
422
+ return None
423
 
424
  mpath = ensure_model()
425
  if not mpath:
426
+ st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
427
+ st.stop()
428
  try:
429
+ model = load_model(str(mpath))
430
  except Exception as e:
431
+ st.error(f"Failed to load model: {e}")
432
+ st.stop()
433
 
434
  meta_path = MODELS_DIR / "meta.json"
435
  if meta_path.exists():
436
+ try:
437
+ meta = json.loads(meta_path.read_text(encoding="utf-8"))
438
+ FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
439
+ except Exception:
440
+ pass
441
 
442
  # =========================
443
  # Session state
 
455
  # Branding in Sidebar
456
  # =========================
457
  st.sidebar.markdown(f"""
458
+ <div class="centered-container">
459
+ <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
460
+ <div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>
461
+ <div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>
462
+ </div>
463
+ """, unsafe_allow_html=True
464
  )
465
 
466
  # =========================
467
  # Reusable Sticky Header Function
468
  # =========================
469
  def sticky_header(title, message):
470
+ st.markdown(
471
+ f"""
472
+ <style>
473
+ .sticky-container {{
474
+ position: sticky;
475
+ top: 0;
476
+ background-color: white;
477
+ z-index: 100;
478
+ padding-top: 10px;
479
+ padding-bottom: 10px;
480
+ border-bottom: 1px solid #eee;
481
+ }}
482
+ </style>
483
+ <div class="sticky-container">
484
+ <h3>{title}</h3>
485
+ <p>{message}</p>
486
+ </div>
487
+ """,
488
+ unsafe_allow_html=True
489
+ )
490
 
491
  # =========================
492
  # INTRO
493
  # =========================
494
  if st.session_state.app_step == "intro":
495
+ st.header("Welcome!")
496
+ st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
497
+ st.subheader("How It Works")
498
+ st.markdown(
499
+ "1) **Upload your data to build the case and preview the performance of our model.** \n"
500
+ "2) Click **Run Model** to compute metrics and plots. \n"
501
+ "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
502
+ )
503
+ if st.button("Start Showcase", type="primary"):
504
+ st.session_state.app_step = "dev"; st.rerun()
505
 
506
  # =========================
507
  # CASE BUILDING
508
  # =========================
509
  if st.session_state.app_step == "dev":
510
+ st.sidebar.header("Case Building")
511
+ up = st.sidebar.file_uploader("Upload Your Data File", type=["xlsx","xls"])
512
+ if up is not None:
513
+ st.session_state.dev_file_bytes = up.getvalue()
514
+ st.session_state.dev_file_name = up.name
515
+ st.session_state.dev_file_loaded = True
516
+ st.session_state.dev_preview = False
517
+ if st.session_state.dev_file_loaded:
518
+ tmp = read_book_bytes(st.session_state.dev_file_bytes)
519
+ if tmp:
520
+ df0 = next(iter(tmp.values()))
521
+ st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
522
+
523
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
524
+ st.session_state.show_preview_modal = True # Set state to show modal
525
+ st.session_state.dev_preview = True
526
+
527
+ run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
528
+ if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
529
+ if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
530
+
531
+ # Apply sticky header
532
+ if st.session_state.dev_file_loaded and st.session_state.dev_preview:
533
+ sticky_header("Case Building", "Previewed ✓ — now click **Run Model**.")
534
+ elif st.session_state.dev_file_loaded:
535
+ sticky_header("Case Building", "📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
536
+ else:
537
+ sticky_header("Case Building", "**Upload your data to build a case, then run the model to review development performance.**")
538
+
539
+ if run and st.session_state.dev_file_bytes:
540
+ book = read_book_bytes(st.session_state.dev_file_bytes)
541
+ sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
542
+ sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
543
+ if sh_train is None or sh_test is None:
544
+ st.markdown('<div class="st-message-box st-error">Workbook must include Train/Training/training2 and Test/Testing/testing2 sheets.</div>', unsafe_allow_html=True)
545
+ st.stop()
546
+ tr = book[sh_train].copy(); te = book[sh_test].copy()
547
+ if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
548
+ st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True)
549
+ st.stop()
550
+ tr["UCS_Pred"] = model.predict(tr[FEATURES])
551
+ te["UCS_Pred"] = model.predict(te[FEATURES])
552
+
553
+ st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
554
+ st.session_state.results["m_train"]={
555
+ "R": pearson_r(tr[TARGET], tr["UCS_Pred"]),
556
+ "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
557
+ "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
558
+ }
559
+ st.session_state.results["m_test"]={
560
+ "R": pearson_r(te[TARGET], te["UCS_Pred"]),
561
+ "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
562
+ "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
563
+ }
564
+
565
+ tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
566
+ st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
567
+ st.markdown('<div class="st-message-box st-success">Case has been built and results are displayed below.</div>', unsafe_allow_html=True)
568
+
569
+ def _dev_block(df, m):
570
+ c1,c2,c3 = st.columns(3)
571
+ c1.metric("R", f"{m['R']:.2f}")
572
+ c2.metric("RMSE", f"{m['RMSE']:.2f}")
573
+ c3.metric("MAE", f"{m['MAE']:.2f}")
574
+
575
+ # NEW: Footer for metric abbreviations
576
+ st.markdown("""
577
+ <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
578
+ <strong>R:</strong> Pearson Correlation Coefficient<br>
579
+ <strong>RMSE:</strong> Root Mean Square Error<br>
580
+ <strong>MAE:</strong> Mean Absolute Error
581
+ </div>
582
+ """, unsafe_allow_html=True)
583
+
584
+ # 2-column layout, big gap (prevents overlap)
585
+ col_cross, col_track = st.columns([3, 2], gap="large")
586
+ with col_cross:
587
+ st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
588
+ with col_track:
589
+ st.plotly_chart(
590
+ track_plot(df, include_actual=True),
591
+ use_container_width=False, # Set to False to honor the width in track_plot()
592
+ config={"displayModeBar": False, "scrollZoom": True}
593
+ )
594
+
595
+ if "Train" in st.session_state.results or "Test" in st.session_state.results:
596
+ tab1, tab2 = st.tabs(["Training", "Testing"])
597
+ if "Train" in st.session_state.results:
598
+ with tab1: _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
599
+ if "Test" in st.session_state.results:
600
+ with tab2: _dev_block(st.session_state.results["Test"], st.session_state.results["m_test"])
601
 
602
  # =========================
603
  # VALIDATION (with actual UCS)
604
  # =========================
605
  if st.session_state.app_step == "validate":
606
+ st.sidebar.header("Validate the Model")
607
+ up = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"])
608
+ if up is not None:
609
+ book = read_book_bytes(up.getvalue())
610
+ if book:
611
+ df0 = next(iter(book.values()))
612
+ st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
613
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
614
+ st.session_state.show_preview_modal = True # Set state to show modal
615
+ go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
616
+ if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
617
+ if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
618
+
619
+ sticky_header("Validate the Model", "Upload a dataset with the same **features** and **UCS** to evaluate performance.")
620
+
621
+ if go_btn and up is not None:
622
+ book = read_book_bytes(up.getvalue())
623
+ name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
624
+ df = book[name].copy()
625
+ if not ensure_cols(df, FEATURES+[TARGET]): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
626
+ df["UCS_Pred"] = model.predict(df[FEATURES])
627
+ st.session_state.results["Validate"]=df
628
+
629
+ ranges = st.session_state.train_ranges; oor_pct = 0.0; tbl=None
630
+ if ranges:
631
+ any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
632
+ oor_pct = float(any_viol.mean()*100.0)
633
+ if any_viol.any():
634
+ tbl = df.loc[any_viol, FEATURES].copy()
635
+ for c in FEATURES:
636
+ if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
637
+ tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
638
+ st.session_state.results["m_val"]={
639
+ "R": pearson_r(df[TARGET], df["UCS_Pred"]),
640
+ "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
641
+ "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
642
+ }
643
+ st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
644
+ st.session_state.results["oor_tbl"]=tbl
645
+
646
+ if "Validate" in st.session_state.results:
647
+ m = st.session_state.results["m_val"]
648
+ c1,c2,c3 = st.columns(3)
649
+ c1.metric("R", f"{m['R']:.2f}")
650
+ c2.metric("RMSE", f"{m['RMSE']:.2f}")
651
+ c3.metric("MAE", f"{m['MAE']:.2f}")
652
+
653
+ # NEW: Footer for metric abbreviations
654
+ st.markdown("""
655
+ <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
656
+ <strong>R:</strong> Pearson Correlation Coefficient<br>
657
+ <strong>RMSE:</strong> Root Mean Square Error<br>
658
+ <strong>MAE:</strong> Mean Absolute Error
659
+ </div>
660
+ """, unsafe_allow_html=True)
661
+
662
+ col_cross, col_track = st.columns([3, 2], gap="large")
663
+ with col_cross:
664
+ st.pyplot(
665
+ cross_plot_static(st.session_state.results["Validate"][TARGET],
666
+ st.session_state.results["Validate"]["UCS_Pred"]),
667
+ use_container_width=False
668
+ )
669
+ with col_track:
670
+ st.plotly_chart(
671
+ track_plot(st.session_state.results["Validate"], include_actual=True),
672
+ use_container_width=False, # Set to False to honor the width in track_plot()
673
+ config={"displayModeBar": False, "scrollZoom": True}
674
+ )
675
+
676
+ sv = st.session_state.results["sv_val"]
677
+ if sv["oor"] > 0: st.markdown('<div class="st-message-box st-warning">Some inputs fall outside **training min–max** ranges.</div>', unsafe_allow_html=True)
678
+ if st.session_state.results["oor_tbl"] is not None:
679
+ st.write("*Out-of-range rows (vs. Training min–max):*")
680
+ df_centered_rounded(st.session_state.results["oor_tbl"])
681
 
682
  # =========================
683
  # PREDICTION (no actual UCS)
684
  # =========================
685
  if st.session_state.app_step == "predict":
686
+ st.sidebar.header("Prediction (No Actual UCS)")
687
+ up = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"])
688
+ if up is not None:
689
+ book = read_book_bytes(up.getvalue())
690
+ if book:
691
+ df0 = next(iter(book.values()))
692
+ st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
693
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
694
+ st.session_state.show_preview_modal = True # Set state to show modal
695
+ go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
696
+ if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
697
+
698
+ sticky_header("Prediction", "Upload a dataset with the feature columns (no **UCS**).")
699
+
700
+ if go_btn and up is not None:
701
+ book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
702
+ df = book[name].copy()
703
+ if not ensure_cols(df, FEATURES): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
704
+ df["UCS_Pred"] = model.predict(df[FEATURES])
705
+ st.session_state.results["PredictOnly"]=df
706
+
707
+ ranges = st.session_state.train_ranges; oor_pct = 0.0
708
+ if ranges:
709
+ any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
710
+ oor_pct = float(any_viol.mean()*100.0)
711
+ st.session_state.results["sv_pred"]={
712
+ "n":len(df),
713
+ "pred_min":float(df["UCS_Pred"].min()),
714
+ "pred_max":float(df["UCS_Pred"].max()),
715
+ "pred_mean":float(df["UCS_Pred"].mean()),
716
+ "pred_std":float(df["UCS_Pred"].std(ddof=0)),
717
+ "oor":oor_pct
718
+ }
719
+
720
+ if "PredictOnly" in st.session_state.results:
721
+ df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
722
+
723
+ col_left, col_right = st.columns([2,3], gap="large")
724
+ with col_left:
725
+ table = pd.DataFrame({
726
+ "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
727
+ "Value": [sv["n"],
728
+ round(sv["pred_min"],2),
729
+ round(sv["pred_max"],2),
730
+ round(sv["pred_mean"],2),
731
+ round(sv["pred_std"],2),
732
+ f'{sv["oor"]:.1f}%']
733
+ })
734
+ st.markdown('<div class="st-message-box st-success">Predictions ready ✓</div>', unsafe_allow_html=True)
735
+ df_centered_rounded(table, hide_index=True)
736
+ st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
737
+ with col_right:
738
+ st.plotly_chart(
739
+ track_plot(df, include_actual=False),
740
+ use_container_width=False, # Set to False to honor the width in track_plot()
741
+ config={"displayModeBar": False, "scrollZoom": True}
742
+ )
743
 
744
  # =========================
745
  # Run preview modal after all other elements
746
  # =========================
747
  if st.session_state.show_preview_modal:
748
+ # Get the correct book based on the current app step
749
+ book_to_preview = {}
750
+ if st.session_state.app_step == "dev":
751
+ book_to_preview = read_book_bytes(st.session_state.dev_file_bytes)
752
+ elif st.session_state.app_step in ["validate", "predict"] and 'up' in locals() and up is not None:
753
+ book_to_preview = read_book_bytes(up.getvalue())
754
+
755
+ with st.expander("Preview data", expanded=True):
756
+ if not book_to_preview:
757
+ st.markdown('<div class="st-message-box">No data loaded yet.</div>', unsafe_allow_html=True)
758
+ else:
759
+ names = list(book_to_preview.keys())
760
+ tabs = st.tabs(names)
761
+ for t, name in zip(tabs, names):
762
+ with t:
763
+ df = book_to_preview[name]
764
+ t1, t2 = st.tabs(["Tracks", "Summary"])
765
+ with t1:
766
+ st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
767
+ with t2:
768
+ tbl = (df[FEATURES]
769
+ .agg(['min','max','mean','std'])
770
+ .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
771
+ df_centered_rounded(tbl.reset_index(names="Feature"))
772
+ # Reset the state variable after the modal is displayed
773
+ st.session_state.show_preview_modal = False
 
 
 
 
 
 
 
 
 
 
774
 
775
  # =========================
776
  # Footer
777
  # =========================
778
+ st.markdown("""
779
+ <br><br><br>
780
+ <hr>
781
+ <div style='text-align:center;color:#6b7280;font-size:0.8em;'>
782
+ © 2024 Smart Thinking AI-Solutions Team. All rights reserved.
783
+ </div>
784
+ """, unsafe_allow_html=True)