UCS2014 commited on
Commit
ac61d22
·
verified ·
1 Parent(s): fa32f46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +512 -517
app.py CHANGED
@@ -1,114 +1,117 @@
1
- import io, json, os, base64
 
2
  from pathlib import Path
3
- import streamlit as st
4
- import pandas as pd
5
  import numpy as np
 
 
6
  import joblib
7
- import matplotlib
8
- matplotlib.use("Agg")
9
- import matplotlib.pyplot as plt
10
- from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
11
 
12
- # Plotly (for interactivity)
13
- HAVE_PLOTLY = True
14
- try:
15
- import plotly.graph_objects as go
16
- from plotly.subplots import make_subplots
17
- except Exception:
18
- HAVE_PLOTLY = False
19
-
20
- # ---------------- Defaults ----------------
21
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
22
  TARGET = "UCS"
 
23
  MODELS_DIR = Path("models")
24
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
25
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
26
 
27
- COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
 
 
 
 
28
 
29
- # ---------------- Page / Theme ----------------
 
 
30
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
31
- st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
32
- st.markdown(
33
- """
34
- <style>
35
- .stApp { background: #FFFFFF; }
36
- section[data-testid="stSidebar"] { background: #F6F9FC; }
37
- .block-container { padding-top: .5rem; padding-bottom: .5rem; }
38
-
39
- /* Default Streamlit button style (Run, Predict remain blue) */
40
- .stButton>button{
41
- background:#0d6efd; color:#fff; font-weight:bold; border-radius:8px; border:none; padding:10px 24px;
42
- }
43
- .stButton>button:hover{ filter: brightness(0.92); }
44
-
45
- /* Hero */
46
- .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
47
- .st-hero .brand { width:110px; height:110px; object-fit:contain; }
48
- .st-hero h1 { margin:0; line-height:1.05; }
49
- .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
50
- [data-testid="stBlock"]{ margin-top:0 !important; }
51
-
52
- /* Color the sidebar buttons by order inside our wrappers */
53
- .dev-actions > div.stButton:nth-child(1) button { background:#f59e0b !important; } /* Preview (orange) */
54
- .dev-actions > div.stButton:nth-child(2) button { background:#0d6efd !important; } /* Run (blue) */
55
- .dev-actions > div.stButton:nth-child(3) button { background:#198754 !important; } /* Proceed (green) */
56
-
57
- .val-actions > div.stButton:nth-child(1) button { background:#f59e0b !important; } /* Preview (orange) */
58
- .val-actions > div.stButton:nth-child(2) button { background:#0d6efd !important; } /* Predict (blue) */
59
-
60
- .dev-actions .stButton button:disabled,
61
- .val-actions .stButton button:disabled{ filter: grayscale(40%); opacity:.6; }
62
- </style>
63
- """,
64
- unsafe_allow_html=True
65
- )
66
 
67
- # ---------------- Helpers ----------------
68
- try:
69
- dialog = st.dialog
70
- except AttributeError:
71
- def dialog(title):
72
- def deco(fn):
73
- def wrapper(*args, **kwargs):
74
- with st.expander(title, expanded=True):
75
- return fn(*args, **kwargs)
76
- return wrapper
77
- return deco
78
-
79
- def _get_model_url(): return (os.environ.get("MODEL_URL", "") or "").strip()
80
- def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
81
-
82
- def ensure_cols(df, cols):
83
- miss = [c for c in cols if c not in df.columns]
84
- if miss: st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}"); return False
85
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- @st.cache_resource(show_spinner=False)
88
- def load_model(model_path: str): return joblib.load(model_path)
 
 
 
 
89
 
90
  @st.cache_data(show_spinner=False)
91
- def parse_excel(data_bytes: bytes):
92
  bio = io.BytesIO(data_bytes)
93
  xl = pd.ExcelFile(bio)
94
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
95
 
96
- def read_book_bytes(data_bytes: bytes):
97
- if not data_bytes: return {}
98
- try: return parse_excel(data_bytes)
99
- except Exception as e: st.error(f"Failed to read Excel: {e}"); return {}
100
-
101
- def find_sheet(book, names):
102
- low2orig = {k.lower(): k for k in book.keys()}
103
- for nm in names:
104
- if nm.lower() in low2orig: return low2orig[nm.lower()]
105
- return None
106
 
107
- def toast(msg):
108
- try: st.toast(msg)
109
- except Exception: st.info(msg)
110
 
111
  def infer_features_from_model(m):
 
112
  try:
113
  if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
114
  return [str(x) for x in m.feature_names_in_]
@@ -121,56 +124,115 @@ def infer_features_from_model(m):
121
  except Exception: pass
122
  return None
123
 
124
- def inline_logo(path="logo.png") -> str:
125
- try:
126
- p = Path(path)
127
- if not p.exists(): return ""
128
- return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
129
- except Exception:
130
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- def export_workbook(sheets_dict: dict[str, pd.DataFrame], summary_df: pd.DataFrame|None):
 
 
133
  try:
134
- import openpyxl # noqa
 
 
135
  except Exception:
136
- raise RuntimeError("Export requires openpyxl. Please add it to requirements.txt")
137
- buf = io.BytesIO()
138
- with pd.ExcelWriter(buf, engine="openpyxl") as xw:
139
- for name, frame in sheets_dict.items():
140
- frame.to_excel(xw, sheet_name=name[:31], index=False)
141
- if summary_df is not None:
142
- summary_df.to_excel(xw, sheet_name="Summary", index=False)
143
- return buf.getvalue()
144
-
145
- # ---------- Plotting (Plotly first, MPL fallback) ----------
146
- def cross_plotly(actual, pred, title):
 
 
 
 
 
 
 
 
 
 
147
  lo = float(np.nanmin([actual.min(), pred.min()]))
148
  hi = float(np.nanmax([actual.max(), pred.max()]))
149
  pad = 0.03 * (hi - lo if hi > lo else 1.0)
 
150
  fig = go.Figure()
151
  fig.add_trace(go.Scatter(
152
  x=actual, y=pred, mode="markers",
153
  marker=dict(size=6, color=COLORS["pred"]),
154
  hovertemplate="Actual: %{x:.2f}<br>Pred: %{y:.2f}<extra></extra>",
155
- name="Points"
156
  ))
157
  fig.add_trace(go.Scatter(
158
  x=[lo - pad, hi + pad], y=[lo - pad, hi + pad],
159
  mode="lines", line=dict(dash="dash", width=1.5, color=COLORS["ref"]),
160
- hoverinfo="skip", showlegend=False
161
  ))
162
- fig.update_layout(title=title, margin=dict(l=10, r=10, t=40, b=10), height=350,
163
- legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0))
164
- fig.update_xaxes(title_text="Actual UCS", scaleanchor="y", scaleratio=1)
165
- fig.update_yaxes(title_text="Predicted UCS")
 
 
 
 
 
 
 
 
166
  return fig
167
 
168
  def track_plotly(df, include_actual=True):
 
169
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
170
  if depth_col is not None:
171
  y = df[depth_col]; y_label = depth_col
