UCS2014 commited on
Commit
617b9e0
·
verified ·
1 Parent(s): 148589b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -128
app.py CHANGED
@@ -20,20 +20,23 @@ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
20
  # Page / Theme
21
  # =========================
22
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
 
 
23
  st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
 
 
24
  st.markdown(
25
  """
26
  <style>
27
  .stApp { background: #FFFFFF; }
28
  section[data-testid="stSidebar"] { background: #F6F9FC; }
 
29
  .stButton>button{ background:#007bff; color:#fff; font-weight:bold; border-radius:8px; border:none; padding:10px 24px; }
30
  .stButton>button:hover{ background:#0056b3; }
31
- /* Hero header */
32
- .st-hero { display:flex; align-items:center; gap:14px; padding: 6px 0 0 0; }
33
- .st-hero .brand { width:70px; height:70px; object-fit:contain; }
34
  .st-hero h1 { margin:0; line-height:1.05; }
35
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
36
- /* Ensure hero is tight to the top */
37
  [data-testid="stBlock"]{ margin-top:0 !important; }
38
  </style>
39
  """,
@@ -44,11 +47,8 @@ st.markdown(
44
  # Small helpers
45
  # =========================
46
  def _get_model_url():
47
- """Avoid Streamlit secrets error when secrets.toml is absent."""
48
- try:
49
- return (st.secrets.get("MODEL_URL", "") or os.environ.get("MODEL_URL", "") or "").strip()
50
- except Exception:
51
- return (os.environ.get("MODEL_URL", "") or "").strip()
52
 
53
  def rmse(y_true, y_pred):
54
  return float(np.sqrt(mean_squared_error(y_true, y_pred)))
@@ -86,12 +86,16 @@ def find_sheet(book, names):
86
  return low2orig[nm.lower()]
87
  return None
88
 
89
- def cross_plot(actual, pred, title, size=(5.6, 5.6)):
90
- fig, ax = plt.subplots(figsize=size)
91
- ax.scatter(actual, pred, s=16, alpha=0.7)
92
  lo = float(np.nanmin([actual.min(), pred.min()]))
93
  hi = float(np.nanmax([actual.max(), pred.max()]))
94
- ax.plot([lo, hi], [lo, hi], '--')
 
 
 
 
95
  ax.set_xlabel("Actual UCS")
96
  ax.set_ylabel("Predicted UCS")
97
  ax.set_title(title)
@@ -99,35 +103,35 @@ def cross_plot(actual, pred, title, size=(5.6, 5.6)):
99
  return fig
100
 
101
  def depth_or_index_track(df, title, include_actual=True):
102
- # If a depth-like column exists, plot UCS vs Depth (depth downward); else index track
103
  depth_col = None
104
  for c in df.columns:
105
  if 'depth' in str(c).lower():
106
  depth_col = c
107
  break
108
- fig, ax = plt.subplots(figsize=(5.8, 7.5))
 
 
 
109
  if depth_col is not None:
110
- ax.plot(df["UCS_Pred"], df[depth_col], label="UCS_Pred")
111
  if include_actual and TARGET in df.columns:
112
- ax.plot(df[TARGET], df[depth_col], alpha=0.7, label="UCS (actual)")
113
  ax.set_ylabel(depth_col)
114
  ax.set_xlabel("UCS")
115
- ax.xaxis.set_label_position('top')
116
- ax.xaxis.tick_top()
117
- ax.invert_yaxis()
118
  else:
119
  idx = np.arange(1, len(df) + 1)
120
- ax.plot(df["UCS_Pred"], idx, label="UCS_Pred")
121
  if include_actual and TARGET in df.columns:
122
- ax.plot(df[TARGET], idx, alpha=0.7, label="UCS (actual)")
123
  ax.set_ylabel("Point Index")
124
  ax.set_xlabel("UCS")
125
- ax.xaxis.set_label_position('top')
126
- ax.xaxis.tick_top()
127
- ax.invert_yaxis()
128
  ax.grid(True, linestyle=":", alpha=0.4)
129
- ax.set_title(title, pad=12)
130
- ax.legend()
131
  return fig
132
 
133
  def export_workbook(sheets_dict, summary_df=None):
@@ -180,18 +184,13 @@ MODEL_URL = _get_model_url()
180
 
181
  def ensure_model_present() -> Path:
182
  """Return a local model path, trying local files first, then (optionally) downloading with timeout."""
183
- # Check local paths first
184
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
185
  if p.exists() and p.stat().st_size > 0:
186
  return p
187
-
188
- # If no URL set, we cannot download
189
  if not MODEL_URL:
190
  return None
191
-
192
- # Try to download with a short timeout so startup can't hang
193
  try:
194
- import requests # only when needed
195
  DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
196
  with st.status("Downloading model…", expanded=False):
197
  with requests.get(MODEL_URL, stream=True, timeout=30) as r:
@@ -210,7 +209,6 @@ if not model_path:
210
  st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).")
211
  st.stop()
212
 
213
- # Load model (fix: correct try/except block)
214
  try:
215
  model = load_model(str(model_path))
216
  except Exception as e:
@@ -278,7 +276,7 @@ if st.session_state.app_step == "intro":
278
  st.markdown(
279
  "1. *Upload the Model Development Data.* This should contain your training and testing sets.\n"
280
  "2. Click *Run Model* to view metrics, cross-plots, and a track plot.\n"
281
- "3. Click *Go to Prediction* and upload a new dataset to get predictions.\n"
282
  "4. *Export* everything to Excel for further analysis."
283
  )
284
  if st.button("Start Showcase", type="primary", key="start_showcase"):
@@ -290,10 +288,18 @@ if st.session_state.app_step == "intro":
290
  # =========================
291
  if st.session_state.app_step == "dev":
292
  st.sidebar.header("Model Development Data")
293
- train_test_file = st.sidebar.file_uploader("Upload Train/Test Excel", type=["xlsx", "xls"], key="dev_upload")
 
 
 
 
 
 
 
 
 
 
294
  run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
295
- if "Train" in st.session_state.results or "Test" in st.session_state.results:
296
- st.sidebar.button("Go to Prediction ▶", use_container_width=True, on_click=lambda: st.session_state.update(app_step="predict"))
297
 
298
  st.subheader("Model Development")
299
  if run_btn and train_test_file is not None:
@@ -304,91 +310,82 @@ if st.session_state.app_step == "dev":
304
  st.stop()
305
  status.update(label="Workbook read ✓")
306
 
 
307
  sh_train = find_sheet(book, ["Train", "Training", "training2", "train", "training"])
308
- sh_test = find_sheet(book, ["Test", "Testing", "testing2", "test", "testing"])
309
  if sh_train is None or sh_test is None:
310
  status.update(label="Workbook must include Train/Training/training2 and Test/Testing/testing2.", state="error")
311
  st.stop()
312
 
313
- df_tr = book[sh_train].copy()
314
- df_te = book[sh_test].copy()
315
  if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])):
