UCS2014 commited on
Commit
9350442
·
verified ·
1 Parent(s): 310d95f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +576 -554
app.py CHANGED
@@ -25,13 +25,13 @@ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
25
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
 
27
  # ---- Plot sizing controls ----
28
- CROSS_W = 400 # px (matplotlib figure size; Streamlit will still scale)
29
- CROSS_H = 400
30
- TRACK_H = 1000 # px (plotly height; width auto-fits column)
31
  # NEW: Add a TRACK_W variable to control the width
32
- TRACK_W = 500 # px (plotly width)
33
- FONT_SZ = 13
34
- BOLD_FONT = "Arial Black, Arial, sans-serif" # used for bold axis titles & ticks
35
 
36
  # =========================
37
  # Page / CSS
@@ -41,16 +41,16 @@ st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wi
41
  # General CSS (logo helpers etc.)
42
  st.markdown("""
43
  <style>
44
- .brand-logo { width: 200px; height: auto; object-fit: contain; }
45
- .sidebar-header { display:flex; align-items:center; gap:12px; }
46
- .sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
47
- .sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
48
- .centered-container {
49
- display: flex;
50
- flex-direction: column;
51
- align-items: center;
52
- text-align: center;
53
- }
54
  </style>
55
  """, unsafe_allow_html=True)
56
 
@@ -59,12 +59,12 @@ st.markdown("""
59
  <style>
60
  /* This targets the main content area */
61
  .main .block-container {
62
- overflow: unset !important;
63
  }
64
 
65
  /* This targets the vertical block that holds all your elements */
66
  div[data-testid="stVerticalBlock"] {
67
- overflow: unset !important;
68
  }
69
  </style>
70
  """, unsafe_allow_html=True)
@@ -87,52 +87,52 @@ section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] s
87
  st.markdown("""
88
  <style>
89
  div[data-testid="stExpander"] > details > summary {
90
- position: sticky;
91
- top: 0;
92
- z-index: 10;
93
- background: #fff;
94
- border-bottom: 1px solid #eee;
95
  }
96
  div[data-testid="stExpander"] div[data-baseweb="tab-list"] {
97
- position: sticky;
98
- top: 42px; /* adjust if your expander header height differs */
99
- z-index: 9;
100
- background: #fff;
101
- padding-top: 6px;
102
  }
103
  </style>
104
  """, unsafe_allow_html=True)
105
 
106
  # Center text in all pandas Styler tables (headers + cells)
107
  TABLE_CENTER_CSS = [
108
- dict(selector="th", props=[("text-align", "center")]),
109
- dict(selector="td", props=[("text-align", "center")]),
110
  ]
111
 
112
  # NEW: CSS for the message box
113
  st.markdown("""
114
  <style>
115
  .st-message-box {
116
- background-color: #f0f2f6;
117
- color: #333333;
118
- padding: 10px;
119
- border-radius: 10px;
120
- border: 1px solid #e6e9ef;
121
  }
122
  .st-message-box.st-success {
123
- background-color: #d4edda;
124
- color: #155724;
125
- border-color: #c3e6cb;
126
  }
127
  .st-message-box.st-warning {
128
- background-color: #fff3cd;
129
- color: #856404;
130
- border-color: #ffeeba;
131
  }
132
  .st-message-box.st-error {
133
- background-color: #f8d7da;
134
- color: #721c24;
135
- border-color: #f5c6cb;
136
  }
137
  </style>
138
  """, unsafe_allow_html=True)