172
  else:
173
  y = np.arange(1, len(df) + 1); y_label = "Point Index"
 
174
  fig = go.Figure()
175
  fig.add_trace(go.Scatter(
176
  x=df["UCS_Pred"], y=y, mode="lines",
@@ -185,170 +247,88 @@ def track_plotly(df, include_actual=True):
185
  name="UCS (actual)",
186
  hovertemplate="UCS (actual): %{x:.2f}<br>"+y_label+": %{y}<extra></extra>"
187
  ))
188
- fig.update_yaxes(autorange="reversed", title_text=y_label)
189
- fig.update_xaxes(title_text="UCS", side="top")
190
- fig.update_layout(margin=dict(l=10, r=10, t=40, b=10), height=650,
191
- legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0))
 
 
 
 
 
 
 
 
 
 
192
  return fig
193
 
194
  def make_index_tracks_plotly(df: pd.DataFrame, cols: list[str]):
 
 
 
195
  cols = [c for c in cols if c in df.columns]
196
  if not cols:
197
  fig = go.Figure()
198
  fig.add_annotation(text="No selected columns in sheet", showarrow=False, x=0.5, y=0.5)
199
  fig.update_xaxes(visible=False); fig.update_yaxes(visible=False)
200
- fig.update_layout(height=200, margin=dict(l=10,r=10,t=10,b=10))
 
201
  return fig
 
202
  n = len(cols)
203
- fig = make_subplots(rows=1, cols=n, shared_yaxes=True, horizontal_spacing=0.05) # <-- FIX
 
204
  idx = np.arange(1, len(df) + 1)
 
205
  for i, col in enumerate(cols, start=1):
206
  fig.add_trace(
207
- go.Scatter(x=df[col], y=idx, mode="lines", line=dict(color="#333", width=1.2),
208
- hovertemplate=f"{col}: "+"%{x:.2f}<br>Index: %{y}<extra></extra>",
209
- name=col, showlegend=False),
 
 
 
 
 
 
 
 
210
  row=1, col=i
211
  )
212
- fig.update_xaxes(title_text=col, side="top", row=1, col=i)
213
- fig.update_yaxes(autorange="reversed", title_text="Point Index", row=1, col=1)
214
- fig.update_layout(height=650, margin=dict(l=10, r=10, t=40, b=10))
215
- return fig
216
 