316
- status.update(label="Missing required columns.", state="error")
317
- st.stop()
318
 
319
  status.update(label="Columns validated ✓")
320
  status.update(label="Predicting…")
321
 
322
  df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
323
  df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
324
- st.session_state.results["Train"] = df_tr
325
- st.session_state.results["Test"] = df_te
326
 
327
  st.session_state.results["metrics_train"] = {
328
  "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
329
  "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
330
  "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]),
331
  }
332
- st.session_state.results["metrics_test"] = {
333
  "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
334
  "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
335
  "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
336
  }
337
 
338
- tr_min = df_tr[FEATURES].min().to_dict()
339
- tr_max = df_tr[FEATURES].max().to_dict()
340
- st.session_state.train_ranges = {f: (float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
341
 
342
- status.update(label="Done ✓", state="complete")
343
- toast("Model run complete 🚀")
344
 
345
- if "Train" in st.session_state.results or "Test" in st.session_state.results:
346
  tab1, tab2 = st.tabs(["Training", "Testing"])
347
  if "Train" in st.session_state.results:
348
  with tab1:
349
- df = st.session_state.results["Train"]
350
- m = st.session_state.results["metrics_train"]
351
- c1, c2, c3 = st.columns(3)
352
  c1.metric("R²", f"{m['R2']:.4f}")
353
  c2.metric("RMSE", f"{m['RMSE']:.4f}")
354
  c3.metric("MAE", f"{m['MAE']:.4f}")
355
- left, right = st.columns(2)
356
  with left:
357
  st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
358
  with right:
359
  st.pyplot(depth_or_index_track(df, "Training: Depth/Index Track", include_actual=True), use_container_width=True)
 
360
  if "Test" in st.session_state.results:
361
  with tab2:
362
- df = st.session_state.results["Test"]
363
- m = st.session_state.results["metrics_test"]
364
- c1, c2, c3 = st.columns(3)
365
  c1.metric("R²", f"{m['R2']:.4f}")
366
  c2.metric("RMSE", f"{m['RMSE']:.4f}")
367
  c3.metric("MAE", f"{m['MAE']:.4f}")
368
- left, right = st.columns(2)
369
  with left:
370
  st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
371
  with right:
372
  st.pyplot(depth_or_index_track(df, "Testing: Depth/Index Track", include_actual=True), use_container_width=True)
373
 
374
  st.markdown("---")
375
- sheets = {}
376
- rows = []
377
  if "Train" in st.session_state.results:
378
  sheets["Train_with_pred"] = st.session_state.results["Train"]
379
- rows.append({"Split": "Train", **{k: round(v, 6) for k, v in st.session_state.results["metrics_train"].items()}})
380
  if "Test" in st.session_state.results:
381
  sheets["Test_with_pred"] = st.session_state.results["Test"]
382
- rows.append({"Split": "Test", **{k: round(v, 6) for k, v in st.session_state.results["metrics_test"].items()}})
383
  summary_df = pd.DataFrame(rows) if rows else None
384
  try:
385
  data_bytes = export_workbook(sheets, summary_df)
386
- st.download_button(
387
- "Export Train/Test Results to Excel",
388
- data=data_bytes,
389
- file_name="UCS_Dev_Results.xlsx",
390
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
391
- )
392
  except RuntimeError as e:
393
  st.warning(str(e))
394
 
@@ -397,7 +394,7 @@ if st.session_state.app_step == "dev":
397
  # =========================
398
  if st.session_state.app_step == "predict":
399
  st.sidebar.header("Prediction (Validation)")
400
- validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx", "xls"], key="val_upload")
401
  predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
402
  st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
403
 
@@ -407,83 +404,55 @@ if st.session_state.app_step == "predict":
407
  if predict_btn and validation_file is not None:
408
  with st.status("Predicting…", expanded=False) as status:
409
  vbook = read_book(validation_file)
410
- if not vbook:
411
- status.update(label="Could not read the Validation Excel.", state="error")
412
- st.stop()
413
  status.update(label="Workbook read ✓")
414
- vname = find_sheet(vbook, ["Validation", "Validate", "validation2", "Val", "val"]) or list(vbook.keys())[0]
415
  df_val = vbook[vname].copy()
416
- if not ensure_cols(df_val, FEATURES):
417
- status.update(label="Missing required columns.", state="error")
418
- st.stop()
419
  status.update(label="Columns validated ✓")
420
  df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
421
  st.session_state.results["Validate"] = df_val
422
 
423
- # OOR check against training min–max
424
- ranges = st.session_state.train_ranges
425
- oor_table = None
426
- oor_pct = 0.0
427
  if ranges:
428
  viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
429
- any_viol = pd.DataFrame(viol).any(axis=1)
430
- oor_pct = float(any_viol.mean() * 100.0)
431
  if any_viol.any():
432
  offenders = df_val.loc[any_viol, FEATURES].copy()
433
- offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(
434
- lambda r: ", ".join([c for c, v in r.items() if v]), axis=1
435
- )
436
- offenders.index = offenders.index + 1
437
- oor_table = offenders
438
 
439
  metrics_val = None
440
  if TARGET in df_val.columns:
441
  metrics_val = {
442
  "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
443
  "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
444
- "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"]),
445
  }
