UCS2014 commited on
Commit
148589b
·
verified ·
1 Parent(s): a68ea4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -109
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import io, json, os, base64
3
  from pathlib import Path
4
  import streamlit as st
@@ -45,7 +44,7 @@ st.markdown(
45
  # Small helpers
46
  # =========================
47
  def _get_model_url():
48
- # Avoid Streamlit secrets error when secrets.toml is absent
49
  try:
50
  return (st.secrets.get("MODEL_URL", "") or os.environ.get("MODEL_URL", "") or "").strip()
51
  except Exception:
@@ -72,24 +71,30 @@ def parse_excel(data_bytes: bytes):
72
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
73
 
74
  def read_book(upload):
75
- if upload is None: return {}
76
- try: return parse_excel(upload.getvalue())
 
 
77
  except Exception as e:
78
- st.error(f"Failed to read Excel: {e}"); return {}
 
79
 
80
  def find_sheet(book, names):
81
  low2orig = {k.lower(): k for k in book.keys()}
82
  for nm in names:
83
- if nm.lower() in low2orig: return low2orig[nm.lower()]
 
84
  return None
85
 
86
- def cross_plot(actual, pred, title, size=(5.6,5.6)):
87
  fig, ax = plt.subplots(figsize=size)
88
  ax.scatter(actual, pred, s=16, alpha=0.7)
89
  lo = float(np.nanmin([actual.min(), pred.min()]))
90
  hi = float(np.nanmax([actual.max(), pred.max()]))
91
- ax.plot([lo,hi], [lo,hi], '--')
92
- ax.set_xlabel("Actual UCS"); ax.set_ylabel("Predicted UCS"); ax.set_title(title)
 
 
93
  ax.grid(True, ls=":", alpha=0.4)
94
  return fig
95
 
@@ -98,26 +103,36 @@ def depth_or_index_track(df, title, include_actual=True):
98
  depth_col = None
99
  for c in df.columns:
100
  if 'depth' in str(c).lower():
101
- depth_col = c; break
 
102
  fig, ax = plt.subplots(figsize=(5.8, 7.5))
103
  if depth_col is not None:
104
  ax.plot(df["UCS_Pred"], df[depth_col], label="UCS_Pred")
105
  if include_actual and TARGET in df.columns:
106
  ax.plot(df[TARGET], df[depth_col], alpha=0.7, label="UCS (actual)")
107
- ax.set_ylabel(depth_col); ax.set_xlabel("UCS")
108
- ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
 
 
 
109
  else:
110
  idx = np.arange(1, len(df) + 1)
111
  ax.plot(df["UCS_Pred"], idx, label="UCS_Pred")
112
  if include_actual and TARGET in df.columns:
113
  ax.plot(df[TARGET], idx, alpha=0.7, label="UCS (actual)")
114
- ax.set_ylabel("Point Index"); ax.set_xlabel("UCS")
115
- ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
116
- ax.grid(True, linestyle=":", alpha=0.4); ax.set_title(title, pad=12); ax.legend()
 
 
 
 
 
117
  return fig
118
 
119
  def export_workbook(sheets_dict, summary_df=None):
120
- try: import openpyxl
 
121
  except Exception:
122
  raise RuntimeError("Export requires openpyxl. Please add it to requirements or install it.")
123
  buf = io.BytesIO()
@@ -129,26 +144,31 @@ def export_workbook(sheets_dict, summary_df=None):
129
  return buf.getvalue()
130
 
131
  def toast(msg):
132
- try: st.toast(msg)
133
- except Exception: st.info(msg)
 
 
134
 
135
  def infer_features_from_model(m):
136
  try:
137
  if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
138
  return [str(x) for x in m.feature_names_in_]
139
- except Exception: pass
 
140
  try:
141
  if hasattr(m, "steps") and len(m.steps):
142
  last = m.steps[-1][1]
143
  if hasattr(last, "feature_names_in_") and len(last.feature_names_in_):
144
  return [str(x) for x in last.feature_names_in_]
145
- except Exception: pass
 
146
  return None
147
 
148
  def inline_logo(path="logo.png") -> str:
149
  try:
150
  p = Path(path)
151
- if not p.exists(): return ""
 
152
  return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
153
  except Exception:
154
  return ""
@@ -159,43 +179,45 @@ def inline_logo(path="logo.png") -> str:
159
  MODEL_URL = _get_model_url()
160
 
161
  def ensure_model_present() -> Path:
 
162
  # Check local paths first
163
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
164
- if p.exists():
165
  return p
166
- # Download if MODEL_URL provided
167
- if MODEL_URL:
168
- try:
169
- import requests
170
- except Exception:
171
- st.error("requests is required to download the model. Add 'requests' to requirements.txt.")
172
- return None
173
- try:
174
- DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
175
- with requests.get(MODEL_URL, stream=True) as r:
 
176
  r.raise_for_status()
177
  with open(DEFAULT_MODEL, "wb") as f:
178
- for chunk in r.iter_content(chunk_size=1<<20):
179
- f.write(chunk)
180
- return DEFAULT_MODEL
181
- except Exception as e:
182
- st.error(f"Failed to download model from MODEL_URL. {e}")
183
- return None
 
184
 
185
  model_path = ensure_model_present()
186
  if not model_path:
187
  st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).")