217
- # MPL fallbacks
218
- def cross_plot_mpl(actual, pred, title, size=(3.9,3.9)):
219
- fig, ax = plt.subplots(figsize=size, dpi=100)
220
- ax.scatter(actual, pred, s=14, alpha=0.85, color=COLORS["pred"])
221
- lo = float(np.nanmin([actual.min(), pred.min()])); hi = float(np.nanmax([actual.max(), pred.max()]))
222
- pad = 0.03 * (hi - lo if hi > lo else 1.0)
223
- ax.plot([lo-pad, hi+pad], [lo-pad, hi+pad], '--', lw=1.2, color=COLORS["ref"])
224
- ax.set_xlim(lo-pad, hi+pad); ax.set_ylim(lo-pad, hi+pad); ax.set_aspect('equal','box')
225
- ax.set_xlabel("Actual UCS"); ax.set_ylabel("Predicted UCS"); ax.set_title(title); ax.grid(True, ls=":", alpha=0.4)
226
- return fig
227
-
228
- def depth_or_index_track_mpl(df, title=None, include_actual=True):
229
- depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
230
- fig, ax = plt.subplots(figsize=(3.1, 7.2), dpi=100)
231
- if depth_col is not None:
232
- ax.plot(df["UCS_Pred"], df[depth_col], '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
233
- if include_actual and TARGET in df.columns:
234
- ax.plot(df[TARGET], df[depth_col], ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
235
- ax.set_ylabel(depth_col); ax.set_xlabel("UCS"); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
236
- else:
237
- idx = np.arange(1, len(df) + 1)
238
- ax.plot(df["UCS_Pred"], idx, '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
239
- if include_actual and TARGET in df.columns:
240
- ax.plot(df[TARGET], idx, ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
241
- ax.set_ylabel("Point Index"); ax.set_xlabel("UCS"); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
242
- ax.grid(True, linestyle=":", alpha=0.4);
243
- if title: ax.set_title(title, pad=8)
244
- ax.legend(loc="best")
245
  return fig
246
 
247
- def stats_table(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
248
- cols = [c for c in cols if c in df.columns]
249
- if not cols:
250
- return pd.DataFrame({"Feature": [], "Min": [], "Max": [], "Mean": [], "Std": []})
251
- out = df[cols].agg(['min', 'max', 'mean', 'std']).T
252
- out = out.rename(columns={"min": "Min", "max": "Max", "mean": "Mean", "std": "Std"})
253
- return out.reset_index().rename(columns={"index": "Feature"})
254
-
255
- # ---------- Preview dialogs ----------
256
- @dialog("Preview data")
257
- def preview_modal_dev(book: dict[str, pd.DataFrame], feature_cols: list[str]):
258
- if not book:
259
- st.info("No data loaded yet."); return
260
- sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
261
- sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
262
- tabs, data = [], []
263
- if sh_train: tabs.append("Train"); data.append(book[sh_train])
264
- if sh_test: tabs.append("Test"); data.append(book[sh_test])
265
- if not tabs:
266
- first_name = list(book.keys())[0]
267
- tabs = [first_name]; data = [book[first_name]]
268
- st.write("Use the tabs to switch between Train/Test views (if available).")
269
- t_objs = st.tabs(tabs)
270
- for t, df in zip(t_objs, data):
271
- with t:
272
- t1, t2 = st.tabs(["Tracks", "Summary"])
273
- with t1:
274
- if HAVE_PLOTLY:
275
- st.plotly_chart(make_index_tracks_plotly(df, feature_cols), use_container_width=True, theme=None)
276
- else:
277
- st.pyplot(depth_or_index_track_mpl(df, title=None, include_actual=False), use_container_width=True)
278
- with t2:
279
- st.dataframe(stats_table(df, feature_cols), use_container_width=True)
280
-
281
- @dialog("Preview data")
282
- def preview_modal_val(book: dict[str, pd.DataFrame], feature_cols: list[str]):
283
- if not book:
284
- st.info("No data loaded yet."); return
285
- vname = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
286
- df = book[vname]
287
- t1, t2 = st.tabs(["Tracks", "Summary"])
288
- with t1:
289
- if HAVE_PLOTLY:
290
- st.plotly_chart(make_index_tracks_plotly(df, feature_cols), use_container_width=True, theme=None)
291
- else:
292
- st.pyplot(depth_or_index_track_mpl(df, title=None, include_actual=False), use_container_width=True)
293
- with t2:
294
- st.dataframe(stats_table(df, feature_cols), use_container_width=True)
295
-
296
- # ---------------- Model presence ----------------
297
- MODEL_URL = _get_model_url()
298
- def ensure_model_present() -> Path:
299
- for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
300
- if p.exists() and p.stat().st_size > 0:
301
- return p
302
- if not MODEL_URL: return None
303
- try:
304
- import requests
305
- DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
306
- with st.status("Downloading model…", expanded=False):
307
- with requests.get(MODEL_URL, stream=True, timeout=30) as r:
308
- r.raise_for_status()
309
- with open(DEFAULT_MODEL, "wb") as f:
310
- for chunk in r.iter_content(chunk_size=1<<20):
311
- if chunk: f.write(chunk)
312
- return DEFAULT_MODEL
313
- except Exception as e:
314
- st.error(f"Failed to download model from MODEL_URL: {e}")
315
- return None
316
-
317
- model_path = ensure_model_present()
318
- if not model_path:
319
- st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).")
320
- st.stop()
321
-
322
- try:
323
- model = load_model(str(model_path))
324
- except Exception as e:
325
- st.error(f"Failed to load model: {model_path}\n{e}")
326
- st.stop()
327
-
328
- # Meta overrides / inference
329
- meta_path = MODELS_DIR / "meta.json"
330
- if meta_path.exists():
331
- try:
332
- meta = json.loads(meta_path.read_text(encoding="utf-8"))
333
- FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
334
- except Exception: pass
335
- else:
336
- infer = infer_features_from_model(model)
337
- if infer: FEATURES = infer
338
-
339
- # ---------------- Session state ----------------
340
- if "app_step" not in st.session_state: st.session_state.app_step = "intro"
341
- if "results" not in st.session_state: st.session_state.results = {}
342
- if "train_ranges" not in st.session_state: st.session_state.train_ranges = None
343
-
344
- for k, v in {
345
- "dev_ready": False, "dev_file_loaded": False, "dev_previewed": False,
346
- "dev_file_signature": None, "dev_file_bytes": b"", "dev_file_name": "",
347
- "dev_file_rows": 0, "dev_file_cols": 0,
348
- }.items():
349
- if k not in st.session_state: st.session_state[k] = v
350
-
351
- # ---------------- Hero ----------------
352
  st.markdown(
353
  f"""
354
  <div class="st-hero">
@@ -359,305 +339,320 @@ st.markdown(
359
  </div>
360
  </div>
361
  """,
362
- unsafe_allow_html=True
363
  )
364
 
365
- # ---------------- INTRO ----------------
366
- if st.session_state.app_step == "intro":
 
 
367
  st.header("Welcome!")
368
- st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
369
- st.subheader("Required Input Columns")
370
  st.markdown(
371
- "- Q, gpm Flow rate (gallons per minute) \n"
372
- "- SPP(psi) Stand pipe pressure \n"
373
- "- T (kft.lbf) Torque (thousand foot-pounds) \n"
374
- "- WOB (klbf) — Weight on bit \n"
375
- "- ROP (ft/h) — Rate of penetration"
376
  )
377
- st.subheader("How It Works")
378
- st.markdown(
379
- "1. **Upload your data to build the case and preview the performance of our model.** \n"
380
- "2. Click **Run Model** to compute metrics and plots. \n"
381
- "3. Click **Proceed to Prediction** to validate on a new dataset. \n"
382
- "4. Export results to Excel at any time."
383
- )
384
- if st.button("Start Showcase", type="primary", key="start_showcase"):
385
- st.session_state.app_step = "dev"; st.rerun()
386
 
387
- # ---------------- DEVELOPMENT ----------------
388
- if st.session_state.app_step == "dev":
 
 
 
389
  st.sidebar.header("Model Development Data")
390
- dev_label = "Upload Data (Excel)" if not st.session_state.dev_file_name else "Replace data (Excel)"
391
- train_test_file = st.sidebar.file_uploader(dev_label, type=["xlsx","xls"], key="dev_upload")
392
-
393
- if train_test_file is not None:
394
- file_bytes = train_test_file.getvalue()
395
- size = len(file_bytes)
396
- sig = (train_test_file.name, size)
397
- if sig != st.session_state.dev_file_signature and size > 0:
398
- st.session_state.dev_file_signature = sig
399
- st.session_state.dev_file_name = train_test_file.name
400
- st.session_state.dev_file_bytes = file_bytes
401
- _book_tmp = read_book_bytes(file_bytes)
402
- if _book_tmp:
403
- first_df = next(iter(_book_tmp.values()))
404
- st.session_state.dev_file_rows = int(first_df.shape[0])
405
- st.session_state.dev_file_cols = int(first_df.shape[1])
406
- st.session_state.dev_file_loaded = True
407
- st.session_state.dev_previewed = False
408
- st.session_state.dev_ready = False
409
-
410
- if st.session_state.dev_file_loaded:
411
- st.sidebar.caption(
412
- f"**Data loaded:** {st.session_state.dev_file_name} • "
413
- f"{st.session_state.dev_file_rows} rows × {st.session_state.dev_file_cols} cols"
414
- )
415
 
416
- st.sidebar.markdown('<div class="dev-actions">', unsafe_allow_html=True)
417
- preview_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded)
418
- run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
419
- proceed_clicked = st.sidebar.button("Proceed to Prediction ▶", use_container_width=True, disabled=not st.session_state.dev_ready)
420
- st.sidebar.markdown('</div>', unsafe_allow_html=True)
421
-
422
- if proceed_clicked and st.session_state.dev_ready:
423
- st.session_state.app_step = "predict"; st.rerun()
424
-
425
- helper_top = st.container()
426
- with helper_top:
427
- st.subheader("Model Development")
428
- if st.session_state.dev_ready:
429
- st.success("Case has been built and results are displayed below.")
430
- elif st.session_state.dev_file_loaded and st.session_state.dev_previewed:
431
- st.info("Previewed ✓ — now click **Run Model** to build the case.")
432
- elif st.session_state.dev_file_loaded:
433
- st.info("📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
434
- else:
435
- st.write("**Upload your data to build a case, then run the model to review development performance.**")
436
-
437
- if preview_btn and st.session_state.dev_file_loaded and st.session_state.dev_file_bytes:
438
- _book = read_book_bytes(st.session_state.dev_file_bytes)
439
- st.session_state.dev_previewed = True
440
- preview_modal_dev(_book, FEATURES)
441
-
442
- if run_btn and st.session_state.dev_file_bytes:
443
- with st.status("Processing…", expanded=False) as status:
444
- book = read_book_bytes(st.session_state.dev_file_bytes)
445
- if not book: status.update(label="Failed to read workbook.", state="error"); st.stop()
446
- status.update(label="Workbook read ✓")
447
- sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
448
- sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
449
- if sh_train is None or sh_test is None:
450
- status.update(label="Workbook must include Train/Training/training2 and Test/Testing/testing2.", state="error"); st.stop()
451
- df_tr = book[sh_train].copy(); df_te = book[sh_test].copy()
452
- if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])):
453
- status.update(label="Missing required columns.", state="error"); st.stop()
454
- status.update(label="Columns validated "); status.update(label="Predicting…")
455
-
456
- df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
457
- df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
458
- st.session_state.results["Train"] = df_tr; st.session_state.results["Test"] = df_te
459
-
460
- st.session_state.results["metrics_train"] = {
461
- "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
462
- "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
463
- "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]),
464
- }
465
- st.session_state.results["metrics_test"] = {
466
- "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
467
- "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
468
- "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
469
- }
470
-
471
- tr_min = df_tr[FEATURES].min().to_dict(); tr_max = df_tr[FEATURES].max().to_dict()
472
- st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
473
-
474
- st.session_state.dev_ready = True
475
- status.update(label="Done ✓", state="complete"); toast("Model run complete 🚀")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  st.rerun()
 
 
477
 
478
- if ("Train" in st.session_state.results) or ("Test" in st.session_state.results):
479
- tab1, tab2 = st.tabs(["Training", "Testing"])
480
- if "Train" in st.session_state.results:
481
- with tab1:
482
- df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
483
- c1,c2,c3 = st.columns(3)
484
- c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
485
- left, right = st.columns([0.9, 0.55])
486
- with left:
487
- if HAVE_PLOTLY:
488
- st.plotly_chart(cross_plotly(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"),
489
- use_container_width=True, theme=None)
490
- else:
491
- st.pyplot(cross_plot_mpl(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"),
492
- use_container_width=True)
493
- with right:
494
- if HAVE_PLOTLY:
495
- st.plotly_chart(track_plotly(df, include_actual=True), use_container_width=True, theme=None)
496
- else:
497
- st.pyplot(depth_or_index_track_mpl(df, title=None, include_actual=True),
498
- use_container_width=True)
499
- if "Test" in st.session_state.results:
500
- with tab2:
501
- df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
502
- c1,c2,c3 = st.columns(3)
503
- c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
504
- left, right = st.columns([0.9, 0.55])
505
- with left:
506
- if HAVE_PLOTLY:
507
- st.plotly_chart(cross_plotly(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"),
508
- use_container_width=True, theme=None)
509
- else:
510
- st.pyplot(cross_plot_mpl(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"),
511
- use_container_width=True)
512
- with right:
513
- if HAVE_PLOTLY:
514
- st.plotly_chart(track_plotly(df, include_actual=True), use_container_width=True, theme=None)
515
- else:
516
- st.pyplot(depth_or_index_track_mpl(df, title=None, include_actual=True),
517
- use_container_width=True)
518
 
519
- st.markdown("---")
520
- sheets = {}; rows = []
521
- if "Train" in st.session_state.results:
522
- sheets["Train_with_pred"] = st.session_state.results["Train"]
523
- rows.append({"Split":"Train", **{k:round(v,6) for k,v in st.session_state.results["metrics_train"].items()}})
524
- if "Test" in st.session_state.results:
525
- sheets["Test_with_pred"] = st.session_state.results["Test"]
526
- rows.append({"Split":"Test", **{k:round(v,6) for k,v in st.session_state.results["metrics_test"].items()}})
527
- summary_df = pd.DataFrame(rows) if rows else None
528
- try:
529
- data_bytes = export_workbook(sheets, summary_df)
530
- st.download_button("Export Development Results to Excel",
531
- data=data_bytes, file_name="UCS_Dev_Results.xlsx",
532
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
533
- except RuntimeError as e:
534
- st.warning(str(e))
535
 
536
- # ---------------- PREDICTION ----------------
537
- if st.session_state.app_step == "predict":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  st.sidebar.header("Prediction (Validation)")
539
- validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload")
540
- if validation_file is not None:
541
- _book_tmp = read_book_bytes(validation_file.getvalue())
542
- if _book_tmp:
543
- first_df = next(iter(_book_tmp.values()))
544
- st.sidebar.caption(f"**Data loaded:** {validation_file.name} • {first_df.shape[0]} rows × {first_df.shape[1]} cols")
545
-
546
- st.sidebar.markdown('<div class="val-actions">', unsafe_allow_html=True)
547
- preview_val_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=(validation_file is None))
548
- predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
549
- st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
550
- st.sidebar.markdown('</div>', unsafe_allow_html=True)
551
 
552
  st.subheader("Prediction")
553
- st.write("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.")
554
-
555
- if preview_val_btn and validation_file is not None:
556
- _book = read_book_bytes(validation_file.getvalue())
557
- preview_modal_val(_book, FEATURES)
558
-
559
- if predict_btn and validation_file is not None:
560
- with st.status("Predicting…", expanded=False) as status:
561
- vbook = read_book_bytes(validation_file.getvalue())
562
- if not vbook: status.update(label="Could not read the Validation Excel.", state="error"); st.stop()
563
- status.update(label="Workbook read ✓")
564
- vname = find_sheet(vbook, ["Validation","Validate","validation2","Val","val"]) or list(vbook.keys())[0]
 
565
  df_val = vbook[vname].copy()
566
- if not ensure_cols(df_val, FEATURES): status.update(label="Missing required columns.", state="error"); st.stop()
567
- status.update(label="Columns validated ✓")
568
- df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
569
- st.session_state.results["Validate"] = df_val
570
-
571
- ranges = st.session_state.train_ranges; oor_table = None; oor_pct = 0.0
572
- if ranges:
573
- viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
574
- any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
575
- if any_viol.any():
576
- offenders = df_val.loc[any_viol, FEATURES].copy()
577
- offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(
578
- lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
579
- offenders.index = offenders.index + 1; oor_table = offenders
580
-
581
- metrics_val = None
582
- if TARGET in df_val.columns:
583
- metrics_val = {
584
- "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
585
- "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
586
- "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"])
587
- }
588
- st.session_state.results["metrics_val"] = metrics_val
589
- st.session_state.results["summary_val"] = {
590
- "n_points": len(df_val),
591
- "pred_min": float(df_val["UCS_Pred"].min()),
592
- "pred_max": float(df_val["UCS_Pred"].max()),
593
- "oor_pct": oor_pct
594
- }
595
- st.session_state.results["oor_table"] = oor_table
596
- status.update(label="Predictions ready ✓", state="complete")
597
-
598
- if "Validate" in st.session_state.results:
599
- st.subheader("Validation Results")
600
- sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
601
 
602
- if sv["oor_pct"] > 0:
603
- st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
 
 
 
605
  c1,c2,c3,c4 = st.columns(4)
606
- c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
607
- c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
 
 
 
 
608
 
609
- left, right = st.columns([0.9, 0.55])
610
  with left:
611
- if TARGET in st.session_state.results["Validate"].columns:
612
- if HAVE_PLOTLY:
613
- st.plotly_chart(cross_plotly(st.session_state.results["Validate"][TARGET],
614
- st.session_state.results["Validate"]["UCS_Pred"],
615
- "Validation: Actual vs Predicted"),
616
- use_container_width=True, theme=None)
617
- else:
618
- st.pyplot(cross_plot_mpl(st.session_state.results["Validate"][TARGET],
619
- st.session_state.results["Validate"]["UCS_Pred"],
620
- "Validation: Actual vs Predicted"),
621
- use_container_width=True)
622
  else:
623
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
624
  with right:
625
- if HAVE_PLOTLY:
626
- st.plotly_chart(track_plotly(st.session_state.results["Validate"],
627
- include_actual=(TARGET in st.session_state.results["Validate"].columns)),
628
- use_container_width=True, theme=None)
629
- else:
630
- st.pyplot(depth_or_index_track_mpl(st.session_state.results["Validate"], title=None,
631
- include_actual=(TARGET in st.session_state.results["Validate"].columns)),
632
- use_container_width=True)
633
 
634
  if oor_table is not None:
635
  st.write("*Out-of-range rows (vs. Training min–max):*")
636
  st.dataframe(oor_table, use_container_width=True)
637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  st.markdown("---")
639
- sheets = {"Validate_with_pred": st.session_state.results["Validate"]}
640
  rows = []
641
  for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]:
642
- m = st.session_state.results.get(key)
643
  if m: rows.append({"Split": name, **{k: round(v,6) for k,v in m.items()}})
644
  summary_df = pd.DataFrame(rows) if rows else None
645
  try:
646
- data_bytes = export_workbook(sheets, summary_df)
647
  st.download_button("Export Validation Results to Excel",
648
  data=data_bytes, file_name="UCS_Validation_Results.xlsx",
649
  mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
650
  except RuntimeError as e:
651
  st.warning(str(e))
652
 
653
- # ---------------- Footer ----------------
 
 
654
  st.markdown("---")
655
  st.markdown(
656
- """
657
- <div style='text-align:center; color:#6b7280; line-height:1.6'>
658
- ST_GeoMech_UCS • © Smart Thinking<br/>
659
- <strong>Visit our website:</strong> <a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>
660
- </div>
661
- """,
 
662
  unsafe_allow_html=True
663
  )
 
1
+ # app.py
2
+ import io, os, json, base64
3
  from pathlib import Path
4
+
 
5
  import numpy as np
6
+ import pandas as pd
7
+ import streamlit as st
8
  import joblib
 
 
 
 
9
 
10
+ # =========================
11
+ # Constants / defaults
12
+ # =========================
 
 
 
 
 
 
13
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
14
  TARGET = "UCS"
15
+
16
  MODELS_DIR = Path("models")
17
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
18
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
19
 
20
+ COLORS = {
21
+ "pred": "#1f77b4", # blue
22
+ "actual": "#f2c94c", # yellow
23
+ "ref": "#444444", # 1:1 line
24
+ }
25
 
26
+ # =========================
27
+ # Page config + CSS
28
+ # =========================
29
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ st.markdown("""
32
+ <style>
33
+ /* Hide default header/footer chrome */
34
+ header, footer {visibility: hidden !important;}
35
+ .stApp { background: #ffffff; }
36
+
37
+ /* Sidebar look */
38
+ section[data-testid="stSidebar"] { background: #F6F9FC; }
39
+
40
+ /* Hero */
41
+ .st-hero { display:flex; align-items:center; gap:14px; padding: 4px 0 2px 0; }
42
+ .st-hero .brand { width:90px; height:90px; object-fit:contain; }
43
+ .st-hero h1 { margin:0; line-height:1.05; }
44
+ .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
45
+
46
+ /* Keep hero snug to the top */
47
+ [data-testid="stBlock"] { margin-top:0 !important; }
48
+
49
+ /* Global primary button style (Run Model stays blue) */
50
+ .stButton > button {
51
+ background:#2563eb; color:#fff; font-weight:600; border:none; border-radius:8px;
52
+ padding:9px 18px;
53
+ }
54
+
55
+ /* Orange preview button (scoped by wrapper) */
56
+ #preview-btn button {
57
+ background:#f59e0b !important; color:#fff !important;
58
+ }
59
+
60
+ /* Green proceed button (scoped by wrapper) */
61
+ #proceed-btn button {
62
+ background:#16a34a !important; color:#fff !important;
63
+ }
64
+
65
+ /* Info helper chip */
66
+ .helper-note {
67
+ background:#e7f0ff; border-radius:10px; padding:14px 16px; border:1px solid #d4e3ff;
68
+ color:#0f172a;
69
+ }
70
+
71
+ /* Make tab content tighter */
72
+ [data-baseweb="tab-border"] { margin-top: 0.2rem; }
73
+
74
+ /* Plotly charts use white backgrounds via functions below */
75
+ </style>
76
+ """, unsafe_allow_html=True)
77
+
78
+ # =========================
79
+ # Utils
80
+ # =========================
81
+ def inline_logo(path="logo.png") -> str:
82
+ try:
83
+ p = Path(path)
84
+ if not p.exists(): return ""
85
+ return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
86
+ except Exception:
87
+ return ""
88
 
89
+ def _get_model_url():
90
+ # Safe access (prevents the "No secrets files" banner)
91
+ try:
92
+ return (st.secrets.get("MODEL_URL", "") or os.environ.get("MODEL_URL", "") or "").strip()
93
+ except Exception:
94
+ return (os.environ.get("MODEL_URL", "") or "").strip()
95
 
96
  @st.cache_data(show_spinner=False)
97
+ def parse_excel_bytes(data_bytes: bytes):
98
  bio = io.BytesIO(data_bytes)
99
  xl = pd.ExcelFile(bio)
100
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
101
 
102
+ def ensure_required_columns(df: pd.DataFrame, cols) -> bool:
103
+ miss = [c for c in cols if c not in df.columns]
104
+ if miss:
105
+ st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
106
+ return False
107
+ return True
 
 
 
 
108
 
109
+ @st.cache_resource(show_spinner=False)
110
+ def load_model(model_path: str):
111
+ return joblib.load(model_path)
112
 
113
  def infer_features_from_model(m):
114
+ # Try scikit-learn feature names if present
115
  try:
116
  if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
117
  return [str(x) for x in m.feature_names_in_]
 
124
  except Exception: pass
125
  return None
126
 
127
+ def rmse(y_true, y_pred): # convenience
128
+ from sklearn.metrics import mean_squared_error
129
+ return float(np.sqrt(mean_squared_error(y_true, y_pred)))
130
+
131
+ # =========================
132
+ # Model availability
133
+ # =========================
134
+ MODEL_URL = _get_model_url()
135
+
136
+ def ensure_model_present() -> Path | None:
137
+ for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
138
+ if p.exists(): return p
139
+ if MODEL_URL:
140
+ try:
141
+ import requests
142
+ except Exception:
143
+ st.error("Downloading the model requires 'requests'. Please add it to requirements.txt.")
144
+ return None
145
+ try:
146
+ DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
147
+ with requests.get(MODEL_URL, stream=True) as r:
148
+ r.raise_for_status()
149
+ with open(DEFAULT_MODEL, "wb") as f:
150
+ for chunk in r.iter_content(chunk_size=1<<20):
151
+ f.write(chunk)
152
+ return DEFAULT_MODEL
153
+ except Exception as e:
154
+ st.error(f"Failed to download model from MODEL_URL. {e}")
155
+ return None
156
+ return None
157
+
158
+ model_path = ensure_model_present()
159
+ if not model_path:
160
+ st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).")
161
+ st.stop()
162
+
163
+ try:
164
+ model = load_model(str(model_path))
165
+ except Exception as e:
166
+ st.error(f"Failed to load model: {model_path}\n{e}")
167
+ st.stop()
168
 
169
+ # Optional meta overrides
170
+ meta_path = MODELS_DIR / "meta.json"
171
+ if meta_path.exists():
172
  try:
173
+ meta = json.loads(meta_path.read_text(encoding="utf-8"))
174
+ FEATURES = meta.get("features", FEATURES)
175
+ TARGET = meta.get("target", TARGET)
176
  except Exception:
177
+ pass
178
+ else:
179
+ _inf = infer_features_from_model(model)
180
+ if _inf: FEATURES = _inf
181
+
182
+ # =========================
183
+ # Plotly helpers (no titles, white background, safe margins)
184
+ # =========================
185
+ def _apply_plotly_base_layout(fig, *, top=40, left=60):
186
+ fig.update_layout(
187
+ margin=dict(l=left, r=10, t=top, b=40),
188
+ paper_bgcolor="#ffffff",
189
+ plot_bgcolor="#ffffff",
190
+ font=dict(size=12),
191
+ )
192
+ fig.update_xaxes(automargin=True, title_font=dict(size=12), tickfont=dict(size=11))
193
+ fig.update_yaxes(automargin=True, title_font=dict(size=12), tickfont=dict(size=11))
194
+ return fig
195
+
196
+ def cross_plotly(actual, pred):
197
+ import plotly.graph_objects as go
198
  lo = float(np.nanmin([actual.min(), pred.min()]))
199
  hi = float(np.nanmax([actual.max(), pred.max()]))
200
  pad = 0.03 * (hi - lo if hi > lo else 1.0)
201
+
202
  fig = go.Figure()
203
  fig.add_trace(go.Scatter(
204
  x=actual, y=pred, mode="markers",
205
  marker=dict(size=6, color=COLORS["pred"]),
206
  hovertemplate="Actual: %{x:.2f}<br>Pred: %{y:.2f}<extra></extra>",
207
+ showlegend=False, name="Points",
208
  ))
209
  fig.add_trace(go.Scatter(
210
  x=[lo - pad, hi + pad], y=[lo - pad, hi + pad],
211
  mode="lines", line=dict(dash="dash", width=1.5, color=COLORS["ref"]),
212
+ hoverinfo="skip", showlegend=False,
213
  ))
214
+
215
+ _apply_plotly_base_layout(fig, top=10, left=60)
216
+ fig.update_xaxes(
217
+ title_text="Actual UCS", title_standoff=10,
218
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
219
+ zeroline=False, scaleanchor="y", scaleratio=1
220
+ )
221
+ fig.update_yaxes(
222
+ title_text="Predicted UCS", title_standoff=10,
223
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
224
+ zeroline=False
225
+ )
226
  return fig
227
 
228
  def track_plotly(df, include_actual=True):
229
+ import plotly.graph_objects as go
230
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
231
  if depth_col is not None:
232
  y = df[depth_col]; y_label = depth_col
233
  else:
234
  y = np.arange(1, len(df) + 1); y_label = "Point Index"
235
+
236
  fig = go.Figure()
237
  fig.add_trace(go.Scatter(
238
  x=df["UCS_Pred"], y=y, mode="lines",
 
247
  name="UCS (actual)",
248
  hovertemplate="UCS (actual): %{x:.2f}<br>"+y_label+": %{y}<extra></extra>"
249
  ))
250
+
251
+ _apply_plotly_base_layout(fig, top=60, left=70)
252
+ fig.update_layout(
253
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0),
254
+ height=650
255
+ )
256
+ fig.update_xaxes(
257
+ title_text="UCS", side="top", title_standoff=12,
258
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)"
259
+ )
260
+ fig.update_yaxes(
261
+ title_text=y_label, autorange="reversed", title_standoff=10,
262
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)"
263
+ )
264
  return fig