@@ -141,42 +141,42 @@ st.markdown("""
141
  # Password gate
142
  # =========================
143
  def inline_logo(path="logo.png") -> str:
144
- try:
145
- p = Path(path)
146
- if not p.exists(): return ""
147
- return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
148
- except Exception:
149
- return ""
150
 
151
  def add_password_gate() -> None:
152
- try:
153
- required = st.secrets.get("APP_PASSWORD", "")
154
- except Exception:
155
- required = os.environ.get("APP_PASSWORD", "")
156
-
157
- if not required:
158
- st.warning("Set APP_PASSWORD in Secrets (or environment) and restart.")
159
- st.stop()
160
-
161
- if st.session_state.get("auth_ok", False):
162
- return
163
-
164
- st.sidebar.markdown(f"""
165
- <div class="centered-container">
166
- <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
167
- <div style='font-weight:800;font-size:1.2rem; margin-top: 10px;'>ST_GeoMech_UCS</div>
168
- <div style='color:#667085;'>Smart Thinking • Secure Access</div>
169
- </div>
170
- """, unsafe_allow_html=True
171
- )
172
- pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
173
- if st.sidebar.button("Unlock", type="primary"):
174
- if pwd == required:
175
- st.session_state.auth_ok = True
176
- st.rerun()
177
- else:
178
- st.error("Incorrect key.")
179
- st.stop()
180
 
181
  add_password_gate()
182
 
@@ -184,260 +184,260 @@ add_password_gate()
184
  # Utilities
185
  # =========================
186
  def rmse(y_true, y_pred) -> float:
187
- return float(np.sqrt(mean_squared_error(y_true, y_pred)))
188
 
189
  def pearson_r(y_true, y_pred) -> float:
190
- a = np.asarray(y_true, dtype=float)
191
- p = np.asarray(y_pred, dtype=float)
192
- if a.size < 2: return float("nan")
193
- return float(np.corrcoef(a, p)[0, 1])
194
 
195
  @st.cache_resource(show_spinner=False)
196
  def load_model(model_path: str):
197
- return joblib.load(model_path)
198
 
199
  @st.cache_data(show_spinner=False)
200
  def parse_excel(data_bytes: bytes):
201
- bio = io.BytesIO(data_bytes)
202
- xl = pd.ExcelFile(bio)
203
- return {sh: xl.parse(sh) for sh in xl.sheet_names}
204
 
205
  def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
206
 
207
  def ensure_cols(df, cols):
208
- miss = [c for c in cols if c not in df.columns]
209
- if miss:
210
- st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
211
- return False
212
- return True
213
 
214
  def find_sheet(book, names):
215
- low2orig = {k.lower(): k for k in book.keys()}
216
- for nm in names:
217
- if nm.lower() in low2orig: return low2orig[nm.lower()]
218
- return None
219
 
220
  def _nice_tick0(xmin: float, step: int = 100) -> float:
221
- return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
222
 
223
  def df_centered_rounded(df: pd.DataFrame, hide_index=True):
224
- """Center headers & cells; format numeric columns to 2 decimals."""
225
- out = df.copy()
226
- numcols = out.select_dtypes(include=[np.number]).columns
227
- styler = (
228
- out.style
229
- .format({c: "{:.2f}" for c in numcols})
230
- .set_properties(**{"text-align": "center"})
231
- .set_table_styles(TABLE_CENTER_CSS)
232
- )
233
- st.dataframe(styler, use_container_width=True, hide_index=hide_index)
234
 
235
  # =========================
236
  # Cross plot (Matplotlib, fixed limits & ticks)
237
  # =========================
238
  def cross_plot_static(actual, pred):
239
- a = pd.Series(actual, dtype=float)
240
- p = pd.Series(pred, dtype=float)
241
 
242
- fixed_min, fixed_max = 6000, 10000
243
- ticks = np.arange(fixed_min, fixed_max + 1, 1000)
244
 
245
- dpi = 110
246
- fig, ax = plt.subplots(
247
- figsize=(CROSS_W / dpi, CROSS_H / dpi),
248
- dpi=dpi,
249
- constrained_layout=False
250
- )
251
 
252
- ax.scatter(a, p, s=14, c=COLORS["pred"], alpha=0.9, linewidths=0)
253
- ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
254
- linestyle="--", linewidth=1.2, color=COLORS["ref"])
255
 
256
- ax.set_xlim(fixed_min, fixed_max)
257
- ax.set_ylim(fixed_min, fixed_max)
258
- ax.set_xticks(ticks)
259
- ax.set_yticks(ticks)
260
- ax.set_aspect("equal", adjustable="box") # true 45°
261
 
262
- fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
263
- ax.xaxis.set_major_formatter(fmt)
264
- ax.yaxis.set_major_formatter(fmt)
265
 
266
- ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=12, color="black")
267
- ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=12, color="black")
268
- ax.tick_params(labelsize=8, colors="black")
269
 
270
- ax.grid(True, linestyle=":", alpha=0.3)
271
- for spine in ax.spines.values():
272
- spine.set_linewidth(1.1)
273
- spine.set_color("#444")
274
 
275
- fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
276
- return fig
277
 
278
  # =========================
279
  # Track plot (Plotly)
280
  # =========================
281
  def track_plot(df, include_actual=True):
282
- depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
283
- if depth_col is not None:
284
- y = pd.Series(df[depth_col]).astype(float)
285
- ylab = depth_col
286
- y_range = [float(y.max()), float(y.min())] # reverse
287
- else:
288
- y = pd.Series(np.arange(1, len(df) + 1))
289
- ylab = "Point Index"
290
- y_range = [float(y.max()), float(y.min())]
291
-
292
- # X (UCS) range & ticks
293
- x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
294
- if include_actual and TARGET in df.columns:
295
- x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
296
- x_lo, x_hi = float(x_series.min()), float(x_series.max())
297
- x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
298
- xmin, xmax = x_lo - x_pad, x_hi + x_pad
299
- tick0 = _nice_tick0(xmin, step=100)
300
-
301
- fig = go.Figure()
302
- fig.add_trace(go.Scatter(
303
- x=df["UCS_Pred"], y=y, mode="lines",
304
- line=dict(color=COLORS["pred"], width=1.8),
305
- name="UCS_Pred",
306
- hovertemplate="UCS_Pred: %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
307
- ))
308
- if include_actual and TARGET in df.columns:
309
- fig.add_trace(go.Scatter(
310
- x=df[TARGET], y=y, mode="lines",
311
- line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
312
- name="UCS (actual)",
313
- hovertemplate="UCS (actual): %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
314
- ))
315
-
316
- fig.update_layout(
317
- height=TRACK_H,
318
- width=TRACK_W, # Set the width here
319
- autosize=False, # Disable autosizing to respect the width
320
- paper_bgcolor="#fff", plot_bgcolor="#fff",
321
- margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
322
- font=dict(size=FONT_SZ, color="#000"),
323
- legend=dict(
324
- x=0.98, y=0.05, xanchor="right", yanchor="bottom",
325
- bgcolor="rgba(255,255,255,0.75)", bordercolor="#ccc", borderwidth=1
326
- ),
327
- legend_title_text=""
328
- )
329
-
330
- # Bold, black axis titles & ticks
331
- fig.update_xaxes(
332
- title_text="UCS (psi)",
333
- title_font=dict(size=20, family=BOLD_FONT, color="#000"),
334
- tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
335
- side="top",
336
- range=[xmin, xmax],
337
- ticks="outside",
338
- tickformat=",.0f",
339
- tickmode="auto",
340
- tick0=tick0,
341
- showline=True, linewidth=1.2, linecolor="#444", mirror=True,
342
- showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
343
- )
344
- fig.update_yaxes(
345
- title_text=ylab,
346
- title_font=dict(size=20, family=BOLD_FONT, color="#000"),
347
- tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
348
- range=y_range,
349
- ticks="outside",
350
- showline=True, linewidth=1.2, linecolor="#444", mirror=True,
351
- showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
352
- )
353
-
354
- return fig
355
 
356
  # ---------- Preview modal (matplotlib) ----------
357
  def preview_tracks(df: pd.DataFrame, cols: list[str]):
358
- cols = [c for c in cols if c in df.columns]
359
- n = len(cols)
360
- if n == 0:
361
- fig, ax = plt.subplots(figsize=(4, 2))
362
- ax.text(0.5,0.5,"No selected columns",ha="center",va="center"); ax.axis("off")
363
- return fig
364
- fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
365
- if n == 1: axes = [axes]
366
- idx = np.arange(1, len(df) + 1)
367
- for ax, col in zip(axes, cols):
368
- ax.plot(df[col], idx, '-', lw=1.4, color="#333")
369
- ax.set_xlabel(col); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
370
- ax.grid(True, linestyle=":", alpha=0.3)
371
- for s in ax.spines.values(): s.set_visible(True)
372
- axes[0].set_ylabel("Point Index")
373
- return fig
374
 
375
  # Modal wrapper (Streamlit compatibility)
376
  try:
377
- dialog = st.dialog
378
  except AttributeError:
379
- def dialog(title):
380
- def deco(fn):
381
- def wrapper(*args, **kwargs):
382
- with st.expander(title, expanded=True):
383
- return fn(*args, **kwargs)
384
- return wrapper
385
- return deco
386
 
387
  def preview_modal(book: dict[str, pd.DataFrame]):
388
- if not book:
389
- st.info("No data loaded yet."); return
390
- names = list(book.keys())
391
- tabs = st.tabs(names)
392
- for t, name in zip(tabs, names):
393
- with t:
394
- df = book[name]
395
- t1, t2 = st.tabs(["Tracks", "Summary"])
396
- with t1:
397
- st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
398
- with t2:
399
- tbl = (df[FEATURES]
400
- .agg(['min','max','mean','std'])
401
- .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
402
- df_centered_rounded(tbl.reset_index(names="Feature"))
403
 
404
  # =========================
405
  # Load model
406
  # =========================
407
  def ensure_model() -> Path|None:
408
- for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
409
- if p.exists() and p.stat().st_size > 0: return p
410
- url = os.environ.get("MODEL_URL", "")
411
- if not url: return None
412
- try:
413
- import requests
414
- DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
415
- with requests.get(url, stream=True, timeout=30) as r:
416
- r.raise_for_status()
417
- with open(DEFAULT_MODEL, "wb") as f:
418
- for chunk in r.iter_content(1<<20):
419
- if chunk: f.write(chunk)
420
- return DEFAULT_MODEL
421
- except Exception:
422
- return None
423
 
424
  mpath = ensure_model()
425
  if not mpath:
426
- st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
427
- st.stop()
428
  try:
429
- model = load_model(str(mpath))
430
  except Exception as e:
431
- st.error(f"Failed to load model: {e}")
432
- st.stop()
433
 
434
  meta_path = MODELS_DIR / "meta.json"
435
  if meta_path.exists():
436
- try:
437
- meta = json.loads(meta_path.read_text(encoding="utf-8"))
438
- FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
439
- except Exception:
440
- pass
441
 
442
  # =========================
443
  # Session state
@@ -455,321 +455,343 @@ st.session_state.setdefault("show_preview_modal", False) # New state variable
455
  # Branding in Sidebar
456
  # =========================
457
  st.sidebar.markdown(f"""
458
- <div class="centered-container">
459
- <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
460
- <div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>
461
- <div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>
462
- </div>
463
- """, unsafe_allow_html=True
464
  )
465
 
466
  # =========================
467
  # Reusable Sticky Header Function
468
  # =========================
469
  def sticky_header(title, message):
470
- st.markdown(
471
- f"""
472
- <style>
473
- .sticky-container {{
474
- position: sticky;
475
- top: 0;
476
- background-color: white;
477
- z-index: 100;
478
- padding-top: 10px;
479
- padding-bottom: 10px;
480
- border-bottom: 1px solid #eee;
481
- }}
482
- </style>
483
- <div class="sticky-container">
484
- <h3>{title}</h3>
485
- <p>{message}</p>
486
- </div>
487
- """,
488
- unsafe_allow_html=True
489
- )
490
 
491
  # =========================
492
  # INTRO
493
  # =========================
494
  if st.session_state.app_step == "intro":
495
- st.header("Welcome!")
496
- st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
497
- st.subheader("How It Works")
498
- st.markdown(
499
- "1) **Upload your data to build the case and preview the performance of our model.** \n"
500
- "2) Click **Run Model** to compute metrics and plots. \n"
501
- "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
502
- )
503
- if st.button("Start Showcase", type="primary"):
504
- st.session_state.app_step = "dev"; st.rerun()
505
 
506
  # =========================
507
  # CASE BUILDING
508
  # =========================
509
  if st.session_state.app_step == "dev":
510
- st.sidebar.header("Case Building")
511
- up = st.sidebar.file_uploader("Upload Your Data File", type=["xlsx","xls"])
512
- if up is not None:
513
- st.session_state.dev_file_bytes = up.getvalue()
514
- st.session_state.dev_file_name = up.name
515
- st.session_state.dev_file_loaded = True
516
- st.session_state.dev_preview = False
517
- if st.session_state.dev_file_loaded:
518
- tmp = read_book_bytes(st.session_state.dev_file_bytes)
519
- if tmp:
520
- df0 = next(iter(tmp.values()))
521
- st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
522
-
523
- if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
524
- st.session_state.show_preview_modal = True # Set state to show modal
525
- st.session_state.dev_preview = True
526
-
527
- run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
528
- if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
529
- if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
530
-
531
- # Apply sticky header
532
- if st.session_state.dev_file_loaded and st.session_state.dev_preview:
533
- sticky_header("Case Building", "Previewed ✓ — now click **Run Model**.")
534
- elif st.session_state.dev_file_loaded:
535
- sticky_header("Case Building", "📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
536
- else:
537
- sticky_header("Case Building", "**Upload your data to build a case, then run the model to review development performance.**")
538
-
539
- if run and st.session_state.dev_file_bytes:
540
- book = read_book_bytes(st.session_state.dev_file_bytes)
541
- sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
542
- sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
543
- if sh_train is None or sh_test is None:
544
- st.markdown('<div class="st-message-box st-error">Workbook must include Train/Training/training2 and Test/Testing/testing2 sheets.</div>', unsafe_allow_html=True)
545
- st.stop()
546
- tr = book[sh_train].copy(); te = book[sh_test].copy()
547
- if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
548
- st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True)
549
- st.stop()
550
- tr["UCS_Pred"] = model.predict(tr[FEATURES])
551
- te["UCS_Pred"] = model.predict(te[FEATURES])
552
-
553
- st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
554
- st.session_state.results["m_train"]={
555
- "R": pearson_r(tr[TARGET], tr["UCS_Pred"]),
556
- "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
557
- "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
558
- }
559
- st.session_state.results["m_test"]={
560
- "R": pearson_r(te[TARGET], te["UCS_Pred"]),
561
- "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
562
- "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
563
- }
564
-
565
- tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
566
- st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
567
- st.markdown('<div class="st-message-box st-success">Case has been built and results are displayed below.</div>', unsafe_allow_html=True)
568
-
569
- def _dev_block(df, m):
570
- c1,c2,c3 = st.columns(3)
571
- c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
572
-
573
- # 2-column layout, big gap (prevents overlap)
574
- col_cross, col_track = st.columns([3, 2], gap="large")
575
- with col_cross:
576
- st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
577
- with col_track:
578
- st.plotly_chart(
579
- track_plot(df, include_actual=True),
580
- use_container_width=False, # Set to False to honor the width in track_plot()
581
- config={"displayModeBar": False, "scrollZoom": True}
582
- )
583
-
584
- if "Train" in st.session_state.results or "Test" in st.session_state.results:
585
- tab1, tab2 = st.tabs(["Training", "Testing"])
586
- if "Train" in st.session_state.results:
587
- with tab1: _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
588
- if "Test" in st.session_state.results:
589
- with tab2: _dev_block(st.session_state.results["Test"], st.session_state.results["m_test"])
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  # =========================
592
  # VALIDATION (with actual UCS)
593
  # =========================
594
  if st.session_state.app_step == "validate":
595
- st.sidebar.header("Validate the Model")
596
- up = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"])
597
- if up is not None:
598
- book = read_book_bytes(up.getvalue())
599
- if book:
600
- df0 = next(iter(book.values()))
601
- st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
602
- if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
603
- st.session_state.show_preview_modal = True # Set state to show modal
604
- go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
605
- if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
606
- if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
607
-
608
- sticky_header("Validate the Model", "Upload a dataset with the same **features** and **UCS** to evaluate performance.")
609
-
610
- if go_btn and up is not None:
611
- book = read_book_bytes(up.getvalue())
612
- name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
613
- df = book[name].copy()
614
- if not ensure_cols(df, FEATURES+[TARGET]): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
615
- df["UCS_Pred"] = model.predict(df[FEATURES])
616
- st.session_state.results["Validate"]=df
617
-
618
- ranges = st.session_state.train_ranges; oor_pct = 0.0; tbl=None
619
- if ranges:
620
- any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
621
- oor_pct = float(any_viol.mean()*100.0)
622
- if any_viol.any():
623
- tbl = df.loc[any_viol, FEATURES].copy()
624
- for c in FEATURES:
625
- if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
626
- tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
627
- st.session_state.results["m_val"]={
628
- "R": pearson_r(df[TARGET], df["UCS_Pred"]),
629
- "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
630
- "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
631
- }
632
- st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
633
- st.session_state.results["oor_tbl"]=tbl
634
-
635
- if "Validate" in st.session_state.results:
636
- m = st.session_state.results["m_val"]
637
- c1,c2,c3 = st.columns(3)
638
- c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
639
-
640
- col_cross, col_track = st.columns([3, 2], gap="large")
641
- with col_cross:
642
- st.pyplot(
643
- cross_plot_static(st.session_state.results["Validate"][TARGET],
644
- st.session_state.results["Validate"]["UCS_Pred"]),
645
- use_container_width=True
646
- )
647
- with col_track:
648
- st.plotly_chart(
649
- track_plot(st.session_state.results["Validate"], include_actual=True),
650
- use_container_width=False, # Set to False to honor the width in track_plot()
651
- config={"displayModeBar": False, "scrollZoom": True}
652
- )
653
-
654
- sv = st.session_state.results["sv_val"]
655
- if sv["oor"] > 0: st.markdown('<div class="st-message-box st-warning">Some inputs fall outside **training min–max** ranges.</div>', unsafe_allow_html=True)
656
- if st.session_state.results["oor_tbl"] is not None:
657
- st.write("*Out-of-range rows (vs. Training min–max):*")
658
- df_centered_rounded(st.session_state.results["oor_tbl"])
 
 
 
 
 
 
 
 
 
 
 
659
 
660
  # =========================
661
  # PREDICTION (no actual UCS)
662
  # =========================
663
  if st.session_state.app_step == "predict":
664
- st.sidebar.header("Prediction (No Actual UCS)")
665
- up = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"])
666
- if up is not None:
667
- book = read_book_bytes(up.getvalue())
668
- if book:
669
- df0 = next(iter(book.values()))
670
- st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
671
- if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
672
- st.session_state.show_preview_modal = True # Set state to show modal
673
- go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
674
- if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
675
-
676
- sticky_header("Prediction", "Upload a dataset with the feature columns (no **UCS**).")
677
-
678
- if go_btn and up is not None:
679
- book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
680
- df = book[name].copy()
681
- if not ensure_cols(df, FEATURES): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
682
- df["UCS_Pred"] = model.predict(df[FEATURES])
683
- st.session_state.results["PredictOnly"]=df
684
-
685
- ranges = st.session_state.train_ranges; oor_pct = 0.0
686
- if ranges:
687
- any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
688
- oor_pct = float(any_viol.mean()*100.0)
689
- st.session_state.results["sv_pred"]={
690
- "n":len(df),
691
- "pred_min":float(df["UCS_Pred"].min()),
692
- "pred_max":float(df["UCS_Pred"].max()),
693
- "pred_mean":float(df["UCS_Pred"].mean()),
694
- "pred_std":float(df["UCS_Pred"].std(ddof=0)),
695
- "oor":oor_pct
696
- }
697
-
698
- if "PredictOnly" in st.session_state.results:
699
- df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
700
-
701
- col_left, col_right = st.columns([2,3], gap="large")
702
- with col_left:
703
- table = pd.DataFrame({
704
- "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
705
- "Value": [sv["n"],
706
- round(sv["pred_min"],2),
707
- round(sv["pred_max"],2),
708
- round(sv["pred_mean"],2),
709
- round(sv["pred_std"],2),
710
- f'{sv["oor"]:.1f}%']
711
- })
712
- st.markdown('<div class="st-message-box st-success">Predictions ready ✓</div>', unsafe_allow_html=True)
713
- df_centered_rounded(table, hide_index=True)
714
- st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
715
- with col_right:
716
- st.plotly_chart(
717
- track_plot(df, include_actual=False),
718
- use_container_width=False, # Set to False to honor the width in track_plot()
719
- config={"displayModeBar": False, "scrollZoom": True}
720
- )
721
 
722
  # =========================
723
  # Run preview modal after all other elements
724
  # =========================
725
  if st.session_state.show_preview_modal:
726
- # Get the correct book based on the current app step
727
- book_to_preview = {}
728
- if st.session_state.app_step == "dev":
729
- book_to_preview = read_book_bytes(st.session_state.dev_file_bytes)
730
- elif st.session_state.app_step in ["validate", "predict"] and up is not None:
731
- book_to_preview = read_book_bytes(up.getvalue())
732
-
733
- # Use a try-except block to handle cases where 'up' might be None
734
- # and the logic tries to access its attributes.
735
- try:
736
- if st.session_state.app_step == "validate" and up is not None:
737
- book_to_preview = read_book_bytes(up.getvalue())
738
- elif st.session_state.app_step == "predict" and up is not None:
739
- book_to_preview = read_book_bytes(up.getvalue())
740
- except NameError:
741
- book_to_preview = {}
742
-
743
- with st.expander("Preview data", expanded=True):
744
- if not book_to_preview:
745
- st.markdown('<div class="st-message-box">No data loaded yet.</div>', unsafe_allow_html=True)
746
- else:
747
- names = list(book_to_preview.keys())
748
- tabs = st.tabs(names)
749
- for t, name in zip(tabs, names):
750
- with t:
751
- df = book_to_preview[name]
752
- t1, t2 = st.tabs(["Tracks", "Summary"])
753
- with t1:
754
- st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
755
- with t2:
756
- tbl = (df[FEATURES]
757
- .agg(['min','max','mean','std'])
758
- .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
759
- df_centered_rounded(tbl.reset_index(names="Feature"))
760
- # Reset the state variable after the modal is displayed
761
- st.session_state.show_preview_modal = False
762
 
763
  # =========================
764
  # Footer
765
  # =========================
766
  st.markdown("---")
767
  st.markdown(
768
- """
769
- <div style='text-align:center; color:#6b7280; line-height:1.6'>
770
- ST_GeoMech_UCS • © Smart Thinking<br/>
771
- <strong>Visit our website:</strong> <a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>
772
- </div>
773
- """,
774
- unsafe_allow_html=True
775
  )
 
25
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
 
27
  # ---- Plot sizing controls ----
28
+ CROSS_W = 250       # px (matplotlib figure size; Streamlit will still scale)
29
+ CROSS_H = 250
30
+ TRACK_H = 1000      # px (plotly height; width auto-fits column)
31
  # NEW: Add a TRACK_W variable to control the width
32
+ TRACK_W = 500       # px (plotly width)
33
+ FONT_SZ  = 13
34
+ BOLD_FONT = "Arial Black, Arial, sans-serif"  # used for bold axis titles & ticks
35
 
36
  # =========================
37
  # Page / CSS
 
41
  # General CSS (logo helpers etc.)
42
  st.markdown("""
43
  <style>
44
+   .brand-logo { width: 200px; height: auto; object-fit: contain; }
45
+   .sidebar-header { display:flex; align-items:center; gap:12px; }
46
+   .sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
47
+   .sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
48
+   .centered-container {
49
+     display: flex;
50
+     flex-direction: column;
51
+     align-items: center;
52
+     text-align: center;
53
+   }
54
  </style>
55
  """, unsafe_allow_html=True)
56
 
 
59
  <style>
60
  /* This targets the main content area */
61
  .main .block-container {
62
+     overflow: unset !important;
63
  }
64
 
65
  /* This targets the vertical block that holds all your elements */
66
  div[data-testid="stVerticalBlock"] {
67
+     overflow: unset !important;
68
  }
69
  </style>
70
  """, unsafe_allow_html=True)
 