188
  st.stop()
189
 
190
- # Load model
191
- try:
192
  model = load_model(str(model_path))
193
  except Exception as e:
194
  st.error(f"Failed to load model: {model_path}\n{e}")
195
  st.stop()
196
 
197
-
198
- # Meta overrides
199
  meta_path = MODELS_DIR / "meta.json"
200
  if meta_path.exists():
201
  try:
@@ -206,14 +228,18 @@ if meta_path.exists():
206
  pass
207
  else:
208
  infer = infer_features_from_model(model)
209
- if infer: FEATURES = infer
 
210
 
211
  # =========================
212
  # Session state
213
  # =========================
214
- if "app_step" not in st.session_state: st.session_state.app_step = "intro"
215
- if "results" not in st.session_state: st.session_state.results = {}
216
- if "train_ranges" not in st.session_state: st.session_state.train_ranges = None
 
 
 
217
 
218
  # =========================
219
  # Hero header (logo + title)
@@ -232,31 +258,39 @@ st.markdown(
232
  )
233
 
234
  # =========================
235
- # INTRO PAGE (as requested)
236
  # =========================
237
  if st.session_state.app_step == "intro":
238
  st.header("Welcome!")
239
- st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to predict the UCS of the underlying formations while drilling using the drilling data.")
 
 
 
240
  st.subheader("Required Input Columns")
241
- st.markdown("- Q, gpm — Flow rate (gallons per minute)
242
- - SPP(psi)Stand pipe pressure
243
- - T (kft.lbf) — Torque (thousand foot-pounds)
244
- - WOB (klbf) — Weight on bit
245
- - ROP (ft/h) — Rate of penetration")
 
 
246
  st.subheader("How It Works")
247
- st.markdown("1. *Upload the Model Development Data.* This should contain your training and testing sets.
248
- 2. Click *Run Model* to view metrics, cross-plots, and a track plot.
249
- 3. Click *Go to Prediction* and upload a new dataset to get predictions.
250
- 4. *Export* everything to Excel for further analysis.")
 
 
251
  if st.button("Start Showcase", type="primary", key="start_showcase"):
252
- st.session_state.app_step = "dev"; st.rerun()
 
253
 
254
  # =========================
255
  # MODEL DEVELOPMENT (Train/Test)
256
  # =========================
257
  if st.session_state.app_step == "dev":
258
  st.sidebar.header("Model Development Data")
259
- train_test_file = st.sidebar.file_uploader("Upload Train/Test Excel", type=["xlsx","xls"], key="dev_upload")
260
  run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
261
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
262
  st.sidebar.button("Go to Prediction ▶", use_container_width=True, on_click=lambda: st.session_state.update(app_step="predict"))
@@ -265,71 +299,96 @@ if st.session_state.app_step == "dev":
265
  if run_btn and train_test_file is not None:
266
  with st.status("Processing…", expanded=False) as status:
267
  book = read_book(train_test_file)
268
- if not book: status.update(label="Failed to read workbook.", state="error"); st.stop()
 
 
269
  status.update(label="Workbook read ✓")
270
 
271
- sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
272
- sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
273
  if sh_train is None or sh_test is None:
274
- status.update(label="Workbook must include Train/Training/training2 and Test/Testing/testing2.", state="error"); st.stop()
 
275
 
276
- df_tr = book[sh_train].copy(); df_te = book[sh_test].copy()
 
277
  if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])):