265
 
266
  def make_index_tracks_plotly(df: pd.DataFrame, cols: list[str]):
267
+ from plotly.subplots import make_subplots
268
+ import plotly.graph_objects as go
269
+
270
  cols = [c for c in cols if c in df.columns]
271
  if not cols:
272
  fig = go.Figure()
273
  fig.add_annotation(text="No selected columns in sheet", showarrow=False, x=0.5, y=0.5)
274
  fig.update_xaxes(visible=False); fig.update_yaxes(visible=False)
275
+ fig.update_layout(height=200, margin=dict(l=10,r=10,t=10,b=10),
276
+ paper_bgcolor="#ffffff", plot_bgcolor="#ffffff")
277
  return fig
278
+
279
  n = len(cols)
280
+ # IMPORTANT: shared_yaxes (not shared_y)
281
+ fig = make_subplots(rows=1, cols=n, shared_yaxes=True, horizontal_spacing=0.05)
282
  idx = np.arange(1, len(df) + 1)
283
+
284
  for i, col in enumerate(cols, start=1):
285
  fig.add_trace(
286
+ go.Scatter(
287
+ x=df[col], y=idx, mode="lines",
288
+ line=dict(color="#333333", width=1.2),
289
+ hovertemplate=f"{col}: "+"%{x:.2f}<br>Index: %{y}<extra></extra>",
290
+ showlegend=False, name=col,
291
+ ), row=1, col=i
292
+ )
293
+ fig.update_xaxes(
294
+ title_text=col, side="top", title_standoff=10,
295
+ tickfont=dict(size=10),
296
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
297
  row=1, col=i
298
  )
 
 
 
 