446
  st.session_state.results["metrics_val"] = metrics_val
447
  st.session_state.results["summary_val"] = {
448
  "n_points": len(df_val),
449
  "pred_min": float(df_val["UCS_Pred"].min()),
450
  "pred_max": float(df_val["UCS_Pred"].max()),
451
- "oor_pct": oor_pct,
452
  }
453
  st.session_state.results["oor_table"] = oor_table
454
  status.update(label="Predictions ready ✓", state="complete")
455
 
456
  if "Validate" in st.session_state.results:
457
  st.subheader("Validation Results")
458
- sv = st.session_state.results["summary_val"]
459
- oor_table = st.session_state.results.get("oor_table")
460
- c1, c2, c3, c4 = st.columns(4)
461
- c1.metric("points", f"{sv['n_points']}")
462
- c2.metric("Pred min", f"{sv['pred_min']:.2f}")
463
- c3.metric("Pred max", f"{sv['pred_max']:.2f}")
464
- c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
465
- left, right = st.columns(2)
466
  with left:
467
  if TARGET in st.session_state.results["Validate"].columns:
468
- st.pyplot(
469
- cross_plot(
470
- st.session_state.results["Validate"][TARGET],
471
- st.session_state.results["Validate"]["UCS_Pred"],
472
- "Validation: Actual vs Predicted",
473
- ),
474
- use_container_width=True,
475
- )
476
  else:
477
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
478
  with right:
479
- st.pyplot(
480
- depth_or_index_track(
481
- st.session_state.results["Validate"],
482
- "Validation: Depth/Index Track",
483
- include_actual=(TARGET in st.session_state.results["Validate"].columns),
484
- ),
485
- use_container_width=True,
486
- )
487
  if oor_table is not None:
488
  st.write("*Out-of-range rows (vs. Training min–max):*")
489
  st.dataframe(oor_table, use_container_width=True)
@@ -491,19 +460,15 @@ if st.session_state.app_step == "predict":
491
  st.markdown("---")
492
  sheets = {"Validate_with_pred": st.session_state.results["Validate"]}
493
  rows = []
494
- for name, key in [("Train", "metrics_train"), ("Test", "metrics_test"), ("Validate", "metrics_val")]:
495
  m = st.session_state.results.get(key)
496
- if m:
497
- rows.append({"Split": name, **{k: round(v, 6) for k, v in m.items()}})
498
  summary_df = pd.DataFrame(rows) if rows else None
499
  try:
500
  data_bytes = export_workbook(sheets, summary_df)
501
- st.download_button(
502
- "Export Validation Results to Excel",
503
- data=data_bytes,
504
- file_name="UCS_Validation_Results.xlsx",
505
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
506
- )
507
  except RuntimeError as e:
508
  st.warning(str(e))
509
 
@@ -511,7 +476,4 @@ if st.session_state.app_step == "predict":
511
  # Footer
512
  # =========================
513
  st.markdown("---")
514
- st.markdown(
515
- "<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>",
516
- unsafe_allow_html=True,
517
- )
 
20
  # Page / Theme
21
  # =========================
22
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
23
+
24
+ # Hide Streamlit default header/footer
25
  st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
26
+
27
+ # Compact page, bigger logo, tidy hero
28
  st.markdown(
29
  """
30
  <style>
31
  .stApp { background: #FFFFFF; }
32
  section[data-testid="stSidebar"] { background: #F6F9FC; }
33
+ .block-container { padding-top: .5rem; padding-bottom: .5rem; } /* less vertical padding */
34
  .stButton>button{ background:#007bff; color:#fff; font-weight:bold; border-radius:8px; border:none; padding:10px 24px; }
35
  .stButton>button:hover{ background:#0056b3; }
36
+ .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
37
+ .st-hero .brand { width:110px; height:110px; object-fit:contain; } /* enlarged logo */
 
38
  .st-hero h1 { margin:0; line-height:1.05; }
39
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
 
40
  [data-testid="stBlock"]{ margin-top:0 !important; }
41
  </style>
42
  """,
 
47
  # Small helpers
48
  # =========================
49
  def _get_model_url():
50
+ """HuggingFace exposes Space variables via environment; avoid st.secrets to prevent red banner."""
51
+ return (os.environ.get("MODEL_URL", "") or "").strip()
 
 
 
52
 
53
  def rmse(y_true, y_pred):
54
  return float(np.sqrt(mean_squared_error(y_true, y_pred)))
 
86
  return low2orig[nm.lower()]
87
  return None
88
 
89
+ def cross_plot(actual, pred, title, size=(5.0, 5.0)):
90
+ fig, ax = plt.subplots(figsize=size, dpi=100)
91
+ ax.scatter(actual, pred, s=14, alpha=0.75)
92
  lo = float(np.nanmin([actual.min(), pred.min()]))
93
  hi = float(np.nanmax([actual.max(), pred.max()]))
94
+ pad = 0.03 * (hi - lo if hi > lo else 1.0)
95
+ ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], '--', lw=1.2, color=(0.35, 0.35, 0.35))
96
+ ax.set_xlim(lo - pad, hi + pad)
97
+ ax.set_ylim(lo - pad, hi + pad)
98
+ ax.set_aspect('equal', 'box') # perfect 1:1
99
  ax.set_xlabel("Actual UCS")