278
- status.update(label="Missing required columns.", state="error"); st.stop()
 
 
279
  status.update(label="Columns validated ✓")
280
  status.update(label="Predicting…")
281
 
282
  df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
283
  df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
284
- st.session_state.results["Train"] = df_tr; st.session_state.results["Test"] = df_te
 
285
 
286
  st.session_state.results["metrics_train"] = {
287
  "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
288
  "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
289
  "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]),
290
  }
291
- st.session_state.results["metrics_test"] = {
292
  "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
293
  "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
294
  "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
295
  }
296
 
297
- tr_min = df_tr[FEATURES].min().to_dict(); tr_max = df_tr[FEATURES].max().to_dict()
298
- st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
 
299
 
300
- status.update(label="Done ✓", state="complete"); toast("Model run complete 🚀")
 
301
 
302
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
303
  tab1, tab2 = st.tabs(["Training", "Testing"])
304
  if "Train" in st.session_state.results:
305
  with tab1:
306
- df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
307
- c1,c2,c3 = st.columns(3); c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
308
- left,right = st.columns(2)
309
- with left: st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
310
- with right: st.pyplot(depth_or_index_track(df, "Training: Depth/Index Track", include_actual=True), use_container_width=True)
 
 
 
 
 
 
311
  if "Test" in st.session_state.results:
312
  with tab2:
313
- df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
314
- c1,c2,c3 = st.columns(3); c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
315
- left,right = st.columns(2)
316
- with left: st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
317
- with right: st.pyplot(depth_or_index_track(df, "Testing: Depth/Index Track", include_actual=True), use_container_width=True)
 
 
 
 
 
 
318
 
319
  st.markdown("---")
320
- sheets = {}; rows = []
 
321
  if "Train" in st.session_state.results:
322
  sheets["Train_with_pred"] = st.session_state.results["Train"]
323
- rows.append({"Split":"Train", **{k:round(v,6) for k,v in st.session_state.results["metrics_train"].items()}})
324
  if "Test" in st.session_state.results:
325
  sheets["Test_with_pred"] = st.session_state.results["Test"]
326
- rows.append({"Split":"Test", **{k:round(v,6) for k,v in st.session_state.results["metrics_test"].items()}})
327
  summary_df = pd.DataFrame(rows) if rows else None
328
  try:
329
  data_bytes = export_workbook(sheets, summary_df)
330
- st.download_button("Export Train/Test Results to Excel",
331
- data=data_bytes, file_name="UCS_Dev_Results.xlsx",
332
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
 
 
 
333
  except RuntimeError as e:
334
  st.warning(str(e))
335
 
@@ -338,7 +397,7 @@ if st.session_state.app_step == "dev":
338
  # =========================
339
  if st.session_state.app_step == "predict":
340
  st.sidebar.header("Prediction (Validation)")
341
- validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload")
342
  predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
343
  st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
344
 
@@ -348,55 +407,83 @@ if st.session_state.app_step == "predict":
348
  if predict_btn and validation_file is not None:
349
  with st.status("Predicting…", expanded=False) as status:
350
  vbook = read_book(validation_file)
351
- if not vbook: status.update(label="Could not read the Validation Excel.", state="error"); st.stop()
 
 
352
  status.update(label="Workbook read ✓")
353
- vname = find_sheet(vbook, ["Validation","Validate","validation2","Val","val"]) or list(vbook.keys())[0]
354
  df_val = vbook[vname].copy()
355
- if not ensure_cols(df_val, FEATURES): status.update(label="Missing required columns.", state="error"); st.stop()
 
 
356
  status.update(label="Columns validated ✓")
357
  df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
358
  st.session_state.results["Validate"] = df_val
359
 
360
- ranges = st.session_state.train_ranges; oor_table = None; oor_pct = 0.0
 
 
 
361
  if ranges:
362
  viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
363
- any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
 
364
  if any_viol.any():
365
  offenders = df_val.loc[any_viol, FEATURES].copy()
366
- offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
367
- offenders.index = offenders.index + 1; oor_table = offenders
 
 
 
368
 
369
  metrics_val = None
370
  if TARGET in df_val.columns:
371
  metrics_val = {
372
  "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
373
  "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
374
- "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"])
375
  }