87
  st.markdown("""
88
  <style>
89
  div[data-testid="stExpander"] > details > summary {
90
+   position: sticky;
91
+   top: 0;
92
+   z-index: 10;
93
+   background: #fff;
94
+   border-bottom: 1px solid #eee;
95
  }
96
  div[data-testid="stExpander"] div[data-baseweb="tab-list"] {
97
+   position: sticky;
98
+   top: 42px;    /* adjust if your expander header height differs */
99
+   z-index: 9;
100
+   background: #fff;
101
+   padding-top: 6px;
102
  }
103
  </style>
104
  """, unsafe_allow_html=True)
105
 
106
  # Center text in all pandas Styler tables (headers + cells)
107
  TABLE_CENTER_CSS = [
108
+     dict(selector="th", props=[("text-align", "center")]),
109
+     dict(selector="td", props=[("text-align", "center")]),
110
  ]
111
 
112
  # NEW: CSS for the message box
113
  st.markdown("""
114
  <style>
115
  .st-message-box {
116
+     background-color: #f0f2f6;
117
+     color: #333333;
118
+     padding: 10px;
119
+     border-radius: 10px;
120
+     border: 1px solid #e6e9ef;
121
  }
122
  .st-message-box.st-success {
123
+     background-color: #d4edda;
124
+     color: #155724;
125
+     border-color: #c3e6cb;
126
  }
127
  .st-message-box.st-warning {
128
+     background-color: #fff3cd;
129
+     color: #856404;
130
+     border-color: #ffeeba;
131
  }
132
  .st-message-box.st-error {
133
+     background-color: #f8d7da;
134
+     color: #721c24;
135
+     border-color: #f5c6cb;
136
  }
137
  </style>
138
  """, unsafe_allow_html=True)
 