100
  ax.set_ylabel("Predicted UCS")
101
  ax.set_title(title)
 
103
  return fig
104
 
105
  def depth_or_index_track(df, title, include_actual=True):
106
+ # depth-like column?
107
  depth_col = None
108
  for c in df.columns:
109
  if 'depth' in str(c).lower():
110
  depth_col = c
111
  break
112
+ # taller for a log-profile look
113
+ fig_h = 8.8 if depth_col is not None else 8.0
114
+ fig, ax = plt.subplots(figsize=(6.2, fig_h), dpi=100)
115
+
116
  if depth_col is not None:
117
+ ax.plot(df["UCS_Pred"], df[depth_col], '--', lw=1.6, label="UCS_Pred")
118
  if include_actual and TARGET in df.columns:
119
+ ax.plot(df[TARGET], df[depth_col], '-', lw=2.0, alpha=0.8, label="UCS (actual)")
120
  ax.set_ylabel(depth_col)
121
  ax.set_xlabel("UCS")
122
+ ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
 
 
123
  else:
124
  idx = np.arange(1, len(df) + 1)
125
+ ax.plot(df["UCS_Pred"], idx, '--', lw=1.6, label="UCS_Pred")
126
  if include_actual and TARGET in df.columns:
127
+ ax.plot(df[TARGET], idx, '-', lw=2.0, alpha=0.8, label="UCS (actual)")
128
  ax.set_ylabel("Point Index")
129
  ax.set_xlabel("UCS")
130
+ ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
131
+
 
132
  ax.grid(True, linestyle=":", alpha=0.4)
133
+ ax.set_title(title, pad=10)
134
+ ax.legend(loc="best")
135
  return fig
136
 
137
  def export_workbook(sheets_dict, summary_df=None):
 
184
 
185
  def ensure_model_present() -> Path:
186
  """Return a local model path, trying local files first, then (optionally) downloading with timeout."""
 
187
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
188
  if p.exists() and p.stat().st_size > 0:
189
  return p
 
 
190
  if not MODEL_URL:
191
  return None
 
 
192
  try:
193
+ import requests
194
  DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
195
  with st.status("Downloading model…", expanded=False):
196
  with requests.get(MODEL_URL, stream=True, timeout=30) as r:
 
209
  st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).")
210
  st.stop()
211
 
 
212
  try:
213
  model = load_model(str(model_path))
214
  except Exception as e:
 
276
  st.markdown(
277
  "1. *Upload the Model Development Data.* This should contain your training and testing sets.\n"
278
  "2. Click *Run Model* to view metrics, cross-plots, and a track plot.\n"
279
+ "3. Click *Proceed to Prediction* and upload a new dataset to get predictions.\n"
280
  "4. *Export* everything to Excel for further analysis."
281
  )
282
  if st.button("Start Showcase", type="primary", key="start_showcase"):
 
288
  # =========================
289
  if st.session_state.app_step == "dev":
290
  st.sidebar.header("Model Development Data")