376
  st.session_state.results["metrics_val"] = metrics_val
377
  st.session_state.results["summary_val"] = {
378
  "n_points": len(df_val),
379
  "pred_min": float(df_val["UCS_Pred"].min()),
380
  "pred_max": float(df_val["UCS_Pred"].max()),
381
- "oor_pct": oor_pct
382
  }
383
  st.session_state.results["oor_table"] = oor_table
384
  status.update(label="Predictions ready ✓", state="complete")
385
 
386
  if "Validate" in st.session_state.results:
387
  st.subheader("Validation Results")
388
- sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
389
- c1,c2,c3,c4 = st.columns(4)
390
- c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
391
- c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
392
- left,right = st.columns(2)
 
 
 
393
  with left:
394
  if TARGET in st.session_state.results["Validate"].columns:
395
- st.pyplot(cross_plot(st.session_state.results["Validate"][TARGET], st.session_state.results["Validate"]["UCS_Pred"], "Validation: Actual vs Predicted"), use_container_width=True)
 
 
 
 
 
 
 
396
  else:
397
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
398
  with right:
399
- st.pyplot(depth_or_index_track(st.session_state.results["Validate"], "Validation: Depth/Index Track", include_actual=(TARGET in st.session_state.results["Validate"].columns)), use_container_width=True)
 
 
 
 
 
 
 
400
  if oor_table is not None:
401
  st.write("*Out-of-range rows (vs. Training min–max):*")
402
  st.dataframe(oor_table, use_container_width=True)
@@ -404,15 +491,19 @@ if st.session_state.app_step == "predict":
404
  st.markdown("---")
405
  sheets = {"Validate_with_pred": st.session_state.results["Validate"]}
406
  rows = []
407
- for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]:
408
  m = st.session_state.results.get(key)
409
- if m: rows.append({"Split": name, **{k: round(v,6) for k,v in m.items()}})
 
410
  summary_df = pd.DataFrame(rows) if rows else None
411
  try:
412
  data_bytes = export_workbook(sheets, summary_df)