141
  # Password gate
142
  # =========================
143
  def inline_logo(path="logo.png") -> str:
144
+     try:
145
+         p = Path(path)
146
+         if not p.exists(): return ""
147
+         return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
148
+     except Exception:
149
+         return ""
150
 
151
  def add_password_gate() -> None:
152
+     try:
153
+         required = st.secrets.get("APP_PASSWORD", "")
154
+     except Exception:
155
+         required = os.environ.get("APP_PASSWORD", "")
156
+
157
+     if not required:
158
+         st.warning("Set APP_PASSWORD in Secrets (or environment) and restart.")
159
+         st.stop()
160
+
161
+     if st.session_state.get("auth_ok", False):
162
+         return
163
+
164
+     st.sidebar.markdown(f"""
165
+         <div class="centered-container">
166
+             <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
167
+             <div style='font-weight:800;font-size:1.2rem; margin-top: 10px;'>ST_GeoMech_UCS</div>
168
+             <div style='color:#667085;'>Smart Thinking • Secure Access</div>
169
+         </div>
170
+         """, unsafe_allow_html=True
171
+     )
172
+     pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
173
+     if st.sidebar.button("Unlock", type="primary"):
174
+         if pwd == required:
175
+             st.session_state.auth_ok = True
176
+             st.rerun()
177
+         else:
178
+             st.error("Incorrect key.")
179
+     st.stop()
180
 