299
 
300
+ fig.update_yaxes(
301
+ autorange="reversed", title_text="Point Index", title_standoff=10,
302
+ tickfont=dict(size=10),
303
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
304
+ row=1, col=1
305
+ )
306
+ fig.update_layout(
307
+ height=650,
308
+ margin=dict(l=60, r=10, t=60, b=40),
309
+ paper_bgcolor="#ffffff",
310
+ plot_bgcolor="#ffffff",
311
+ font=dict(size=12),
312
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  return fig
314
 
315
+ # =========================
316
+ # Session state defaults
317
+ # =========================
318
+ ss = st.session_state
319
+ ss.setdefault("app_step", "dev") # intro/dev/predict (you asked to start at dev)
320
+ ss.setdefault("dev_bytes", None) # raw uploaded bytes
321
+ ss.setdefault("dev_book", None) # parsed workbook dict
322
+ ss.setdefault("dev_sheet_train", None) # chosen train sheet
323
+ ss.setdefault("dev_sheet_test", None) # chosen test sheet
324
+ ss.setdefault("dev_previewed", False)
325
+ ss.setdefault("dev_ran", False)
326
+ ss.setdefault("results", {})
327
+ ss.setdefault("train_ranges", None)
328
+
329
+ # =========================
330
+ # Hero header
331
+ # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  st.markdown(
333
  f"""
334
  <div class="st-hero">
 
339
  </div>
340
  </div>
341
  """,
342
+ unsafe_allow_html=True,
343
  )