413
- st.download_button("Export Validation Results to Excel",
414
- data=data_bytes, file_name="UCS_Validation_Results.xlsx",
415
- mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
 
 
 
416
  except RuntimeError as e:
417
  st.warning(str(e))
418
 
@@ -420,4 +511,7 @@ if st.session_state.app_step == "predict":
420
  # Footer
421
  # =========================
422
  st.markdown("---")
423
- st.markdown("<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>", unsafe_allow_html=True)
 
 
 
 
 
1
  import io, json, os, base64
2
  from pathlib import Path
3
  import streamlit as st
 
44
  # Small helpers
45
  # =========================
46
  def _get_model_url():
47
+ """Avoid Streamlit secrets error when secrets.toml is absent."""
48
  try:
49
  return (st.secrets.get("MODEL_URL", "") or os.environ.get("MODEL_URL", "") or "").strip()
50
  except Exception:
 
71
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
72
 
73
  def read_book(upload):
74
+ if upload is None:
75
+ return {}
76
+ try:
77
+ return parse_excel(upload.getvalue())
78
  except Exception as e:
79
+ st.error(f"Failed to read Excel: {e}")
80
+ return {}
81
 
82
  def find_sheet(book, names):
83
  low2orig = {k.lower(): k for k in book.keys()}
84
  for nm in names:
85
+ if nm.lower() in low2orig:
86
+ return low2orig[nm.lower()]
87
  return None
88
 
89
+ def cross_plot(actual, pred, title, size=(5.6, 5.6)):
90
  fig, ax = plt.subplots(figsize=size)
91
  ax.scatter(actual, pred, s=16, alpha=0.7)
92
  lo = float(np.nanmin([actual.min(), pred.min()]))
93
  hi = float(np.nanmax([actual.max(), pred.max()]))
94
+ ax.plot([lo, hi], [lo, hi], '--')
95
+ ax.set_xlabel("Actual UCS")
96
+ ax.set_ylabel("Predicted UCS")
97
+ ax.set_title(title)
98
  ax.grid(True, ls=":", alpha=0.4)
99
  return fig
100
 
 
103
  depth_col = None
104
  for c in df.columns:
105
  if 'depth' in str(c).lower():
106
+ depth_col = c
107
+ break
108
  fig, ax = plt.subplots(figsize=(5.8, 7.5))
109
  if depth_col is not None:
110
  ax.plot(df["UCS_Pred"], df[depth_col], label="UCS_Pred")
111
  if include_actual and TARGET in df.columns:
112
  ax.plot(df[TARGET], df[depth_col], alpha=0.7, label="UCS (actual)")
113
+ ax.set_ylabel(depth_col)
114
+ ax.set_xlabel("UCS")
115
+ ax.xaxis.set_label_position('top')
116
+ ax.xaxis.tick_top()
117
+ ax.invert_yaxis()
118
  else:
119
  idx = np.arange(1, len(df) + 1)
120
  ax.plot(df["UCS_Pred"], idx, label="UCS_Pred")
121
  if include_actual and TARGET in df.columns:
122
  ax.plot(df[TARGET], idx, alpha=0.7, label="UCS (actual)")
123
+ ax.set_ylabel("Point Index")
124
+ ax.set_xlabel("UCS")
125
+ ax.xaxis.set_label_position('top')
126
+ ax.xaxis.tick_top()
127
+ ax.invert_yaxis()
128
+ ax.grid(True, linestyle=":", alpha=0.4)
129
+ ax.set_title(title, pad=12)
130
+ ax.legend()
131
  return fig
132
 
133
  def export_workbook(sheets_dict, summary_df=None):
134
+ try:
135
+ import openpyxl # noqa
136
  except Exception:
137
  raise RuntimeError("Export requires openpyxl. Please add it to requirements or install it.")
138
  buf = io.BytesIO()
 
144
  return buf.getvalue()
145
 
146
  def toast(msg):
147
+ try:
148
+ st.toast(msg)
149
+ except Exception:
150
+ st.info(msg)
151
 
152
  def infer_features_from_model(m):
153
  try:
154
  if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")):
155
  return [str(x) for x in m.feature_names_in_]
156
+ except Exception:
157
+ pass
158
  try:
159
  if hasattr(m, "steps") and len(m.steps):
160
  last = m.steps[-1][1]
161
  if hasattr(last, "feature_names_in_") and len(last.feature_names_in_):
162
  return [str(x) for x in last.feature_names_in_]
163
+ except Exception:
164
+ pass
165
  return None
166
 
167
  def inline_logo(path="logo.png") -> str:
168
  try:
169
  p = Path(path)
170
+ if not p.exists():
171
+ return ""
172
  return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
173
  except Exception:
174
  return ""
 
179
  MODEL_URL = _get_model_url()
180
 
181
  def ensure_model_present() -> Path:
182
+ """Return a local model path, trying local files first, then (optionally) downloading with timeout."""
183
  # Check local paths first
184
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
185
+ if p.exists() and p.stat().st_size > 0:
186
  return p
187
+
188
+ # If no URL set, we cannot download
189
+ if not MODEL_URL:
190
+ return None
191
+
192
+ # Try to download with a short timeout so startup can't hang
193
+ try:
194
+ import requests # only when needed
195
+ DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
196
+ with st.status("Downloading model…", expanded=False):
197
+ with requests.get(MODEL_URL, stream=True, timeout=30) as r:
198
  r.raise_for_status()
199
  with open(DEFAULT_MODEL, "wb") as f:
200
+ for chunk in r.iter_content(chunk_size=1 << 20):
201
+ if chunk:
202
+ f.write(chunk)
203
+ return DEFAULT_MODEL
204
+ except Exception as e:
205
+ st.error(f"Failed to download model from MODEL_URL: {e}")
206
+ return None
207
 
208
  model_path = ensure_model_present()
209
  if not model_path:
210
  st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).")