181
  add_password_gate()
182
 
 
184
  # Utilities
185
  # =========================
186
  def rmse(y_true, y_pred) -> float:
187
+     return float(np.sqrt(mean_squared_error(y_true, y_pred)))
188
 
189
  def pearson_r(y_true, y_pred) -> float:
190
+     a = np.asarray(y_true, dtype=float)
191
+     p = np.asarray(y_pred,   dtype=float)
192
+     if a.size < 2: return float("nan")
193
+     return float(np.corrcoef(a, p)[0, 1])
194
 
195
  @st.cache_resource(show_spinner=False)
196
  def load_model(model_path: str):
197
+     return joblib.load(model_path)
198
 
199
  @st.cache_data(show_spinner=False)
200
  def parse_excel(data_bytes: bytes):
201
+     bio = io.BytesIO(data_bytes)
202
+     xl = pd.ExcelFile(bio)
203
+     return {sh: xl.parse(sh) for sh in xl.sheet_names}
204
 
205
  def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
206
 
207
  def ensure_cols(df, cols):
208
+     miss = [c for c in cols if c not in df.columns]
209
+     if miss:
210
+         st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
211
+         return False
212
+     return True
213
 
214
  def find_sheet(book, names):
215
+     low2orig = {k.lower(): k for k in book.keys()}
216
+     for nm in names:
217
+         if nm.lower() in low2orig: return low2orig[nm.lower()]
218
+     return None
219
 
220
  def _nice_tick0(xmin: float, step: int = 100) -> float:
221
+     return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
222
 
223
  def df_centered_rounded(df: pd.DataFrame, hide_index=True):
224
+     """Center headers & cells; format numeric columns to 2 decimals."""
225
+     out = df.copy()
226
+     numcols = out.select_dtypes(include=[np.number]).columns
227
+     styler = (
228
+         out.style
229
+             .format({c: "{:.2f}" for c in numcols})
230
+             .set_properties(**{"text-align": "center"})
231
+             .set_table_styles(TABLE_CENTER_CSS)
232
+     )
233
+     st.dataframe(styler, use_container_width=True, hide_index=hide_index)
234
 
235
  # =========================
236
  # Cross plot (Matplotlib, fixed limits & ticks)
237
  # =========================
238
  def cross_plot_static(actual, pred):
239
+     a = pd.Series(actual, dtype=float)
240
+     p = pd.Series(pred,   dtype=float)
241
 
242
+     fixed_min, fixed_max = 6000, 10000
243
+     ticks = np.arange(fixed_min, fixed_max + 1, 1000)
244
 
245
+     dpi = 110
246
+     fig, ax = plt.subplots(
247
+         figsize=(CROSS_W / dpi, CROSS_H / dpi),
248
+         dpi=dpi,
249
+         constrained_layout=False
250
+     )
251
 
252
+     ax.scatter(a, p, s=14, c=COLORS["pred"], alpha=0.9, linewidths=0)
253
+     ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
254
+             linestyle="--", linewidth=1.2, color=COLORS["ref"])
255
 
256
+     ax.set_xlim(fixed_min, fixed_max)
257
+     ax.set_ylim(fixed_min, fixed_max)
258
+     ax.set_xticks(ticks)
259
+     ax.set_yticks(ticks)
260
+     ax.set_aspect("equal", adjustable="box")  # true 45°
261
 
262
+     fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
263
+     ax.xaxis.set_major_formatter(fmt)
264
+     ax.yaxis.set_major_formatter(fmt)
265
 