291
+ train_test_file = st.sidebar.file_uploader("Upload Data (Excel)", type=["xlsx", "xls"], key="dev_upload")
292
+
293
+ # Always show the nav button, disabled until results exist once.
294
+ ready_for_pred = ("Train" in st.session_state.results) or ("Test" in st.session_state.results)
295
+ st.sidebar.button(
296
+ "Proceed to Prediction ▶",
297
+ use_container_width=True,
298
+ disabled=not ready_for_pred,
299
+ on_click=(lambda: st.session_state.update(app_step="predict")) if ready_for_pred else None,
300
+ )
301
+
302
  run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
 
 
303
 
304
  st.subheader("Model Development")
305
  if run_btn and train_test_file is not None:
 
310
  st.stop()
311
  status.update(label="Workbook read ✓")
312
 
313
+ # still expect Train/Test sheets internally
314
  sh_train = find_sheet(book, ["Train", "Training", "training2", "train", "training"])
315
+ sh_test = find_sheet(book, ["Test", "Testing", "testing2", "test", "testing"])
316
  if sh_train is None or sh_test is None:
317
  status.update(label="Workbook must include Train/Training/training2 and Test/Testing/testing2.", state="error")
318
  st.stop()
319
 
320
+ df_tr = book[sh_train].copy(); df_te = book[sh_test].copy()
 
321
  if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])):
322
+ status.update(label="Missing required columns.", state="error"); st.stop()
 
323
 
324
  status.update(label="Columns validated ✓")
325
  status.update(label="Predicting…")
326
 
327
  df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
328
  df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
329
+ st.session_state.results["Train"] = df_tr; st.session_state.results["Test"] = df_te
 
330
 
331
  st.session_state.results["metrics_train"] = {
332
  "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
333
  "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
334
  "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]),
335
  }
336
+ st.session_state.results["metrics_test"] = {
337
  "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
338
  "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
339
  "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
340
  }
341
 
342
+ tr_min = df_tr[FEATURES].min().to_dict(); tr_max = df_tr[FEATURES].max().to_dict()
343
+ st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
 
344
 
345
+ status.update(label="Done ✓", state="complete"); toast("Model run complete 🚀")
 
346
 
347
+ if ("Train" in st.session_state.results) or ("Test" in st.session_state.results):
348
  tab1, tab2 = st.tabs(["Training", "Testing"])
349
  if "Train" in st.session_state.results:
350
  with tab1:
351
+ df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
352
+ c1,c2,c3 = st.columns(3)
 
353
  c1.metric("R²", f"{m['R2']:.4f}")
354
  c2.metric("RMSE", f"{m['RMSE']:.4f}")
355
  c3.metric("MAE", f"{m['MAE']:.4f}")
356
+ left,right = st.columns([1,1])
357
  with left:
358
  st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
359
  with right:
360
  st.pyplot(depth_or_index_track(df, "Training: Depth/Index Track", include_actual=True), use_container_width=True)
361
+
362
  if "Test" in st.session_state.results:
363
  with tab2:
364
+ df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
365
+ c1,c2,c3 = st.columns(3)
 
366
  c1.metric("R²", f"{m['R2']:.4f}")
367
  c2.metric("RMSE", f"{m['RMSE']:.4f}")
368
  c3.metric("MAE", f"{m['MAE']:.4f}")
369
+ left,right = st.columns([1,1])
370
  with left:
371
  st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
372
  with right:
373
  st.pyplot(depth_or_index_track(df, "Testing: Depth/Index Track", include_actual=True), use_container_width=True)
374
 
375
  st.markdown("---")
376
+ sheets = {}; rows = []
 
377
  if "Train" in st.session_state.results:
378
  sheets["Train_with_pred"] = st.session_state.results["Train"]
379
+ rows.append({"Split":"Train", **{k:round(v,6) for k,v in st.session_state.results["metrics_train"].items()}})
380
  if "Test" in st.session_state.results:
381
  sheets["Test_with_pred"] = st.session_state.results["Test"]
382
+ rows.append({"Split":"Test", **{k:round(v,6) for k,v in st.session_state.results["metrics_test"].items()}})
383
  summary_df = pd.DataFrame(rows) if rows else None
384
  try:
385
  data_bytes = export_workbook(sheets, summary_df)