344
 
345
+ # =========================
346
+ # INTRO (kept for completeness – you said start in dev)
347
+ # =========================
348
+ if ss.app_step == "intro":
349
  st.header("Welcome!")
 
 
350
  st.markdown(
351
+ "1. **Upload your data** to build the case and preview the performance of our model.\n"
352
+ "2. **Run Model** to compute metrics and plots.\n"
353
+ "3. **Proceed to Prediction** to validate on a new dataset and export results."
 
 
354
  )
355
+ if st.button("Start", type="primary"): ss.app_step = "dev"; st.rerun()
 
 
 
 
 
 
 
 
356
 
357
+ # =========================
358
+ # DEVELOPMENT
359
+ # =========================
360
+ if ss.app_step == "dev":
361
+ # Sidebar controls
362
  st.sidebar.header("Model Development Data")
363
+ dev_file = st.sidebar.file_uploader("Replace data (Excel)", type=["xlsx","xls"], key="dev_upload")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
+ # Cache uploaded file into session (so preview doesn't clear it)
366
+ if dev_file is not None:
367
+ ss.dev_bytes = dev_file.getvalue()
368
+ try:
369
+ ss.dev_book = parse_excel_bytes(ss.dev_bytes)
370
+ except Exception as e:
371
+ st.sidebar.error(f"Failed to read workbook: {e}")
372
+ ss.dev_book = None
373
+ ss.dev_previewed = False
374
+ ss.dev_ran = False
375
+
376
+ # PREVIEW button (orange)
377
+ st.sidebar.markdown("<div id='preview-btn'>", unsafe_allow_html=True)
378
+ preview_click = st.sidebar.button("Preview data", use_container_width=True)
379
+ st.sidebar.markdown("</div>", unsafe_allow_html=True)
380
+
381
+ # RUN button (blue)
382
+ run_click = st.sidebar.button("Run Model", use_container_width=True)
383
+
384
+ # Proceed button (green; enabled after run)
385
+ st.sidebar.markdown("<div id='proceed-btn'>", unsafe_allow_html=True)
386
+ proceed_click = st.sidebar.button(
387
+ "Proceed to Prediction ▶",
388
+ use_container_width=True,
389
+ disabled=not ss.dev_ran
390
+ )
391
+ st.sidebar.markdown("</div>", unsafe_allow_html=True)
392
+
393
+ if proceed_click and ss.dev_ran:
394
+ ss.app_step = "predict"
395
+ st.rerun()
396
+
397
+ # Section heading
398
+ st.subheader("Model Development")
399
+
400
+ # Helper message (sticks here always)
401
+ helper = st.empty()
402
+ if ss.dev_book is None:
403
+ helper.markdown("<div class='helper-note'>Upload your data to build the case and preview the dataset.</div>", unsafe_allow_html=True)
404
+ elif not ss.dev_previewed:
405
+ helper.markdown("<div class='helper-note'>Data loaded ✓ — click <b>Preview data</b> to review tracks and summary.</div>", unsafe_allow_html=True)
406
+ elif ss.dev_previewed and not ss.dev_ran:
407
+ helper.markdown("<div class='helper-note'>Previewed — now click <b>Run Model</b> to build the case.</div>", unsafe_allow_html=True)
408
+ else:
409
+ helper.markdown("<div class='helper-note'>Case built ✓ — results are displayed below.</div>", unsafe_allow_html=True)
410
+
411
+ # ----------------- Preview modal -----------------
412
+ def preview_modal(book: dict, feature_cols: list[str]):
413
+ if not book: return
414
+ with st.expander("▼ Preview (tracks & summary)", expanded=True):
415
+ # Choose a sheet to preview
416
+ sheetnames = list(book.keys())
417
+ sh = st.selectbox("Sheet", options=sheetnames, index=0, key="preview_sheet_sel")
418
+ df = book[sh].copy()
419
+
420
+ # Tracks tab + Stats tab
421
+ t1, t2 = st.tabs(["Tracks", "Summary"])
422
+
423
+ with t1:
424
+ fig = make_index_tracks_plotly(df, feature_cols)
425
+ st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
426
+
427
+ with t2:
428
+ stats = df[feature_cols].describe().T[["min", "max", "mean", "std"]].rename(
429
+ columns={"min":"Min", "max":"Max", "mean":"Mean", "std":"Std"}
430
+ )
431
+ st.dataframe(stats, use_container_width=True)
432
+
433
+ # If preview clicked and we have data
434
+ if preview_click:
435
+ if ss.dev_book:
436
+ preview_modal(ss.dev_book, FEATURES)
437
+ ss.dev_previewed = True
438
+ ss.dev_ran = False
439
  st.rerun()