266
+     ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=4, color="black")
267
+     ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=4, color="black")
268
+     ax.tick_params(labelsize=2, colors="black")
269
 
270
+     ax.grid(True, linestyle=":", alpha=0.3)
271
+     for spine in ax.spines.values():
272
+         spine.set_linewidth(1.1)
273
+         spine.set_color("#444")
274
 
275
+     fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
276
+     return fig
277
 
278
  # =========================
279
  # Track plot (Plotly)
280
  # =========================
281
  def track_plot(df, include_actual=True):
282
+     depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
283
+     if depth_col is not None:
284
+         y = pd.Series(df[depth_col]).astype(float)
285
+         ylab = depth_col
286
+         y_range = [float(y.max()), float(y.min())]  # reverse
287
+     else:
288
+         y = pd.Series(np.arange(1, len(df) + 1))
289
+         ylab = "Point Index"
290
+         y_range = [float(y.max()), float(y.min())]
291
+
292
+     # X (UCS) range & ticks
293
+     x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
294
+     if include_actual and TARGET in df.columns:
295
+         x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
296
+     x_lo, x_hi = float(x_series.min()), float(x_series.max())
297
+     x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
298
+     xmin, xmax = x_lo - x_pad, x_hi + x_pad
299
+     tick0 = _nice_tick0(xmin, step=100)
300
+
301
+     fig = go.Figure()
302
+     fig.add_trace(go.Scatter(
303
+         x=df["UCS_Pred"], y=y, mode="lines",
304
+         line=dict(color=COLORS["pred"], width=1.8),
305
+         name="UCS_Pred",
306
+         hovertemplate="UCS_Pred: %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
307
+     ))
308
+     if include_actual and TARGET in df.columns:
309
+         fig.add_trace(go.Scatter(
310
+             x=df[TARGET], y=y, mode="lines",
311
+             line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
312
+             name="UCS (actual)",
313
+             hovertemplate="UCS (actual): %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
314
+         ))
315
+
316
+     fig.update_layout(
317
+         height=TRACK_H,
318
+         width=TRACK_W, # Set the width here
319
+         autosize=False, # Disable autosizing to respect the width
320
+         paper_bgcolor="#fff", plot_bgcolor="#fff",
321
+         margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
322
+         font=dict(size=FONT_SZ, color="#000"),
323
+         legend=dict(
324
+             x=0.98, y=0.05, xanchor="right", yanchor="bottom",
325
+             bgcolor="rgba(255,255,255,0.75)", bordercolor="#ccc", borderwidth=1
326
+         ),
327
+         legend_title_text=""
328
+     )
329
+
330
+     # Bold, black axis titles & ticks
331
+     fig.update_xaxes(
332
+         title_text="UCS (psi)",
333
+         title_font=dict(size=20, family=BOLD_FONT, color="#000"),
334
+         tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
335
+         side="top",
336
+         range=[xmin, xmax],
337
+         ticks="outside",
338
+         tickformat=",.0f",
339
+         tickmode="auto",
340
+         tick0=tick0,
341
+         showline=True, linewidth=1.2, linecolor="#444", mirror=True,
342
+         showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
343
+     )
344
+     fig.update_yaxes(
345
+         title_text=ylab,
346
+         title_font=dict(size=20, family=BOLD_FONT, color="#000"),
347
+         tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
348
+         range=y_range,
349
+         ticks="outside",
350
+         showline=True, linewidth=1.2, linecolor="#444", mirror=True,
351
+         showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
352
+     )
353
+
354
+     return fig
355
 
356
  # ---------- Preview modal (matplotlib) ----------
357
  def preview_tracks(df: pd.DataFrame, cols: list[str]):
358
+     cols = [c for c in cols if c in df.columns]
359
+     n = len(cols)
360
+     if n == 0:
361
+         fig, ax = plt.subplots(figsize=(4, 2))
362
+         ax.text(0.5,0.5,"No selected columns",ha="center",va="center"); ax.axis("off")
363
+         return fig
364
+     fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
365
+     if n == 1: axes = [axes]
366
+     idx = np.arange(1, len(df) + 1)
367
+     for ax, col in zip(axes, cols):
368
+         ax.plot(df[col], idx, '-', lw=1.4, color="#333")
369
+         ax.set_xlabel(col); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
370
+         ax.grid(True, linestyle=":", alpha=0.3)
371
+         for s in ax.spines.values(): s.set_visible(True)
372
+     axes[0].set_ylabel("Point Index")
373
+     return fig
374
 
375
  # Modal wrapper (Streamlit compatibility)
376
  try:
377
+     dialog = st.dialog
378
  except AttributeError:
379
+     def dialog(title):
380
+         def deco(fn):
381
+             def wrapper(*args, **kwargs):
382
+                 with st.expander(title, expanded=True):
383
+                     return fn(*args, **kwargs)
384
+             return wrapper
385
+         return deco
386
 
387
  def preview_modal(book: dict[str, pd.DataFrame]):