211
  st.stop()
212
 
213
+ # Load model (fix: correct try/except block)
214
+ try:
215
  model = load_model(str(model_path))
216
  except Exception as e:
217
  st.error(f"Failed to load model: {model_path}\n{e}")
218
  st.stop()
219
 
220
+ # Meta overrides or inference
 
221
  meta_path = MODELS_DIR / "meta.json"
222
  if meta_path.exists():
223
  try:
 
228
  pass
229
  else:
230
  infer = infer_features_from_model(model)
231
+ if infer:
232
+ FEATURES = infer
233
 
234
  # =========================
235
  # Session state
236
  # =========================
237
+ if "app_step" not in st.session_state:
238
+ st.session_state.app_step = "intro"
239
+ if "results" not in st.session_state:
240
+ st.session_state.results = {}
241
+ if "train_ranges" not in st.session_state:
242
+ st.session_state.train_ranges = None
243
 
244
  # =========================
245
  # Hero header (logo + title)
 
258
  )
259
 
260
  # =========================
261
+ # INTRO PAGE
262
  # =========================
263
  if st.session_state.app_step == "intro":
264
  st.header("Welcome!")
265
+ st.markdown(
266
+ "This software is developed by *Smart Thinking AI-Solutions Team* "
267
+ "to predict the UCS of the underlying formations while drilling using the drilling data."
268
+ )
269
  st.subheader("Required Input Columns")
270
+ st.markdown(
271
+ "- Q, gpm Flow rate (gallons per minute) \n"
272
+ "- SPP(psi) — Stand pipe pressure \n"
273
+ "- T (kft.lbf) — Torque (thousand foot-pounds) \n"
274
+ "- WOB (klbf) — Weight on bit \n"
275
+ "- ROP (ft/h) — Rate of penetration"
276
+ )
277
  st.subheader("How It Works")
278
+ st.markdown(
279
+ "1. *Upload the Model Development Data.* This should contain your training and testing sets.\n"
280
+ "2. Click *Run Model* to view metrics, cross-plots, and a track plot.\n"
281
+ "3. Click *Go to Prediction* and upload a new dataset to get predictions.\n"
282
+ "4. *Export* everything to Excel for further analysis."
283
+ )
284
  if st.button("Start Showcase", type="primary", key="start_showcase"):
285
+ st.session_state.app_step = "dev"
286
+ st.rerun()
287
 
288
  # =========================
289
  # MODEL DEVELOPMENT (Train/Test)
290
  # =========================
291
  if st.session_state.app_step == "dev":
292
  st.sidebar.header("Model Development Data")
293
+ train_test_file = st.sidebar.file_uploader("Upload Train/Test Excel", type=["xlsx", "xls"], key="dev_upload")
294
  run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
295
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
296
  st.sidebar.button("Go to Prediction ▶", use_container_width=True, on_click=lambda: st.session_state.update(app_step="predict"))
 
299
  if run_btn and train_test_file is not None:
300
  with st.status("Processing…", expanded=False) as status:
301
  book = read_book(train_test_file)
302
+ if not book:
303
+ status.update(label="Failed to read workbook.", state="error")
304
+ st.stop()
305
  status.update(label="Workbook read ✓")
306
 
307
+ sh_train = find_sheet(book, ["Train", "Training", "training2", "train", "training"])
308
+ sh_test = find_sheet(book, ["Test", "Testing", "testing2", "test", "testing"])
309
  if sh_train is None or sh_test is None:
310
+ status.update(label="Workbook must include Train/Training/training2 and Test/Testing/testing2.", state="error")
311
+ st.stop()
312
 
313
+ df_tr = book[sh_train].copy()
314
+ df_te = book[sh_test].copy()
315
  if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])):
316
+ status.update(label="Missing required columns.", state="error")
317
+ st.stop()
318
+
319
  status.update(label="Columns validated ✓")
