UCS2014 commited on
Commit
53297b4
·
verified ·
1 Parent(s): 843472e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +604 -410
app.py CHANGED
@@ -4,82 +4,75 @@ import streamlit as st
4
  import pandas as pd
5
  import numpy as np
6
  import joblib
 
 
 
 
7
 
8
- # --- Plotly (interactive) ---
9
  import plotly.graph_objects as go
10
- from plotly.subplots import make_subplots
11
-
12
- from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
13
 
14
- # =========================================================
15
- # Defaults (overridden by models/meta.json or model.feature_names_in_)
16
- # =========================================================
17
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
18
  TARGET = "UCS"
19
  MODELS_DIR = Path("models")
20
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
21
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
22
 
23
- # =========================================================
24
- # Page / Theme + CSS
25
- # =========================================================
26
- st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
27
 
 
 
 
 
 
28
  st.markdown(
29
  """
30
  <style>
31
- /* App + sidebar background */
32
  .stApp { background: #FFFFFF; }
33
  section[data-testid="stSidebar"] { background: #F6F9FC; }
34
-
35
- /* Tighten top spacing */
36
- [data-testid="stBlock"]{ margin-top: 0 !important; }
37
-
38
- /* Hero row */
39
- .st-hero { display:flex; align-items:center; gap:16px; padding-top: 6px; }
40
- .st-hero .brand { width:90px; height:90px; object-fit:contain; }
41
  .st-hero h1 { margin:0; line-height:1.05; }
42
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
43
-
44
- /* Sidebar button palette (order-based within the Sidebar section)
45
- 1) Preview (orange) 2) Run (blue) 3) Proceed (green)
46
- We scope to the sidebar and to stButton blocks only. */
47
- section[data-testid="stSidebar"] div.stButton > button {
48
- font-weight:700; border-radius:10px; border:none; padding:10px 20px;
49
- }
50
- section[data-testid="stSidebar"] div.stButton:nth-of-type(1) > button { /* Preview */
51
- background:#f59e0b; color:#fff;
52
- }
53
- section[data-testid="stSidebar"] div.stButton:nth-of-type(2) > button { /* Run (blue) */
54
- background:#2563eb; color:#fff;
55
- }
56
- section[data-testid="stSidebar"] div.stButton:nth-of-type(3) > button { /* Proceed (green) */
57
- background:#10b981; color:#fff;
58
- }
59
- section[data-testid="stSidebar"] div.stButton:nth-of-type(3) > button:disabled {
60
- background:#a7f3d0 !important; color:#064e3b !important; opacity:.7 !important;
61
- }
62
-
63
- /* Modal tabs spacing */
64
- .stTabs [data-baseweb="tab-list"] { gap: 6px; }
65
  </style>
66
  """,
67
  unsafe_allow_html=True
68
  )
69
 
70
- # =========================================================
71
  # Helpers
72
- # =========================================================
73
- def inline_logo(path="logo.png") -> str:
74
- try:
75
- p = Path(path)
76
- if not p.exists(): return ""
77
- return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
78
- except Exception:
79
- return ""
 
 
 
 
 
 
 
80
 
81
  def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
82
 
 
 
 
 
 
 
 
83
  @st.cache_resource(show_spinner=False)
84
  def load_model(model_path: str):
85
  return joblib.load(model_path)
@@ -90,187 +83,348 @@ def parse_excel(data_bytes: bytes):
90
  xl = pd.ExcelFile(bio)
91
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
92
 
93
- def ensure_cols(df, cols):
94
- miss = [c for c in cols if c not in df.columns]
95
- if miss:
96
- st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
97
- return False
98
- return True
99
 
100
- def infer_features_from_model(m):
101
- try:
102
- if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
103
- return [str(x) for x in m.feature_names_in_]
104
- except Exception: pass
105
- try:
106
- if hasattr(m, "steps") and len(m.steps):
107
- last = m.steps[-1][1]
108
- if hasattr(last, "feature_names_in_") and len(last.feature_names_in_):
109
- return [str(x) for x in last.feature_names_in_]
110
- except Exception: pass
111
  return None
112
 
113
- def export_workbook(sheets_dict, summary_df=None):
114
- try: import openpyxl # ensure engine is available
115
- except Exception:
116
- raise RuntimeError("Export requires openpyxl. Please add it to requirements.txt.")
117
- buf = io.BytesIO()
118
- with pd.ExcelWriter(buf, engine="openpyxl") as xw:
119
- for name, frame in sheets_dict.items():
120
- frame.to_excel(xw, sheet_name=name[:31], index=False)
121
- if summary_df is not None:
122
- summary_df.to_excel(xw, sheet_name="Summary", index=False)
123
- return buf.getvalue()
124
-
125
- # -------------------- Plotly styling blocks --------------------
126
- AXES_STYLE = dict(
127
- showline=True, linewidth=1.4, linecolor="#444",
128
- mirror=True, ticks="outside", ticklen=4, tickwidth=1,
129
- showgrid=True, gridcolor="rgba(0,0,0,0.08)"
130
- )
131
- FONT = dict(color="#111", size=13)
132
 
133
- def style_layout(fig, width=None, height=None, margins=(12,18,36,12)):
134
- t, r, b, l = margins
135
- fig.update_layout(
136
- margin=dict(t=t, r=r, b=b, l=l),
137
- paper_bgcolor="white",
138
- plot_bgcolor="white",
139
- font=FONT,
140
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
141
- )
142
- if width: fig.update_layout(width=width)
143
- if height: fig.update_layout(height=height)
144
- # Apply to all axes
145
- fig.update_xaxes(**AXES_STYLE, title_font=dict(size=14, color="#111"))
146
- fig.update_yaxes(**AXES_STYLE, title_font=dict(size=14, color="#111"))
 
 
 
 
 
 
 
147
  return fig
148
 
149
- def make_cross_plotly(A, P, height=440, width=640):
150
- a = pd.Series(A).astype(float)
151
- p = pd.Series(P).astype(float)
 
 
152
  lo = float(np.nanmin([a.min(), p.min()]))
153
  hi = float(np.nanmax([a.max(), p.max()]))
 
 
154
 
155
  fig = go.Figure()
 
 
156
  fig.add_trace(go.Scatter(
157
- x=a, y=p, mode="markers", name="Points", marker=dict(size=6)
 
 
 
 
 
158
  ))
 
 
159
  fig.add_trace(go.Scatter(
160
- x=[lo, hi], y=[lo, hi], mode="lines", name="1:1",
161
- line=dict(color="#666", width=2, dash="dash")
 
 
 
 
162
  ))
163
- fig.update_xaxes(range=[lo, hi], title="Actual UCS")
164
- fig.update_yaxes(range=[lo, hi], title="Predicted UCS", scaleanchor="x", scaleratio=1)
165
- style_layout(fig, width=width, height=height, margins=(8,10,36,50))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  return fig
167
 