386
+ st.download_button("Export Train/Test Results to Excel",
387
+ data=data_bytes, file_name="UCS_Dev_Results.xlsx",
388
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
 
 
 
389
  except RuntimeError as e:
390
  st.warning(str(e))
391
 
 
394
  # =========================
395
  if st.session_state.app_step == "predict":
396
  st.sidebar.header("Prediction (Validation)")
397
+ validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload")
398
  predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
399
  st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
400
 
 
404
  if predict_btn and validation_file is not None:
405
  with st.status("Predicting…", expanded=False) as status:
406
  vbook = read_book(validation_file)
407
+ if not vbook: status.update(label="Could not read the Validation Excel.", state="error"); st.stop()
 
 
408
  status.update(label="Workbook read ✓")
409
+ vname = find_sheet(vbook, ["Validation","Validate","validation2","Val","val"]) or list(vbook.keys())[0]
410
  df_val = vbook[vname].copy()
411
+ if not ensure_cols(df_val, FEATURES): status.update(label="Missing required columns.", state="error"); st.stop()
 
 
412
  status.update(label="Columns validated ✓")
413
  df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
414
  st.session_state.results["Validate"] = df_val
415
 
416
+ ranges = st.session_state.train_ranges; oor_table = None; oor_pct = 0.0
 
 
 
417
  if ranges:
418
  viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
419
+ any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
 
420
  if any_viol.any():
421
  offenders = df_val.loc[any_viol, FEATURES].copy()
422
+ offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
423
+ offenders.index = offenders.index + 1; oor_table = offenders
 
 
 
424
 
425
  metrics_val = None
426
  if TARGET in df_val.columns:
427
  metrics_val = {
428
  "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
429
  "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
430
+ "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"])
431
  }
432
  st.session_state.results["metrics_val"] = metrics_val
433
  st.session_state.results["summary_val"] = {
434
  "n_points": len(df_val),
435
  "pred_min": float(df_val["UCS_Pred"].min()),
436
  "pred_max": float(df_val["UCS_Pred"].max()),
437
+ "oor_pct": oor_pct
438
  }
439
  st.session_state.results["oor_table"] = oor_table
440
  status.update(label="Predictions ready ✓", state="complete")
441
 
442
  if "Validate" in st.session_state.results:
443
  st.subheader("Validation Results")
444
+ sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
445
+ c1,c2,c3,c4 = st.columns(4)
446
+ c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
447
+ c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
448
+ left,right = st.columns([1,1])
 
 
 
449
  with left:
450
  if TARGET in st.session_state.results["Validate"].columns:
451
+ st.pyplot(cross_plot(st.session_state.results["Validate"][TARGET], st.session_state.results["Validate"]["UCS_Pred"], "Validation: Actual vs Predicted"), use_container_width=True)
 
 
 
 
 
 
 
452
  else:
453
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
454
  with right:
455
+ st.pyplot(depth_or_index_track(st.session_state.results["Validate"], "Validation: Depth/Index Track", include_actual=(TARGET in st.session_state.results["Validate"].columns)), use_container_width=True)
 
 
 
 
 
 
 
456
  if oor_table is not None:
457
  st.write("*Out-of-range rows (vs. Training min–max):*")
458
  st.dataframe(oor_table, use_container_width=True)
 
460
  st.markdown("---")
461
  sheets = {"Validate_with_pred": st.session_state.results["Validate"]}
462
  rows = []
463
+ for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]:
464
  m = st.session_state.results.get(key)
465
+ if m: rows.append({"Split": name, **{k: round(v,6) for k,v in m.items()}})
 
466
  summary_df = pd.DataFrame(rows) if rows else None
467
  try:
468
  data_bytes = export_workbook(sheets, summary_df)
469
+ st.download_button("Export Validation Results to Excel",
470
+ data=data_bytes, file_name="UCS_Validation_Results.xlsx",
471
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
 
 
 
472
  except RuntimeError as e:
473
  st.warning(str(e))
474
 
 
476
  # Footer
477
  # =========================
478
  st.markdown("---")
479
+ st.markdown("<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>", unsafe_allow_html=True)