320
  status.update(label="Predicting…")
321
 
322
  df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES])
323
  df_te["UCS_Pred"] = model.predict(df_te[FEATURES])
324
+ st.session_state.results["Train"] = df_tr
325
+ st.session_state.results["Test"] = df_te
326
 
327
  st.session_state.results["metrics_train"] = {
328
  "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]),
329
  "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]),
330
  "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]),
331
  }
332
+ st.session_state.results["metrics_test"] = {
333
  "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]),
334
  "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]),
335
  "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]),
336
  }
337
 
338
+ tr_min = df_tr[FEATURES].min().to_dict()
339
+ tr_max = df_tr[FEATURES].max().to_dict()
340
+ st.session_state.train_ranges = {f: (float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
341
 
342
+ status.update(label="Done ✓", state="complete")
343
+ toast("Model run complete 🚀")
344
 
345
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
346
  tab1, tab2 = st.tabs(["Training", "Testing"])
347
  if "Train" in st.session_state.results:
348
  with tab1:
349
+ df = st.session_state.results["Train"]
350
+ m = st.session_state.results["metrics_train"]
351
+ c1, c2, c3 = st.columns(3)
352
+ c1.metric("", f"{m['R2']:.4f}")
353
+ c2.metric("RMSE", f"{m['RMSE']:.4f}")
354
+ c3.metric("MAE", f"{m['MAE']:.4f}")
355
+ left, right = st.columns(2)
356
+ with left:
357
+ st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
358
+ with right:
359
+ st.pyplot(depth_or_index_track(df, "Training: Depth/Index Track", include_actual=True), use_container_width=True)
360
  if "Test" in st.session_state.results:
361
  with tab2:
362
+ df = st.session_state.results["Test"]
363
+ m = st.session_state.results["metrics_test"]
364
+ c1, c2, c3 = st.columns(3)
365
+ c1.metric("", f"{m['R2']:.4f}")
366
+ c2.metric("RMSE", f"{m['RMSE']:.4f}")
367
+ c3.metric("MAE", f"{m['MAE']:.4f}")
368
+ left, right = st.columns(2)
369
+ with left:
370
+ st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
371
+ with right:
372
+ st.pyplot(depth_or_index_track(df, "Testing: Depth/Index Track", include_actual=True), use_container_width=True)
373
 
374
  st.markdown("---")
375
+ sheets = {}
376
+ rows = []
377
  if "Train" in st.session_state.results:
378
  sheets["Train_with_pred"] = st.session_state.results["Train"]
379
+ rows.append({"Split": "Train", **{k: round(v, 6) for k, v in st.session_state.results["metrics_train"].items()}})
380
  if "Test" in st.session_state.results:
381
  sheets["Test_with_pred"] = st.session_state.results["Test"]
382
+ rows.append({"Split": "Test", **{k: round(v, 6) for k, v in st.session_state.results["metrics_test"].items()}})
383
  summary_df = pd.DataFrame(rows) if rows else None
384
  try:
385
  data_bytes = export_workbook(sheets, summary_df)
386
+ st.download_button(
387
+ "Export Train/Test Results to Excel",
388
+ data=data_bytes,
389
+ file_name="UCS_Dev_Results.xlsx",
390
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
391
+ )
392
  except RuntimeError as e:
393
  st.warning(str(e))
394
 
 
397
  # =========================
398
  if st.session_state.app_step == "predict":
399
  st.sidebar.header("Prediction (Validation)")
400
+ validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx", "xls"], key="val_upload")
401
  predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
402
  st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
403
 
 
407
  if predict_btn and validation_file is not None:
408
  with st.status("Predicting…", expanded=False) as status:
409
  vbook = read_book(validation_file)
410
+ if not vbook:
411
+ status.update(label="Could not read the Validation Excel.", state="error")
412
+ st.stop()
413
  status.update(label="Workbook read ✓")
414
+ vname = find_sheet(vbook, ["Validation", "Validate", "validation2", "Val", "val"]) or list(vbook.keys())[0]
415
  df_val = vbook[vname].copy()
416
+ if not ensure_cols(df_val, FEATURES):
417
+ status.update(label="Missing required columns.", state="error")
418
+ st.stop()
419
  status.update(label="Columns validated ✓")
420
  df_val["UCS_Pred"] = model.predict(df_val[FEATURES])
421
  st.session_state.results["Validate"] = df_val
422
 
423
+ # OOR check against training min–max
424
+ ranges = st.session_state.train_ranges
425
+ oor_table = None
426
+ oor_pct = 0.0
427
  if ranges:
428
  viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES}
429
+ any_viol = pd.DataFrame(viol).any(axis=1)
430
+ oor_pct = float(any_viol.mean() * 100.0)
431
  if any_viol.any():
432
  offenders = df_val.loc[any_viol, FEATURES].copy()
433
+ offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(
434
+ lambda r: ", ".join([c for c, v in r.items() if v]), axis=1
435
+ )
436
+ offenders.index = offenders.index + 1
437
+ oor_table = offenders
438
 
439
  metrics_val = None
440
  if TARGET in df_val.columns:
441
  metrics_val = {
442
  "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]),
443
  "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]),
444
+ "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"]),
445
  }