168
- def make_depth_track_plotly(df, include_actual=True, height=640, width=360):
169
- idx = np.arange(1, len(df) + 1)
 
 
 
 
 
 
 
 
170
  fig = go.Figure()
 
171
  # Predicted (solid blue)
172
  fig.add_trace(go.Scatter(
173
- x=df["UCS_Pred"], y=idx, mode="lines", name="UCS_Pred",
174
- line=dict(color="#1f77b4", width=2)
 
 
 
175
  ))
 
176
  # Actual (dotted yellow)
177
  if include_actual and TARGET in df.columns:
178
  fig.add_trace(go.Scatter(
179
- x=df[TARGET], y=idx, mode="lines", name="UCS (actual)",
180
- line=dict(color="#f2b01e", width=2, dash="dot")
 
 
 
181
  ))
182
- fig.update_yaxes(autorange="reversed", title="Point Index")
183
- fig.update_xaxes(title="UCS")
184
- style_layout(fig, width=width, height=height, margins=(8,12,36,60))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  return fig
186
 
187
- def make_index_tracks_plotly(df, feature_cols, height=640, width=980):
188
- n = len(feature_cols)
189
- fig = make_subplots(rows=1, cols=n, shared_yaxes=True, horizontal_spacing=0.05)
190
- idx = np.arange(1, len(df) + 1)
 
 
 
 
 
191
 
192
- for i, col in enumerate(feature_cols, start=1):
193
- fig.add_trace(
194
- go.Scatter(x=df[col], y=idx, mode="lines", line=dict(color="#444", width=1.2), name=col, showlegend=False),
195
- row=1, col=i
196
- )
197
- fig.update_xaxes(title=col, row=1, col=i)
198
- fig.update_yaxes(autorange="reversed", title="Point Index", row=1, col=1)
199
- style_layout(fig, width=width, height=height, margins=(6,8,36,60))
200
- return fig
201
 
202
- # =========================================================
203
- # Model availability (cloud-safe)
204
- # =========================================================
205
- def _get_model_url():
206
  try:
207
- return (st.secrets.get("MODEL_URL", "") or os.environ.get("MODEL_URL", "") or "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  except Exception:
209
- return (os.environ.get("MODEL_URL", "") or "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- def ensure_model_present() -> Path | None:
212
- # local candidates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
214
- if p.exists():
215
  return p
216
- # cloud download
217
- MODEL_URL = _get_model_url()
218
- if MODEL_URL:
219
- try:
220
- import requests
221
- DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
222
- with requests.get(MODEL_URL, stream=True) as r:
223
  r.raise_for_status()
224
  with open(DEFAULT_MODEL, "wb") as f:
225
  for chunk in r.iter_content(chunk_size=1<<20):
226
- f.write(chunk)
227
- return DEFAULT_MODEL
228
- except Exception as e:
229
- st.error(f"Failed to download model from MODEL_URL: {e}")
230
- return None
231
 
232
  model_path = ensure_model_present()
233
  if not model_path:
234
- st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
235
  st.stop()
236
 
237
- # Load model
238
  try:
239
  model = load_model(str(model_path))
240
  except Exception as e:
241
  st.error(f"Failed to load model: {model_path}\n{e}")
242
  st.stop()
243
 
244
- # Meta overrides
245
  meta_path = MODELS_DIR / "meta.json"
246
  if meta_path.exists():
247
  try:
248
  meta = json.loads(meta_path.read_text(encoding="utf-8"))
249
- FEATURES = meta.get("features", FEATURES)
250
- TARGET = meta.get("target", TARGET)
251
- except Exception:
252
- pass
253
  else:
254
  infer = infer_features_from_model(model)
255
  if infer: FEATURES = infer
256
 
257
- # =========================================================
258
- # Session state defaults
259
- # =========================================================
260
- ss = st.session_state
261
- ss.setdefault("app_step", "intro") # we start on Intro
262
- ss.setdefault("dev_bytes", None)
263
- ss.setdefault("dev_book", None)
264
- ss.setdefault("dev_previewed", False)
265
- ss.setdefault("dev_ran", False)
266
- ss.setdefault("results", {})
267
- ss.setdefault("train_ranges", None)
268
- ss.setdefault("val_bytes", None)
269
- ss.setdefault("val_book", None)
270
-
271
- # =========================================================
272
- # HERO (logo + title)
273
- # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  st.markdown(
275
  f"""
276
  <div class="st-hero">
@@ -284,277 +438,317 @@ st.markdown(
284
  unsafe_allow_html=True,
285
  )
286
 
287
- # =========================================================
288
  # INTRO PAGE
289
- # =========================================================
290
- if ss.app_step == "intro":
291
  st.header("Welcome!")
292
  st.markdown(
293
- "1. **Upload your data to build the case** and preview the performance of our model. \n"
294
- "2. Click **Run Model** to compute metrics, cross-plots, and the index track. \n"
295
- "3. Click **Proceed to Prediction** to validate on a new dataset."
296
  )
297
- if st.button("Start", type="primary"):
298
- ss.app_step = "dev"
299
- st.rerun()
300
-
301
- # =========================================================
302
- # Helper banner (stays at top of Development page)
303
- # =========================================================
304
- def render_dev_helper():
305
- st.subheader("Model Development")
306
- if not ss.dev_bytes:
307
- st.info("Upload your data to build the case and preview the performance of our model.")
308
- elif ss.dev_bytes and not ss.dev_previewed and not ss.dev_ran:
309
- st.info("File loaded click **Preview data**.")
310
- elif ss.dev_previewed and not ss.dev_ran:
311
- st.info("Previewed ✓ — now click **Run Model** to build the case.")
312
- elif ss.dev_ran:
313
- st.success("Case built — results are displayed below. You can now **Proceed to Prediction**.")
314
-
315
- # =========================================================
316
- # PREVIEW MODAL
317
- # =========================================================
318
- def preview_modal_dev(book, feature_cols):
319
- sh_train = None
320
- sh_test = None
321
- # try common names
322
- low2orig = {k.lower(): k for k in book.keys()}
323
- for nm in ["train","training","training2"]:
324
- if nm in low2orig: sh_train = low2orig[nm]; break
325
- for nm in ["test","testing","testing2"]:
326
- if nm in low2orig: sh_test = low2orig[nm]; break
327
-
328
- tabs = st.tabs(["Tracks", "Summary"])
329
- with tabs[0]:
330
- # prefer Train if available; else first sheet
331
- pick = sh_train or list(book.keys())[0]
332
- df = book[pick]
333
- # only numeric columns needed for plotting
334
- ok_cols = [c for c in feature_cols if c in df.columns]
335
- if not ok_cols:
336
- st.warning("No matching feature columns found for plotting.")
337
- else:
338
- fig = make_index_tracks_plotly(df, ok_cols, height=640, width=1000)
339
- st.plotly_chart(fig, use_container_width=True, theme=None)
340
- with tabs[1]:
341
- pick = sh_train or list(book.keys())[0]
342
- df = book[pick]
343
- st.dataframe(
344
- df.describe().T.rename(columns={
345
- "mean":"Mean","std":"Std","min":"Min","max":"Max"
346
- })[["Min","Max","Mean","Std"]].round(4),
347
- use_container_width=True
 
 
 
348
  )
349
 
350
- # =========================================================
351
- # DEVELOPMENT PAGE
352
- # =========================================================
353
- if ss.app_step == "dev":
354
- render_dev_helper()
355
-
356
- with st.sidebar:
357
- st.header("Model Development Data")
358
-
359
- def _on_dev_upload():
360
- file = st.session_state.get("dev_upload")
361
- if file is not None:
362
- ss.dev_bytes = file.getvalue()
363
- ss.dev_book = parse_excel(ss.dev_bytes)
364
- ss.dev_previewed = False
365
- ss.dev_ran = False
366
-
367
- st.file_uploader("Replace data (Excel)", type=["xlsx","xls"], key="dev_upload",
368
- on_change=_on_dev_upload, help="Limit 200MB per file • XLSX, XLS")
369
-
370
- if ss.dev_bytes and ss.dev_book:
371
- # Small status line under upload
372
- any_sheet = next(iter(ss.dev_book.values()))
373
- st.caption(f"Data loaded: {getattr(st.session_state.get('dev_upload'), 'name', 'file')} • "
374
- f"{any_sheet.shape[0]} rows × {any_sheet.shape[1]} cols")
375
-
376
- preview_clicked = st.button("Preview data", disabled=not bool(ss.dev_book))
377
- run_clicked = st.button("Run Model", disabled=not bool(ss.dev_book))
378
- proceed_clicked = st.button("Proceed to Prediction ▶", disabled=not ss.get("dev_ran", False))
379
-
380
- # Modal preview (does NOT clear the uploaded file)
381
- if preview_clicked and ss.dev_book:
382
- with st.modal("Preview data"):
383
- st.write("Use the tabs below to inspect the uploaded data before running the model.")
384
- preview_modal_dev(ss.dev_book, FEATURES)
385
- if st.button("Close", type="primary"):
386
- ss.dev_previewed = True
387
- st.rerun()
388
-
389
- # Run model
390
- if run_clicked and ss.dev_book:
391
- # pick sheets
392
- book = ss.dev_book
393
- low2orig = {k.lower(): k for k in book.keys()}
394
- sh_train = None; sh_test=None
395
- for nm in ["train","training","training2"]:
396
- if nm in low2orig: sh_train = low2orig[nm]; break
397
- for nm in ["test","testing","testing2"]:
398
- if nm in low2orig: sh_test = low2orig[nm]; break
399
-
400
- if sh_train is None or sh_test is None:
401
- st.error("Workbook must include sheets named *Train/Training* and *Test/Testing* (any one of those).")
402
  else:
403
- df_tr = book[sh_train].copy()
404
- df_te = book[sh_test].copy()
405
- if ensure_cols(df_tr, FEATURES+[TARGET]) and ensure_cols(df_te, FEATURES+[TARGET]):
406
- # predict
407
- df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
408
- df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
409
- ss.results["Train"] = df_tr
410
- ss.results["Test"] = df_te
411
- ss.results["metrics_train"] = {
412
- "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
413
- "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
414
- "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"])
415
- }
416
- ss.results["metrics_test"] = {
417
- "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
418
- "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
419
- "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"])
420
- }
421
- tr_min = df_tr[FEATURES].min().to_dict()
422
- tr_max = df_tr[FEATURES].max().to_dict()
423
- ss.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
424
- ss.dev_ran = True
425
- st.rerun()
426
-
427
- # Results (if available)
428
- if ss.results.get("Train") is not None or ss.results.get("Test") is not None:
429
- tab1, tab2 = st.tabs(["Training", "Testing"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
- if ss.results.get("Train") is not None:
 
 
 
432
  with tab1:
433
- df = ss.results["Train"]; m = ss.results["metrics_train"]
434
  c1,c2,c3 = st.columns(3)
435
- c1.metric("R²", f"{m['R2']:.4f}")
436
- c2.metric("RMSE", f"{m['RMSE']:.4f}")
437
- c3.metric("MAE", f"{m['MAE']:.4f}")
438
-
439
- left, right = st.columns([0.58, 0.42])
440
  with left:
441
- st.plotly_chart(make_cross_plotly(df[TARGET], df["UCS_Pred"], height=440, width=640),
442
- use_container_width=True, theme=None)
 
 
443
  with right:
444
- st.plotly_chart(make_depth_track_plotly(df, include_actual=True, height=640, width=360),
445
- use_container_width=True, theme=None)
446
-
447
- if ss.results.get("Test") is not None:
 
448
  with tab2:
449
- df = ss.results["Test"]; m = ss.results["metrics_test"]
450
  c1,c2,c3 = st.columns(3)
451
- c1.metric("R²", f"{m['R2']:.4f}")
452
- c2.metric("RMSE", f"{m['RMSE']:.4f}")
453
- c3.metric("MAE", f"{m['MAE']:.4f}")
454
-
455
- left, right = st.columns([0.58, 0.42])
456
  with left:
457
- st.plotly_chart(make_cross_plotly(df[TARGET], df["UCS_Pred"], height=440, width=640),
458
- use_container_width=True, theme=None)
 
 
459
  with right:
460
- st.plotly_chart(make_depth_track_plotly(df, include_actual=True, height=640, width=360),
461
- use_container_width=True, theme=None)
462
-
463
- # =========================================================
464
- # PREDICTION PAGE
465
- # =========================================================
466
- if ss.app_step == "dev" and st.sidebar.button("→ Open Prediction in main area", key="force_pred"):
467
- ss.app_step = "predict"; st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- if ss.app_step == "predict":
470
  st.subheader("Prediction")
471
  st.write("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.")
472
 
473
- with st.sidebar:
474
- st.header("Prediction (Validation)")
475
- def _on_val_upload():
476
- file = st.session_state.get("val_upload")
477
- if file is not None:
478
- ss.val_bytes = file.getvalue()
479
- ss.val_book = parse_excel(ss.val_bytes)
480
-
481
- st.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload", on_change=_on_val_upload)
482
- predict_clicked = st.button("Predict", type="primary", use_container_width=True)
483
- st.button("⬅ Back", on_click=lambda: ss.update(app_step="dev"))
484
-
485
- if predict_clicked and ss.val_book:
486
- vname = list(ss.val_book.keys())[0]
487
- df_val = ss.val_book[vname].copy()
488
- if not ensure_cols(df_val, FEATURES):
489
- st.stop()
490
- df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
491
- ss.results["Validate"] = df_val
492
-
493
- ranges = ss.train_ranges
494
- oor_table = None; oor_pct = 0.0
495
- if ranges:
496
- viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
497
- any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
498
- if any_viol.any():
499
- offenders = df_val.loc[any_viol, FEATURES].copy()
500
- offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(
501
- lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
502
- offenders.index = offenders.index + 1; oor_table = offenders
503
-
504
- metrics_val = None
505
- if TARGET in df_val.columns:
506
- metrics_val = {
507
- "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
508
- "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
509
- "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"])
510
  }
511
- ss.results["metrics_val"] = metrics_val
512
- ss.results["summary_val"] = {
513
- "n_points": len(df_val),
514
- "pred_min": float(df_val["UCS_Pred"].min()),
515
- "pred_max": float(df_val["UCS_Pred"].max()),
516
- "oor_pct": oor_pct
517
- }
518
- ss.results["oor_table"] = oor_table
519
- st.experimental_rerun()
520
-
521
- if ss.results.get("Validate") is not None:
522
  st.subheader("Validation Results")
523
- sv = ss.results["summary_val"]; oor_table = ss.results.get("oor_table")
 
 
 
 
524
  c1,c2,c3,c4 = st.columns(4)
525
- c1.metric("# points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
526
  c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
527
 
528
- left,right = st.columns([0.58, 0.42])
529
  with left:
530
- if TARGET in ss.results["Validate"].columns:
531
  st.plotly_chart(
532
- make_cross_plotly(ss.results["Validate"][TARGET], ss.results["Validate"]["UCS_Pred"], height=440, width=640),
533
- use_container_width=True, theme=None
 
 
 
534
  )
535
  else:
536
- st.info("Actual UCS values are not available in the validation data.")
537
  with right:
538
  st.plotly_chart(
539
- make_depth_track_plotly(ss.results["Validate"], include_actual=(TARGET in ss.results["Validate"].columns),
540
- height=640, width=360),
541
- use_container_width=True, theme=None
 
 
542
  )
 
543
  if oor_table is not None:
544
- st.warning("Some validation rows contain inputs **outside** the training min–max. Review the table below.")
545
  st.dataframe(oor_table, use_container_width=True)
546
 
547
- # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  # Footer
549
- # =========================================================
550
  st.markdown("---")
551
  st.markdown(
552
- "<div style='text-align:center; color:#6b7280;'>"
553
- "ST_GeoMech_UCS © Smart Thinking"
554
- "</div>"
555
- "<div style='text-align:center; color:#6b7280;'>"
556
- "Visit our Website: "
557
- "<a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>"
558
- "</div>",
559
  unsafe_allow_html=True
560
  )
 
4
  import pandas as pd
5
  import numpy as np
6
  import joblib
7
+ import matplotlib
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pyplot as plt
10
+ from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
11
 
12
+ # NEW: Plotly for interactive charts (keeps styling the same)
13
  import plotly.graph_objects as go
 
 
 
14
 
15
+ # =========================
16
+ # Defaults
17
+ # =========================
18
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
19
  TARGET = "UCS"
20
  MODELS_DIR = Path("models")
21
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
22
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
23
 
24
+ COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
 
 
 
25
 
26
+ # =========================
27
+ # Page / Theme
28
+ # =========================
29
+ st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
30
+ st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
31
  st.markdown(
32
  """
33
  <style>
 
34
  .stApp { background: #FFFFFF; }
35
  section[data-testid="stSidebar"] { background: #F6F9FC; }
36
+ .block-container { padding-top: .5rem; padding-bottom: .5rem; }
37
+ .stButton>button{ background:#007bff; color:#fff; font-weight:bold; border-radius:8px; border:none; padding:10px 24px; }
38
+ .stButton>button:hover{ background:#0056b3; }
39
+ .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
40
+ .st-hero .brand { width:110px; height:110px; object-fit:contain; }
 
 
41
  .st-hero h1 { margin:0; line-height:1.05; }
42
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
43
+ [data-testid="stBlock"]{ margin-top:0 !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  </style>
45
  """,
46
  unsafe_allow_html=True
47
  )
48
 
49
+ # =========================
50
  # Helpers
51
+ # =========================
52
+ try:
53
+ dialog = st.dialog
54
+ except AttributeError:
55
+ # Fallback (expander) if st.dialog is unavailable
56
+ def dialog(title):
57
+ def deco(fn):
58
+ def wrapper(*args, **kwargs):
59
+ with st.expander(title, expanded=True):
60
+ return fn(*args, **kwargs)
61
+ return wrapper
62
+ return deco
63
+
64
+ def _get_model_url():
65
+ return (os.environ.get("MODEL_URL", "") or "").strip()
66
 
67
  def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
68
 
69
+ def ensure_cols(df, cols):
70
+ miss = [c for c in cols if c not in df.columns]
71
+ if miss:
72
+ st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
73
+ return False
74
+ return True
75
+
76
  @st.cache_resource(show_spinner=False)
77
  def load_model(model_path: str):
78
  return joblib.load(model_path)
 
83
  xl = pd.ExcelFile(bio)
84
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
85
 
86
+ def read_book_bytes(data_bytes: bytes):
87
+ if not data_bytes: return {}
88
+ try: return parse_excel(data_bytes)
89
+ except Exception as e:
90
+ st.error(f"Failed to read Excel: {e}"); return {}
 
91
 
92
+ def find_sheet(book, names):
93
+ low2orig = {k.lower(): k for k in book.keys()}
94
+ for nm in names:
95
+ if nm.lower() in low2orig: return low2orig[nm.lower()]
 
 
 
 
 
 
 
96
  return None
97
 
98
+ # ---------- ORIGINAL Matplotlib plotters (kept for reference) ----------
99
+ def cross_plot(actual, pred, title, size=(3.9, 3.9)):
100
+ fig, ax = plt.subplots(figsize=size, dpi=100)
101
+ ax.scatter(actual, pred, s=14, alpha=0.85, color=COLORS["pred"])
102
+ lo = float(np.nanmin([actual.min(), pred.min()]))
103
+ hi = float(np.nanmax([actual.max(), pred.max()]))
104
+ pad = 0.03 * (hi - lo if hi > lo else 1.0)
105
+ ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], '--', lw=1.2, color=COLORS["ref"])
106
+ ax.set_xlim(lo - pad, hi + pad); ax.set_ylim(lo - pad, hi + pad)
107
+ ax.set_aspect('equal', 'box')
108
+ ax.set_xlabel("Actual UCS"); ax.set_ylabel("Predicted UCS"); ax.set_title(title)
109
+ ax.grid(True, ls=":", alpha=0.4)
110
+ return fig
 
 
 
 
 
 
111
 
112
+ def depth_or_index_track(df, title=None, include_actual=True):
113
+ depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
114
+ fig_w = 3.1
115
+ fig_h = 7.6 if depth_col is not None else 7.2
116
+ fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=100)
117
+ if depth_col is not None:
118
+ ax.plot(df["UCS_Pred"], df[depth_col], '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
119
+ if include_actual and TARGET in df.columns:
120
+ ax.plot(df[TARGET], df[depth_col], ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
121
+ ax.set_ylabel(depth_col); ax.set_xlabel("UCS")
122
+ ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
123
+ else:
124
+ idx = np.arange(1, len(df) + 1)
125
+ ax.plot(df["UCS_Pred"], idx, '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
126
+ if include_actual and TARGET in df.columns:
127
+ ax.plot(df[TARGET], idx, ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
128
+ ax.set_ylabel("Point Index"); ax.set_xlabel("UCS")
129
+ ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
130
+ ax.grid(True, linestyle=":", alpha=0.4)
131
+ if title: ax.set_title(title, pad=8)
132
+ ax.legend(loc="best")
133
  return fig
134
 
135
+ # ---------- NEW: Plotly equivalents (interactive, same specs) ----------
136
+ def cross_plot_interactive(actual, pred, title, size=(3.9, 3.9)):
137
+ """Interactive cross-plot with the same look: blue points, dashed 1:1, equal axes, grid, title."""
138
+ a = pd.Series(actual).astype(float)
139
+ p = pd.Series(pred).astype(float)
140
  lo = float(np.nanmin([a.min(), p.min()]))
141
  hi = float(np.nanmax([a.max(), p.max()]))
142
+ pad = 0.03 * (hi - lo if hi > lo else 1.0)
143
+ x0, x1 = lo - pad, hi + pad
144
 
145
  fig = go.Figure()
146
+
147
+ # points
148
  fig.add_trace(go.Scatter(
149
+ x=a, y=p,
150
+ mode="markers",
151
+ marker=dict(size=6, color=COLORS["pred"]),
152
+ hovertemplate="Actual: %{x:.2f}<br>Pred: %{y:.2f}<extra></extra>",
153
+ name="Points",
154
+ showlegend=False
155
  ))
156
+
157
+ # 1:1
158
  fig.add_trace(go.Scatter(
159
+ x=[x0, x1], y=[x0, x1],
160
+ mode="lines",
161
+ line=dict(color=COLORS["ref"], width=1.2, dash="dash"),
162
+ hoverinfo="skip",
163
+ name="1:1",
164
+ showlegend=False
165
  ))
166
+
167
+ fig.update_layout(
168
+ title=title,
169
+ paper_bgcolor="#ffffff",
170
+ plot_bgcolor="#ffffff",
171
+ margin=dict(l=50, r=10, t=36, b=36),
172
+ hovermode="closest",
173
+ font=dict(size=13)
174
+ )
175
+ fig.update_xaxes(
176
+ title_text="<b>Actual UCS</b>",
177
+ range=[x0, x1],
178
+ ticks="outside", showline=True, linewidth=1.2, linecolor="#444",
179
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
180
+ automargin=True
181
+ )
182
+ fig.update_yaxes(
183
+ title_text="<b>Predicted UCS</b>",
184
+ range=[x0, x1],
185
+ ticks="outside", showline=True, linewidth=1.2, linecolor="#444",
186
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
187
+ scaleanchor="x", scaleratio=1,
188
+ automargin=True
189
+ )
190
+
191
+ # match your size ~ inches * 100 dpi
192
+ w = int(size[0] * 100)
193
+ h = int(size[1] * 100)
194
+ fig.update_layout(width=w, height=h)
195
  return fig
196
 
197
+ def depth_or_index_track_interactive(df, title=None, include_actual=True):
198
+ """Interactive narrow/tall UCS track: blue solid pred, yellow dotted actual, top x-axis, inverted y."""
199
+ depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
200
+ if depth_col is not None:
201
+ y = df[depth_col]
202
+ y_label = depth_col
203
+ else:
204
+ y = np.arange(1, len(df) + 1)
205
+ y_label = "Point Index"
206
+
207
  fig = go.Figure()
208
+
209
  # Predicted (solid blue)
210
  fig.add_trace(go.Scatter(
211
+ x=df["UCS_Pred"], y=y,
212
+ mode="lines",
213
+ line=dict(color=COLORS["pred"], width=1.8),
214
+ name="UCS_Pred",
215
+ hovertemplate="UCS_Pred: %{x:.2f}<br>"+y_label+": %{y}<extra></extra>"
216
  ))
217
+
218
  # Actual (dotted yellow)
219
  if include_actual and TARGET in df.columns:
220
  fig.add_trace(go.Scatter(
221
+ x=df[TARGET], y=y,
222
+ mode="lines",
223
+ line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
224
+ name="UCS (actual)",
225
+ hovertemplate="UCS (actual): %{x:.2f}<br>"+y_label+": %{y}<extra></extra>"
226
  ))
227
+
228
+ fig.update_layout(
229
+ title=title if title else None,
230
+ paper_bgcolor="#ffffff",
231
+ plot_bgcolor="#ffffff",
232
+ margin=dict(l=60, r=10, t=36, b=36),
233
+ hovermode="closest",
234
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0),
235
+ font=dict(size=13),
236
+ # keep it tall & narrow like your Matplotlib version (~3.1in x 7.6in @100dpi)
237
+ width=int(3.1 * 100),
238
+ height=int((7.6 if depth_col is not None else 7.2) * 100),
239
+ )
240
+ fig.update_xaxes(
241
+ title_text="<b>UCS</b>", side="top",
242
+ ticks="outside", showline=True, linewidth=1.2, linecolor="#444",
243
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
244
+ automargin=True
245
+ )
246
+ fig.update_yaxes(
247
+ title_text=f"<b>{y_label}</b>",
248
+ autorange="reversed",
249
+ ticks="outside", showline=True, linewidth=1.2, linecolor="#444",
250
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
251
+ automargin=True
252
+ )
253
  return fig
254
 
255
+ def export_workbook(sheets_dict, summary_df=None):
256
+ try: import openpyxl # noqa
257
+ except Exception: raise RuntimeError("Export requires openpyxl. Please add it to requirements or install it.")
258
+ buf = io.BytesIO()
259
+ with pd.ExcelWriter(buf, engine="openpyxl") as xw:
260
+ for name, frame in sheets_dict.items():
261
+ frame.to_excel(xw, sheet_name=name[:31], index=False)
262
+ if summary_df is not None: summary_df.to_excel(xw, sheet_name="Summary", index=False)
263
+ return buf.getvalue()
264
 
265
+ def toast(msg):
266
+ try: st.toast(msg)
267
+ except Exception: st.info(msg)
 
 
 
 
 
 
268
 
269
+ def infer_features_from_model(m):
 
 
 
270
  try:
271
+ if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
272
+ return [str(x) for x in m.feature_names_in_]
273
+ except Exception: pass
274
+ try:
275
+ if hasattr(m, "steps") and len(m.steps):
276
+ last = m.steps[-1][1]
277
+ if hasattr(last, "feature_names_in_") and len(last.feature_names_in_):
278
+ return [str(x) for x in last.feature_names_in_]
279
+ except Exception: pass
280
+ return None
281
+
282
+ def inline_logo(path="logo.png") -> str:
283
+ try:
284
+ p = Path(path)
285
+ if not p.exists(): return ""
286
+ return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
287
  except Exception:
288
+ return ""
289
+
290
+ # ---------- Preview modal helpers (unchanged; still Matplotlib) ----------
291
+ def make_index_tracks(df: pd.DataFrame, cols: list[str]):
292
+ cols = [c for c in cols if c in df.columns]
293
+ n = len(cols)
294
+ if n == 0:
295
+ fig, ax = plt.subplots(figsize=(4, 2))
296
+ ax.text(0.5, 0.5, "No selected columns in sheet", ha="center", va="center")
297
+ ax.axis("off"); return fig
298
+ width_per = 2.2
299
+ fig_h = 7.0
300
+ fig, axes = plt.subplots(1, n, figsize=(width_per * n, fig_h), sharey=True, dpi=100)
301
+ if n == 1: axes = [axes]
302
+ idx = np.arange(1, len(df) + 1)
303
+ for ax, col in zip(axes, cols):
304
+ ax.plot(df[col], idx, '-', lw=1.4, color="#333")
305
+ ax.set_xlabel(col)
306
+ ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
307
+ ax.grid(True, linestyle=":", alpha=0.3)
308
+ axes[0].set_ylabel("Point Index")
309
+ return fig
310
 
311
+ def stats_table(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
312
+ cols = [c for c in cols if c in df.columns]
313
+ if not cols:
314
+ return pd.DataFrame({"Feature": [], "Min": [], "Max": [], "Mean": [], "Std": []})
315
+ out = df[cols].agg(['min', 'max', 'mean', 'std']).T
316
+ out = out.rename(columns={"min": "Min", "max": "Max", "mean": "Mean", "std": "Std"})
317
+ return out.reset_index().rename(columns={"index": "Feature"})
318
+
319
+ @dialog("Preview data")
320
+ def preview_modal_dev(book: dict[str, pd.DataFrame], feature_cols: list[str]):
321
+ if not book:
322
+ st.info("No data loaded yet."); return
323
+ sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
324
+ sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
325
+ tabs, data = [], []
326
+ if sh_train: tabs.append("Train"); data.append(book[sh_train])
327
+ if sh_test: tabs.append("Test"); data.append(book[sh_test])
328
+ if not tabs:
329
+ first_name = list(book.keys())[0]
330
+ tabs = [first_name]; data = [book[first_name]]
331
+ st.write("Use the tabs to switch between Train/Test views (if available).")
332
+ t_objs = st.tabs(tabs)
333
+ for t, df in zip(t_objs, data):
334
+ with t:
335
+ t1, t2 = st.tabs(["Tracks", "Summary"])
336
+ with t1: st.pyplot(make_index_tracks(df, feature_cols), use_container_width=True)
337
+ with t2: st.dataframe(stats_table(df, feature_cols), use_container_width=True)
338
+
339
+ @dialog("Preview data")
340
+ def preview_modal_val(book: dict[str, pd.DataFrame], feature_cols: list[str]):
341
+ if not book:
342
+ st.info("No data loaded yet."); return
343
+ vname = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
344
+ df = book[vname]
345
+ t1, t2 = st.tabs(["Tracks", "Summary"])
346
+ with t1: st.pyplot(make_index_tracks(df, feature_cols), use_container_width=True)
347
+ with t2: st.dataframe(stats_table(df, feature_cols), use_container_width=True)
348
+
349
+ # =========================
350
+ # Model presence
351
+ # =========================
352
+ MODEL_URL = _get_model_url()
353
+
354
+ def ensure_model_present() -> Path:
355
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
356
+ if p.exists() and p.stat().st_size > 0:
357
  return p
358
+ if not MODEL_URL:
359
+ return None
360
+ try:
361
+ import requests
362
+ DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
363
+ with st.status("Downloading model…", expanded=False):
364
+ with requests.get(MODEL_URL, stream=True, timeout=30) as r:
365
  r.raise_for_status()
366
  with open(DEFAULT_MODEL, "wb") as f:
367
  for chunk in r.iter_content(chunk_size=1<<20):
368
+ if chunk: f.write(chunk)
369
+ return DEFAULT_MODEL
370
+ except Exception as e:
371
+ st.error(f"Failed to download model from MODEL_URL: {e}")
372
+ return None
373
 
374
  model_path = ensure_model_present()
375
  if not model_path:
376
+ st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).")
377
  st.stop()
378
 
 
379
  try:
380
  model = load_model(str(model_path))
381
  except Exception as e:
382
  st.error(f"Failed to load model: {model_path}\n{e}")
383
  st.stop()
384
 
385
+ # Meta overrides / inference
386
  meta_path = MODELS_DIR / "meta.json"
387
  if meta_path.exists():
388
  try:
389
  meta = json.loads(meta_path.read_text(encoding="utf-8"))
390
+ FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
391
+ except Exception: pass
 
 
392
  else:
393
  infer = infer_features_from_model(model)
394
  if infer: FEATURES = infer
395
 
396
+ # =========================
397
+ # Session state
398
+ # =========================
399
+ if "app_step" not in st.session_state: st.session_state.app_step = "intro"
400
+ if "results" not in st.session_state: st.session_state.results = {}
401
+ if "train_ranges" not in st.session_state: st.session_state.train_ranges = None
402
+
403
+ # Dev page state (persist file)
404
+ for k, v in {
405
+ "dev_ready": False,
406
+ "dev_file_loaded": False,
407
+ "dev_previewed": False,
408
+ "dev_file_signature": None,
409
+ "dev_preview_request": False,
410
+ "dev_file_bytes": b"",
411
+ "dev_file_name": "",
412
+ "dev_file_rows": 0,
413
+ "dev_file_cols": 0,
414
+ }.items():
415
+ if k not in st.session_state: st.session_state[k] = v
416
+
417
+ # =========================
418
+ # Hero header
419
+ # =========================
420
+ def inline_logo(path="logo.png") -> str:
421
+ try:
422
+ p = Path(path)
423
+ if not p.exists(): return ""
424
+ return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
425
+ except Exception:
426
+ return ""
427
+
428
  st.markdown(
429
  f"""
430
  <div class="st-hero">
 
438
  unsafe_allow_html=True,
439
  )
440
 
441
+ # =========================
442
  # INTRO PAGE
443
+ # =========================
444
+ if st.session_state.app_step == "intro":
445
  st.header("Welcome!")
446
  st.markdown(
447
+ "This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data."
 
 
448
  )
449
+ st.subheader("Required Input Columns")
450
+ st.markdown(
451
+ "- Q, gpm — Flow rate (gallons per minute) \n"
452
+ "- SPP(psi) — Stand pipe pressure \n"
453
+ "- T (kft.lbf) — Torque (thousand foot-pounds) \n"
454
+ "- WOB (klbf) Weight on bit \n"
455
+ "- ROP (ft/h) — Rate of penetration"
456
+ )
457
+ st.subheader("How It Works")
458
+ st.markdown(
459
+ "1. **Upload your data to build the case and preview the performance of our model.** \n"
460
+ "2. Click **Run Model** to compute metrics and plots. \n"
461
+ "3. Click **Proceed to Prediction** to validate on a new dataset. \n"
462
+ "4. Export results to Excel at any time."
463
+ )
464
+ if st.button("Start Showcase", type="primary", key="start_showcase"):
465
+ st.session_state.app_step = "dev"; st.rerun()
466
+
467
+ # =========================
468
+ # MODEL DEVELOPMENT
469
+ # =========================
470
+ if st.session_state.app_step == "dev":
471
+ st.sidebar.header("Model Development Data")
472
+ dev_label = "Upload Data (Excel)" if not st.session_state.dev_file_name else "Replace data (Excel)"
473
+ train_test_file = st.sidebar.file_uploader(dev_label, type=["xlsx","xls"], key="dev_upload")
474
+
475
+ # Detect new/changed file and PERSIST BYTES
476
+ if train_test_file is not None:
477
+ try:
478
+ file_bytes = train_test_file.getvalue()
479
+ size = len(file_bytes)
480
+ except Exception:
481
+ file_bytes = b""
482
+ size = 0
483
+ sig = (train_test_file.name, size)
484
+ if sig != st.session_state.dev_file_signature and size > 0:
485
+ st.session_state.dev_file_signature = sig
486
+ st.session_state.dev_file_name = train_test_file.name
487
+ st.session_state.dev_file_bytes = file_bytes
488
+ # Inspect first sheet for rows/cols
489
+ _book_tmp = read_book_bytes(file_bytes)
490
+ if _book_tmp:
491
+ first_df = next(iter(_book_tmp.values()))
492
+ st.session_state.dev_file_rows = int(first_df.shape[0])
493
+ st.session_state.dev_file_cols = int(first_df.shape[1])
494
+ st.session_state.dev_file_loaded = True
495
+ st.session_state.dev_previewed = False
496
+ st.session_state.dev_ready = False
497
+
498
+ # Sidebar caption (from persisted info)
499
+ if st.session_state.dev_file_loaded:
500
+ st.sidebar.caption(
501
+ f"**Data loaded:** {st.session_state.dev_file_name} • "
502
+ f"{st.session_state.dev_file_rows} rows × {st.session_state.dev_file_cols} cols"
503
  )
504
 
505
+ # Sidebar actions
506
+ preview_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded)
507
+ if preview_btn and st.session_state.dev_file_loaded:
508
+ st.session_state.dev_preview_request = True
509
+
510
+ run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
511
+
512
+ proceed_clicked = st.sidebar.button(
513
+ "Proceed to Prediction ▶",
514
+ use_container_width=True,
515
+ disabled=not st.session_state.dev_ready
516
+ )
517
+ if proceed_clicked and st.session_state.dev_ready:
518
+ st.session_state.app_step = "predict"; st.rerun()
519
+
520
+ # ----- ALWAYS-ON TOP: Title + helper -----
521
+ helper_top = st.container()
522
+ with helper_top:
523
+ st.subheader("Model Development")
524
+ if st.session_state.dev_ready:
525
+ st.success("Case has been built and results are displayed below.")
526
+ elif st.session_state.dev_file_loaded and st.session_state.dev_previewed:
527
+ st.info("Previewed — now click **Run Model** to build the case.")
528
+ elif st.session_state.dev_file_loaded:
529
+ st.info("📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  else:
531
+ st.write("**Upload your data to build a case, then run the model to review development performance.**")
532
+
533
+ # If user clicked preview, open modal *after* helper so helper stays on top
534
+ if st.session_state.dev_preview_request and st.session_state.dev_file_bytes:
535
+ _book = read_book_bytes(st.session_state.dev_file_bytes)
536
+ st.session_state.dev_previewed = True
537
+ st.session_state.dev_preview_request = False
538
+ preview_modal_dev(_book, FEATURES)
539
+
540
+ # Run model (from persisted bytes)
541
+ if run_btn and st.session_state.dev_file_bytes:
542
+ with st.status("Processing…", expanded=False) as status:
543
+ book = read_book_bytes(st.session_state.dev_file_bytes)
544
+ if not book: status.update(label="Failed to read workbook.", state="error"); st.stop()
545
+ status.update(label="Workbook read ")
546
+ sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
547
+ sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
548
+ if sh_train is None or sh_test is None:
549
+ status.update(label="Workbook must include Train/Training/training2 and Test/Testing/testing2.", state="error"); st.stop()
550
+ df_tr = book[sh_train].copy(); df_te = book[sh_test].copy()
551
+ if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])):
552
+ status.update(label="Missing required columns.", state="error"); st.stop()
553
+ status.update(label="Columns validated ✓"); status.update(label="Predicting…")
554
+
555
+ df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
556
+ df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
557
+ st.session_state.results["Train"] = df_tr; st.session_state.results["Test"] = df_te
558
+
559
+ st.session_state.results["metrics_train"] = {
560
+ "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
561
+ "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
562
+ "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]),
563
+ }
564
+ st.session_state.results["metrics_test"] = {
565
+ "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
566
+ "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
567
+ "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
568
+ }
569
+
570
+ tr_min = df_tr[FEATURES].min().to_dict(); tr_max = df_tr[FEATURES].max().to_dict()
571
+ st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
572
+
573
+ st.session_state.dev_ready = True
574
+ status.update(label="Done ✓", state="complete"); toast("Model run complete 🚀")
575
+ st.rerun()
576
 
577
+ # Results (NOW USING INTERACTIVE PLOTS)
578
+ if ("Train" in st.session_state.results) or ("Test" in st.session_state.results):
579
+ tab1, tab2 = st.tabs(["Training", "Testing"])
580
+ if "Train" in st.session_state.results:
581
  with tab1:
582
+ df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
583
  c1,c2,c3 = st.columns(3)
584
+ c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
585
+ left, right = st.columns([0.9, 0.55])
 
 
 
586
  with left:
587
+ st.plotly_chart(
588
+ cross_plot_interactive(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted", size=(3.9,3.9)),
589
+ use_container_width=True, config={"displayModeBar": False}
590
+ )
591
  with right:
592
+ st.plotly_chart(
593
+ depth_or_index_track_interactive(df, title=None, include_actual=True),
594
+ use_container_width=True, config={"displayModeBar": False}
595
+ )
596
+ if "Test" in st.session_state.results:
597
  with tab2:
598
+ df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
599
  c1,c2,c3 = st.columns(3)
600
+ c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
601
+ left, right = st.columns([0.9, 0.55])
 
 
 
602
  with left:
603
+ st.plotly_chart(
604
+ cross_plot_interactive(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted", size=(3.9,3.9)),
605
+ use_container_width=True, config={"displayModeBar": False}
606
+ )
607
  with right:
608
+ st.plotly_chart(
609
+ depth_or_index_track_interactive(df, title=None, include_actual=True),
610
+ use_container_width=True, config={"displayModeBar": False}
611
+ )
612
+
613
+ st.markdown("---")
614
+ sheets = {}; rows = []
615
+ if "Train" in st.session_state.results:
616
+ sheets["Train_with_pred"] = st.session_state.results["Train"]
617
+ rows.append({"Split":"Train", **{k:round(v,6) for k,v in st.session_state.results["metrics_train"].items()}})
618
+ if "Test" in st.session_state.results:
619
+ sheets["Test_with_pred"] = st.session_state.results["Test"]
620
+ rows.append({"Split":"Test", **{k:round(v,6) for k,v in st.session_state.results["metrics_test"].items()}})
621
+ summary_df = pd.DataFrame(rows) if rows else None
622
+ try:
623
+ data_bytes = export_workbook(sheets, summary_df)
624
+ st.download_button("Export Development Results to Excel",
625
+ data=data_bytes, file_name="UCS_Dev_Results.xlsx",
626
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
627
+ except RuntimeError as e:
628
+ st.warning(str(e))
629
+
630
+ # =========================
631
+ # PREDICTION (Validation)
632
+ # =========================
633
+ if st.session_state.app_step == "predict":
634
+ st.sidebar.header("Prediction (Validation)")
635
+ validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload")
636
+ if validation_file is not None:
637
+ _book_tmp = read_book_bytes(validation_file.getvalue())
638
+ if _book_tmp:
639
+ first_df = next(iter(_book_tmp.values()))
640
+ st.sidebar.caption(f"**Data loaded:** {validation_file.name} • {first_df.shape[0]} rows × {first_df.shape[1]} cols")
641
+
642
+ preview_val_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=(validation_file is None))
643
+ if preview_val_btn and validation_file is not None:
644
+ _book = read_book_bytes(validation_file.getvalue())
645
+ preview_modal_val(_book, FEATURES)
646
+
647
+ predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
648
+ st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
649
 
 
650
  st.subheader("Prediction")
651
  st.write("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.")
652
 
653
+ if predict_btn and validation_file is not None:
654
+ with st.status("Predicting…", expanded=False) as status:
655
+ vbook = read_book_bytes(validation_file.getvalue())
656
+ if not vbook: status.update(label="Could not read the Validation Excel.", state="error"); st.stop()
657
+ status.update(label="Workbook read ✓")
658
+ vname = find_sheet(vbook, ["Validation","Validate","validation2","Val","val"]) or list(vbook.keys())[0]
659
+ df_val = vbook[vname].copy()
660
+ if not ensure_cols(df_val, FEATURES): status.update(label="Missing required columns.", state="error"); st.stop()
661
+ status.update(label="Columns validated ")
662
+ df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
663
+ st.session_state.results["Validate"] = df_val
664
+
665
+ ranges = st.session_state.train_ranges; oor_table = None; oor_pct = 0.0
666
+ if ranges:
667
+ viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
668
+ any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
669
+ if any_viol.any():
670
+ offenders = df_val.loc[any_viol, FEATURES].copy()
671
+ offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
672
+ offenders.index = offenders.index + 1; oor_table = offenders
673
+
674
+ metrics_val = None
675
+ if TARGET in df_val.columns:
676
+ metrics_val = {
677
+ "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
678
+ "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
679
+ "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"])
680
+ }
681
+ st.session_state.results["metrics_val"] = metrics_val
682
+ st.session_state.results["summary_val"] = {
683
+ "n_points": len(df_val),
684
+ "pred_min": float(df_val["UCS_Pred"].min()),
685
+ "pred_max": float(df_val["UCS_Pred"].max()),
686
+ "oor_pct": oor_pct
 
 
 
687
  }
688
+ st.session_state.results["oor_table"] = oor_table
689
+ status.update(label="Predictions ready ✓", state="complete")
690
+
691
+ if "Validate" in st.session_state.results:
 
 
 
 
 
 
 
692
  st.subheader("Validation Results")
693
+ sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
694
+
695
+ if sv["oor_pct"] > 0:
696
+ st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
697
+
698
  c1,c2,c3,c4 = st.columns(4)
699
+ c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
700
  c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
701
 
702
+ left, right = st.columns([0.9, 0.55])
703
  with left:
704
+ if TARGET in st.session_state.results["Validate"].columns:
705
  st.plotly_chart(
706
+ cross_plot_interactive(st.session_state.results["Validate"][TARGET],
707
+ st.session_state.results["Validate"]["UCS_Pred"],
708
+ "Validation: Actual vs Predicted",
709
+ size=(3.9,3.9)),
710
+ use_container_width=True, config={"displayModeBar": False}
711
  )
712
  else:
713
+ st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
714
  with right:
715
  st.plotly_chart(
716
+ depth_or_index_track_interactive(
717
+ st.session_state.results["Validate"], title=None,
718
+ include_actual=(TARGET in st.session_state.results["Validate"].columns)
719
+ ),
720
+ use_container_width=True, config={"displayModeBar": False}
721
  )
722
+
723
  if oor_table is not None:
724
+ st.write("*Out-of-range rows (vs. Training min–max):*")
725
  st.dataframe(oor_table, use_container_width=True)
726
 
727
+ st.markdown("---")
728
+ sheets = {"Validate_with_pred": st.session_state.results["Validate"]}
729
+ rows = []
730
+ for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]:
731
+ m = st.session_state.results.get(key)
732
+ if m: rows.append({"Split": name, **{k: round(v,6) for k,v in m.items()}})
733
+ summary_df = pd.DataFrame(rows) if rows else None
734
+ try:
735
+ data_bytes = export_workbook(sheets, summary_df)
736
+ st.download_button("Export Validation Results to Excel",
737
+ data=data_bytes, file_name="UCS_Validation_Results.xlsx",
738
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
739
+ except RuntimeError as e:
740
+ st.warning(str(e))
741
+
742
+ # =========================
743
  # Footer
744
+ # =========================
745
  st.markdown("---")
746
  st.markdown(
747
+ """
748
+ <div style='text-align:center; color:#6b7280; line-height:1.6'>
749
+ ST_GeoMech_UCS • © Smart Thinking<br/>
750
+ <strong>Visit our website:</strong> <a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>
751
+ </div>
752
+ """,
 
753
  unsafe_allow_html=True
754
  )