388
+     if not book:
389
+         st.info("No data loaded yet."); return
390
+     names = list(book.keys())
391
+     tabs = st.tabs(names)
392
+     for t, name in zip(tabs, names):
393
+         with t:
394
+             df = book[name]
395
+             t1, t2 = st.tabs(["Tracks", "Summary"])
396
+             with t1:
397
+                 st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
398
+             with t2:
399
+                 tbl = (df[FEATURES]
400
+                         .agg(['min','max','mean','std'])
401
+                         .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
402
+                 df_centered_rounded(tbl.reset_index(names="Feature"))
403
 
404
  # =========================
405
  # Load model
406
  # =========================
407
  def ensure_model() -> Path|None:
408
+     for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
409
+         if p.exists() and p.stat().st_size > 0: return p
410
+     url = os.environ.get("MODEL_URL", "")
411
+     if not url: return None
412
+     try:
413
+         import requests
414
+         DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
415
+         with requests.get(url, stream=True, timeout=30) as r:
416
+             r.raise_for_status()
417
+             with open(DEFAULT_MODEL, "wb") as f:
418
+                 for chunk in r.iter_content(1<<20):
419
+                     if chunk: f.write(chunk)
420
+         return DEFAULT_MODEL
421
+     except Exception:
422
+         return None
423
 
424
  mpath = ensure_model()
425
  if not mpath:
426
+     st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
427
+     st.stop()
428
  try:
429
+     model = load_model(str(mpath))
430
  except Exception as e:
431
+     st.error(f"Failed to load model: {e}")
432
+     st.stop()
433
 
434
  meta_path = MODELS_DIR / "meta.json"
435
  if meta_path.exists():
436
+     try:
437
+         meta = json.loads(meta_path.read_text(encoding="utf-8"))
438
+         FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
439
+     except Exception:
440
+         pass
441
 
442
  # =========================
443
  # Session state
 
455
  # Branding in Sidebar
456
  # =========================
457
  st.sidebar.markdown(f"""
458
+     <div class="centered-container">
459
+         <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
460
+         <div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>
461
+         <div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>
462
+     </div>
463
+     """, unsafe_allow_html=True
464
  )
465
 
466
  # =========================
467
  # Reusable Sticky Header Function
468
  # =========================
469
  def sticky_header(title, message):
470
+     st.markdown(
471
+         f"""
472
+         <style>
473
+         .sticky-container {{
474
+             position: sticky;
475
+             top: 0;
476
+             background-color: white;
477
+             z-index: 100;
478
+             padding-top: 10px;
479
+             padding-bottom: 10px;
480
+             border-bottom: 1px solid #eee;
481
+         }}
482
+         </style>
483
+         <div class="sticky-container">
484
+             <h3>{title}</h3>
485
+             <p>{message}</p>
486
+         </div>
487
+         """,
488
+         unsafe_allow_html=True
489
+     )
490
 
491
  # =========================
492
  # INTRO
493
  # =========================
494
  if st.session_state.app_step == "intro":
495
+     st.header("Welcome!")
496
+     st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
497
+     st.subheader("How It Works")
498
+     st.markdown(
499
+         "1) **Upload your data to build the case and preview the performance of our model.** \n"
500
+         "2) Click **Run Model** to compute metrics and plots.  \n"
501
+         "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
502
+     )
503
+     if st.button("Start Showcase", type="primary"):
504
+         st.session_state.app_step = "dev"; st.rerun()
505
 
506
  # =========================
507
  # CASE BUILDING
508
  # =========================
509
  if st.session_state.app_step == "dev":
510
+     st.sidebar.header("Case Building")
511
+     up = st.sidebar.file_uploader("Upload Your Data File", type=["xlsx","xls"])
512
+     if up is not None:
513
+         st.session_state.dev_file_bytes = up.getvalue()
514
+         st.session_state.dev_file_name = up.name
515
+         st.session_state.dev_file_loaded = True
516
+         st.session_state.dev_preview = False
517
+     if st.session_state.dev_file_loaded:
518
+         tmp = read_book_bytes(st.session_state.dev_file_bytes)
519
+         if tmp:
520
+             df0 = next(iter(tmp.values()))
521
+             st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
522
+
523
+     if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
524
+         st.session_state.show_preview_modal = True  # Set state to show modal
525
+         st.session_state.dev_preview = True
526
+
527
+     run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
528
+     if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
529
+     if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
530
+
531
+     # Apply sticky header
532
+     if st.session_state.dev_file_loaded and st.session_state.dev_preview:
533
+         sticky_header("Case Building", "Previewed ✓ — now click **Run Model**.")
534
+     elif st.session_state.dev_file_loaded:
535
+         sticky_header("Case Building", "📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
536
+     else:
537
+         sticky_header("Case Building", "**Upload your data to build a case, then run the model to review development performance.**")
538
+
539
+     if run and st.session_state.dev_file_bytes:
540
+         book = read_book_bytes(st.session_state.dev_file_bytes)
541
+         sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
542
+         sh_test  = find_sheet(book, ["Test","Testing","testing2","test","testing"])
543
+         if sh_train is None or sh_test is None:
544
+             st.markdown('<div class="st-message-box st-error">Workbook must include Train/Training/training2 and Test/Testing/testing2 sheets.</div>', unsafe_allow_html=True)
545
+             st.stop()
546
+         tr = book[sh_train].copy(); te = book[sh_test].copy()
547
+         if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
548
+             st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True)
549
+             st.stop()
550
+         tr["UCS_Pred"] = model.predict(tr[FEATURES])
551
+         te["UCS_Pred"] = model.predict(te[FEATURES])
552
+
553
+         st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
554
+         st.session_state.results["m_train"]={
555
+             "R": pearson_r(tr[TARGET], tr["UCS_Pred"]),
556
+             "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
557
+             "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
558
+         }
559
+         st.session_state.results["m_test"]={
560
+             "R": pearson_r(te[TARGET], te["UCS_Pred"]),
561
+             "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
562
+             "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
563
+         }
564
+
565
+         tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
566
+         st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
567
+         st.markdown('<div class="st-message-box st-success">Case has been built and results are displayed below.</div>', unsafe_allow_html=True)
568
+
569
+     def _dev_block(df, m):
570
+         c1,c2,c3 = st.columns(3)
571
+         c1.metric("R", f"{m['R']:.2f}")
572
+         c2.metric("RMSE", f"{m['RMSE']:.2f}")
573
+         c3.metric("MAE", f"{m['MAE']:.2f}")
574
+
575
+         # NEW: Footer for metric abbreviations
576
+         st.markdown("""
577
+             <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
578
+                 <strong>R:</strong> Pearson Correlation Coefficient<br>
579
+                 <strong>RMSE:</strong> Root Mean Square Error<br>
580
+                 <strong>MAE:</strong> Mean Absolute Error
581
+             </div>
582
+         """, unsafe_allow_html=True)
583
+
584
+         # 2-column layout, big gap (prevents overlap)
585
+         col_cross, col_track = st.columns([3, 2], gap="large")
586
+         with col_cross:
587
+             st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
588
+         with col_track:
589
+             st.plotly_chart(
590
+                 track_plot(df, include_actual=True),
591
+                 use_container_width=False, # Set to False to honor the width in track_plot()
592
+                 config={"displayModeBar": False, "scrollZoom": True}
593
+             )
594
+
595
+     if "Train" in st.session_state.results or "Test" in st.session_state.results:
596
+         tab1, tab2 = st.tabs(["Training", "Testing"])
597
+         if "Train" in st.session_state.results:
598
+             with tab1: _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
599
+         if "Test" in st.session_state.results:
600
+             with tab2: _dev_block(st.session_state.results["Test"],  st.session_state.results["m_test"])
601
 
602
  # =========================
603
  # VALIDATION (with actual UCS)
604
  # =========================
605
  if st.session_state.app_step == "validate":
606
+     st.sidebar.header("Validate the Model")
607
+     up = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"])
608
+     if up is not None:
609
+         book = read_book_bytes(up.getvalue())
610
+         if book:
611
+             df0 = next(iter(book.values()))
612
+             st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
613
+     if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
614
+         st.session_state.show_preview_modal = True  # Set state to show modal
615
+     go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
616
+     if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
617
+     if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
618
+
619
+     sticky_header("Validate the Model", "Upload a dataset with the same **features** and **UCS** to evaluate performance.")
620
+
621
+     if go_btn and up is not None:
622
+         book = read_book_bytes(up.getvalue())
623
+         name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
624
+         df = book[name].copy()
625
+         if not ensure_cols(df, FEATURES+[TARGET]): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
626
+         df["UCS_Pred"] = model.predict(df[FEATURES])
627
+         st.session_state.results["Validate"]=df
628
+
629
+         ranges = st.session_state.train_ranges; oor_pct = 0.0; tbl=None
630
+         if ranges:
631
+             any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
632
+             oor_pct = float(any_viol.mean()*100.0)
633
+             if any_viol.any():
634
+                 tbl = df.loc[any_viol, FEATURES].copy()
635
+                 for c in FEATURES:
636
+                     if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
637
+                 tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
638
+         st.session_state.results["m_val"]={
639
+             "R": pearson_r(df[TARGET], df["UCS_Pred"]),
640
+             "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
641
+             "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
642
+         }
643
+         st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
644
+         st.session_state.results["oor_tbl"]=tbl
645
+
646
+     if "Validate" in st.session_state.results:
647
+         m = st.session_state.results["m_val"]
648
+         c1,c2,c3 = st.columns(3)
649
+         c1.metric("R", f"{m['R']:.2f}")
650
+         c2.metric("RMSE", f"{m['RMSE']:.2f}")
651
+         c3.metric("MAE", f"{m['MAE']:.2f}")
652
+
653
+         # NEW: Footer for metric abbreviations
654
+         st.markdown("""
655
+             <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
656
+                 <strong>R:</strong> Pearson Correlation Coefficient<br>
657
+                 <strong>RMSE:</strong> Root Mean Square Error<br>
658
+                 <strong>MAE:</strong> Mean Absolute Error
659
+             </div>
660
+         """, unsafe_allow_html=True)
661
+    
662
+         col_cross, col_track = st.columns([3, 2], gap="large")
663
+         with col_cross:
664
+             st.pyplot(
665
+                 cross_plot_static(st.session_state.results["Validate"][TARGET],
666
+                                      st.session_state.results["Validate"]["UCS_Pred"]),
667
+                 use_container_width=False
668
+             )
669
+         with col_track:
670
+             st.plotly_chart(
671
+                 track_plot(st.session_state.results["Validate"], include_actual=True),
672
+                 use_container_width=False, # Set to False to honor the width in track_plot()
673
+                 config={"displayModeBar": False, "scrollZoom": True}
674
+             )
675
+
676
+         sv = st.session_state.results["sv_val"]
677
+         if sv["oor"] > 0: st.markdown('<div class="st-message-box st-warning">Some inputs fall outside **training min–max** ranges.</div>', unsafe_allow_html=True)
678
+         if st.session_state.results["oor_tbl"] is not None:
679
+             st.write("*Out-of-range rows (vs. Training min–max):*")
680
+             df_centered_rounded(st.session_state.results["oor_tbl"])
681
 
682
  # =========================
683
  # PREDICTION (no actual UCS)
684
  # =========================
685
  if st.session_state.app_step == "predict":
686
+     st.sidebar.header("Prediction (No Actual UCS)")
687
+     up = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"])
688
+     if up is not None:
689
+         book = read_book_bytes(up.getvalue())
690
+         if book:
691
+             df0 = next(iter(book.values()))
692
+             st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
693
+     if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
694
+         st.session_state.show_preview_modal = True  # Set state to show modal
695
+     go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
696
+     if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
697
+
698
+     sticky_header("Prediction", "Upload a dataset with the feature columns (no **UCS**).")
699
+
700
+     if go_btn and up is not None:
701
+         book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
702
+         df = book[name].copy()
703
+         if not ensure_cols(df, FEATURES): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
704
+         df["UCS_Pred"] = model.predict(df[FEATURES])
705
+         st.session_state.results["PredictOnly"]=df
706
+
707
+         ranges = st.session_state.train_ranges; oor_pct = 0.0
708
+         if ranges:
709
+             any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
710
+             oor_pct = float(any_viol.mean()*100.0)
711
+         st.session_state.results["sv_pred"]={
712
+             "n":len(df),
713
+             "pred_min":float(df["UCS_Pred"].min()),
714
+             "pred_max":float(df["UCS_Pred"].max()),
715
+             "pred_mean":float(df["UCS_Pred"].mean()),
716
+             "pred_std":float(df["UCS_Pred"].std(ddof=0)),
717
+             "oor":oor_pct
718
+         }
719
+
720
+     if "PredictOnly" in st.session_state.results:
721
+         df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
722
+
723
+         col_left, col_right = st.columns([2,3], gap="large")
724
+         with col_left:
725
+             table = pd.DataFrame({
726
+                 "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
727
+                 "Value":  [sv["n"],
728
+                           round(sv["pred_min"],2),
729
+                           round(sv["pred_max"],2),
730
+                           round(sv["pred_mean"],2),
731
+                           round(sv["pred_std"],2),
732
+                           f'{sv["oor"]:.1f}%']
733
+             })
734
+             st.markdown('<div class="st-message-box st-success">Predictions ready ✓</div>', unsafe_allow_html=True)
735
+             df_centered_rounded(table, hide_index=True)
736
+             st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
737
+         with col_right:
738
+             st.plotly_chart(
739
+                 track_plot(df, include_actual=False),
740
+                 use_container_width=False, # Set to False to honor the width in track_plot()
741
+                 config={"displayModeBar": False, "scrollZoom": True}
742
+             )
743
 
744
  # =========================
745
  # Run preview modal after all other elements
746
  # =========================
747
  if st.session_state.show_preview_modal:
748
+     # Get the correct book based on the current app step
749
+     book_to_preview = {}
750
+     if st.session_state.app_step == "dev":
751
+         book_to_preview = read_book_bytes(st.session_state.dev_file_bytes)
752
+     elif st.session_state.app_step in ["validate", "predict"] and up is not None:
753
+         book_to_preview = read_book_bytes(up.getvalue())
754
+
755
+     # Use a try-except block to handle cases where 'up' might be None
756
+     # and the logic tries to access its attributes.
757
+     try:
758
+         if st.session_state.app_step == "validate" and up is not None:
759
+               book_to_preview = read_book_bytes(up.getvalue())
760
+         elif st.session_state.app_step == "predict" and up is not None:
761
+               book_to_preview = read_book_bytes(up.getvalue())
762
+     except NameError:
763
+         book_to_preview = {}
764
+
765
+     with st.expander("Preview data", expanded=True):
766
+         if not book_to_preview:
767
+             st.markdown('<div class="st-message-box">No data loaded yet.</div>', unsafe_allow_html=True)
768
+         else:
769
+             names = list(book_to_preview.keys())
770
+             tabs = st.tabs(names)
771
+             for t, name in zip(tabs, names):
772
+                 with t:
773
+                     df = book_to_preview[name]
774
+                     t1, t2 = st.tabs(["Tracks", "Summary"])
775
+                     with t1:
776
+                         st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
777
+                     with t2:
778
+                         tbl = (df[FEATURES]
779
+                                  .agg(['min','max','mean','std'])
780
+                                  .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
781
+                         df_centered_rounded(tbl.reset_index(names="Feature"))
782
+     # Reset the state variable after the modal is displayed
783
+     st.session_state.show_preview_modal = False
784
 
785
  # =========================
786
  # Footer
787
  # =========================
788
  st.markdown("---")
789
  st.markdown(
790
+     """
791
+     <div style='text-align:center; color:#6b7280; line-height:1.6'>
792
+       ST_GeoMech_UCS • © Smart Thinking<br/>
793
+       <strong>Visit our website:</strong> <a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>
794
+     </div>
795
+     """,
796
+     unsafe_allow_html=True
797
  )