440
+ else:
441
+ st.warning("Please upload an Excel file first.")
442
 
443
+ # If run clicked and we have data
444
+ if run_click:
445
+ if not ss.dev_book:
446
+ st.warning("Please upload and preview your data first.")
447
+ else:
448
+ # Try to find common sheet names
449
+ names = list(ss.dev_book.keys())
450
+ def find_sheet(book, alts):
451
+ lo = {k.lower(): k for k in book.keys()}
452
+ for nm in alts:
453
+ if nm.lower() in lo: return lo[nm.lower()]
454
+ return None
455
+
456
+ sh_train = find_sheet(ss.dev_book, ["Train","Training","training2","train","training"]) or names[0]
457
+ sh_test = find_sheet(ss.dev_book, ["Test","Testing","testing2","test","testing"]) or (names[1] if len(names)>1 else names[0])
458
+ ss.dev_sheet_train, ss.dev_sheet_test = sh_train, sh_test
459
+
460
+ df_tr = ss.dev_book[sh_train].copy()
461
+ df_te = ss.dev_book[sh_test].copy()
462
+
463
+ ok = ensure_required_columns(df_tr, FEATURES+[TARGET]) and ensure_required_columns(df_te, FEATURES+[TARGET])
464
+ if ok:
465
+ df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
466
+ df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
467
+
468
+ from sklearn.metrics import r2_score, mean_absolute_error
469
+ ss.results["Train"] = df_tr
470
+ ss.results["Test"] = df_te
471
+ ss.results["metrics_train"] = {
472
+ "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
473
+ "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
474
+ "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]),
475
+ }
476
+ ss.results["metrics_test"] = {
477
+ "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
478
+ "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
479
+ "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
480
+ }
 
 
481
 