446
  st.session_state.results["metrics_val"] = metrics_val
447
  st.session_state.results["summary_val"] = {
448
  "n_points": len(df_val),
449
  "pred_min": float(df_val["UCS_Pred"].min()),
450
  "pred_max": float(df_val["UCS_Pred"].max()),
451
+ "oor_pct": oor_pct,
452
  }
453
  st.session_state.results["oor_table"] = oor_table
454
  status.update(label="Predictions ready ✓", state="complete")
455
 
456
  if "Validate" in st.session_state.results:
457
  st.subheader("Validation Results")
458
+ sv = st.session_state.results["summary_val"]
459
+ oor_table = st.session_state.results.get("oor_table")
460
+ c1, c2, c3, c4 = st.columns(4)
461
+ c1.metric("points", f"{sv['n_points']}")
462
+ c2.metric("Pred min", f"{sv['pred_min']:.2f}")
463
+ c3.metric("Pred max", f"{sv['pred_max']:.2f}")
464
+ c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
465
+ left, right = st.columns(2)
466
  with left:
467
  if TARGET in st.session_state.results["Validate"].columns:
468
+ st.pyplot(
469
+ cross_plot(
470
+ st.session_state.results["Validate"][TARGET],
471
+ st.session_state.results["Validate"]["UCS_Pred"],
472
+ "Validation: Actual vs Predicted",
473
+ ),
474
+ use_container_width=True,
475
+ )
476
  else:
477
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
478
  with right:
479
+ st.pyplot(
480
+ depth_or_index_track(
481
+ st.session_state.results["Validate"],
482
+ "Validation: Depth/Index Track",
483
+ include_actual=(TARGET in st.session_state.results["Validate"].columns),
484
+ ),
485
+ use_container_width=True,
486
+ )
487
  if oor_table is not None:
488
  st.write("*Out-of-range rows (vs. Training min–max):*")
489
  st.dataframe(oor_table, use_container_width=True)
 
491
  st.markdown("---")
492
  sheets = {"Validate_with_pred": st.session_state.results["Validate"]}
493
  rows = []
494
+ for name, key in [("Train", "metrics_train"), ("Test", "metrics_test"), ("Validate", "metrics_val")]:
495
  m = st.session_state.results.get(key)
496
+ if m:
497
+ rows.append({"Split": name, **{k: round(v, 6) for k, v in m.items()}})
498
  summary_df = pd.DataFrame(rows) if rows else None
499
  try:
500
  data_bytes = export_workbook(sheets, summary_df)
501
+ st.download_button(
502
+ "Export Validation Results to Excel",
503
+ data=data_bytes,
504
+ file_name="UCS_Validation_Results.xlsx",
505
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
506
+ )
507
  except RuntimeError as e:
508
  st.warning(str(e))
509
 
 
511
  # Footer
512
  # =========================
513
  st.markdown("---")
514
+ st.markdown(
515
+ "<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>",
516
+ unsafe_allow_html=True,
517
+ )