482
+ tr_min = df_tr[FEATURES].min().to_dict()
483
+ tr_max = df_tr[FEATURES].max().to_dict()
484
+ ss.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
+ ss.dev_ran = True
487
+ helper.markdown("<div class='helper-note'>Case built ✓ — results are displayed below.</div>", unsafe_allow_html=True)
488
+ else:
489
+ ss.dev_ran = False
490
+
491
+ # Show results if available
492
+ if ss.dev_ran and ("Train" in ss.results or "Test" in ss.results):
493
+ ttr, tte = st.tabs(["Training", "Testing"])
494
+
495
+ if "Train" in ss.results:
496
+ with ttr:
497
+ m = ss.results["metrics_train"]
498
+ c1,c2,c3 = st.columns([1,1,1])
499
+ c1.metric("R²", f"{m['R2']:.4f}")
500
+ c2.metric("RMSE", f"{m['RMSE']:.4f}")
501
+ c3.metric("MAE", f"{m['MAE']:.4f}")
502
+ l, r = st.columns([0.55, 0.45])
503
+ with l:
504
+ st.plotly_chart(cross_plotly(ss.results["Train"][TARGET], ss.results["Train"]["UCS_Pred"]),
505
+ use_container_width=True, config={"displayModeBar": False})
506
+ with r:
507
+ st.plotly_chart(track_plotly(ss.results["Train"], include_actual=True),
508
+ use_container_width=True, config={"displayModeBar": False})
509
+
510
+ if "Test" in ss.results:
511
+ with tte:
512
+ m = ss.results["metrics_test"]
513
+ c1,c2,c3 = st.columns([1,1,1])
514
+ c1.metric("R²", f"{m['R2']:.4f}")
515
+ c2.metric("RMSE", f"{m['RMSE']:.4f}")
516
+ c3.metric("MAE", f"{m['MAE']:.4f}")
517
+ l, r = st.columns([0.55, 0.45])
518
+ with l:
519
+ st.plotly_chart(cross_plotly(ss.results["Test"][TARGET], ss.results["Test"]["UCS_Pred"]),
520
+ use_container_width=True, config={"displayModeBar": False})
521
+ with r:
522
+ st.plotly_chart(track_plotly(ss.results["Test"], include_actual=True),
523
+ use_container_width=True, config={"displayModeBar": False})
524
+
525
+ # =========================
526
+ # PREDICTION
527
+ # =========================
528
+ if ss.app_step == "predict":
529
  st.sidebar.header("Prediction (Validation)")
530
+ val_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload")
531
+ predict_click = st.sidebar.button("Predict", use_container_width=True)
532
+ back_click = st.sidebar.button("⬅ Back", use_container_width=True)
533
+
534
+ if back_click:
535
+ ss.app_step = "dev"; st.rerun()
 
 
 
 
 
 
536
 
537
  st.subheader("Prediction")
538
+ st.markdown("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.")
539
+ st.success("Predictions ready ✓" if "Validate" in ss.results else "Waiting for input…")
540
+
541
+ if predict_click and val_file is not None:
542
+ try:
543
+ vbook = parse_excel_bytes(val_file.getvalue())
544
+ except Exception as e:
545
+ st.error(f"Could not read the Validation Excel: {e}")
546
+ vbook = {}
547
+
548
+ if vbook:
549
+ # Pick first sheet by default
550
+ vname = list(vbook.keys())[0]
551
  df_val = vbook[vname].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
+ if ensure_required_columns(df_val, FEATURES):
554
+ df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
555
+ ss.results["Validate"] = df_val
556
+
557
+ # Out-of-range check vs training ranges
558
+ ranges = ss.train_ranges; oor_table = None; oor_pct = 0.0
559
+ if ranges:
560
+ viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
561
+ any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
562
+ if any_viol.any():
563
+ offenders = df_val.loc[any_viol, FEATURES].copy()
564
+ offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(
565
+ lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
566
+ offenders.index = offenders.index + 1; oor_table = offenders
567
+
568
+ from sklearn.metrics import r2_score, mean_absolute_error
569
+ metrics_val = None
570
+ if TARGET in df_val.columns:
571
+ metrics_val = {
572
+ "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
573
+ "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
574
+ "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"]),
575
+ }
576
+ ss.results["metrics_val"] = metrics_val
577
+ ss.results["summary_val"] = {
578
+ "n_points": len(df_val),
579
+ "pred_min": float(df_val["UCS_Pred"].min()),
580
+ "pred_max": float(df_val["UCS_Pred"].max()),
581
+ "oor_pct": oor_pct
582
+ }
583
+ ss.results["oor_table"] = oor_table
584
+ st.experimental_rerun()
585
 
586
+ # Show prediction results
587
+ if "Validate" in ss.results:
588
+ sv = ss.results["summary_val"]; oor_table = ss.results.get("oor_table")
589
  c1,c2,c3,c4 = st.columns(4)
590
+ c1.metric("# points", f"{sv['n_points']}")
591
+ c2.metric("Pred min", f"{sv['pred_min']:.2f}")
592
+ c3.metric("Pred max", f"{sv['pred_max']:.2f}")
593
+ c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
594
+ if sv["oor_pct"] > 0:
595
+ st.warning("Some validation rows contain inputs outside the Training min–max ranges. Review the table below.")
596
 
597
+ left, right = st.columns([0.55, 0.45])
598
  with left:
599
+ if TARGET in ss.results["Validate"].columns:
600
+ st.plotly_chart(
601
+ cross_plotly(ss.results["Validate"][TARGET], ss.results["Validate"]["UCS_Pred"]),
602
+ use_container_width=True, config={"displayModeBar": False}
603
+ )
 
 
 
 
 
 
604
  else:
605
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
606
  with right:
607
+ st.plotly_chart(
608
+ track_plotly(ss.results["Validate"], include_actual=(TARGET in ss.results["Validate"].columns)),
609
+ use_container_width=True, config={"displayModeBar": False}
610
+ )
 
 
 
 
611
 
612
  if oor_table is not None:
613
  st.write("*Out-of-range rows (vs. Training min–max):*")
614
  st.dataframe(oor_table, use_container_width=True)
615
 
616
+ # Export
617
+ def export_workbook(sheets_dict, summary_df=None):
618
+ try:
619
+ import openpyxl
620
+ except Exception:
621
+ raise RuntimeError("Export requires openpyxl. Please add it to requirements.txt.")
622
+ buf = io.BytesIO()
623
+ with pd.ExcelWriter(buf, engine="openpyxl") as xw:
624
+ for name, frame in sheets_dict.items():
625
+ frame.to_excel(xw, sheet_name=name[:31], index=False)
626
+ if summary_df is not None:
627
+ summary_df.to_excel(xw, sheet_name="Summary", index=False)
628
+ return buf.getvalue()
629
+
630
  st.markdown("---")
631
+ sheets_to_save = {"Validate_with_pred": ss.results["Validate"]}
632
  rows = []
633
  for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]:
634
+ m = ss.results.get(key)
635
  if m: rows.append({"Split": name, **{k: round(v,6) for k,v in m.items()}})
636
  summary_df = pd.DataFrame(rows) if rows else None
637
  try:
638
+ data_bytes = export_workbook(sheets_to_save, summary_df)
639
  st.download_button("Export Validation Results to Excel",
640
  data=data_bytes, file_name="UCS_Validation_Results.xlsx",
641
  mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
642
  except RuntimeError as e:
643
  st.warning(str(e))
644
 
645
+ # =========================
646
+ # Footer
647
+ # =========================
648
  st.markdown("---")
649
  st.markdown(
650
+ "<div style='text-align:center; color:#6b7280;'>"
651
+ "ST_GeoMech_UCS © Smart Thinking"
652
+ "</div>"
653
+ "<div style='text-align:center; color:#6b7280;'>"
654
+ "Visit our Website: "
655
+ "<a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>"
656
+ "</div>",
657
  unsafe_allow_html=True
658
  )