UCS2014 commited on
Commit
32095d5
·
verified ·
1 Parent(s): 7aae2b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -156
app.py CHANGED
@@ -12,10 +12,11 @@ matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
13
 
14
  import plotly.graph_objects as go
15
- from sklearn.metrics import mean_squared_error, mean_absolute_error
 
16
 
17
  # =========================
18
- # Constants
19
  # =========================
20
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
21
  TARGET = "UCS"
@@ -25,62 +26,40 @@ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
25
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
 
27
  # ---- Plot sizing controls (edit here) ----
28
- CROSS_W = 420; CROSS_H = 420 # square cross-plot (original look)
29
  TRACK_W = 400; TRACK_H = 950 # log-strip style (all pages)
30
- FONT_SZ = 13
31
- PLOT_COLS = [14, 0.5, 10] # 3-column band: left • spacer • right
32
- CROSS_NUDGE = 0.5 # inner columns [CROSS_NUDGE : 1] bigger = more right
 
33
 
34
  # =========================
35
  # Page / CSS
36
  # =========================
37
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
38
  st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
39
-
40
  st.markdown(
41
  """
42
  <style>
43
  .stApp { background:#fff; }
44
  section[data-testid="stSidebar"] { background:#F6F9FC; }
45
  .block-container { padding-top:.5rem; padding-bottom:.5rem; }
46
-
47
- /* Buttons look */
48
  .stButton>button { background:#007bff; color:#fff; font-weight:600; border-radius:8px; border:none; }
49
  .stButton>button:hover { background:#0056b3; }
50
-
51
- /* Brand header */
52
- .st-hero { display:flex; align-items:center; gap:16px; padding-top:4px; }
53
  .st-hero .brand { width:110px; height:110px; object-fit:contain; }
54
  .st-hero h1 { margin:0; line-height:1.05; }
55
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
56
-
57
  [data-testid="stBlock"]{ margin-top:0 !important; }
58
-
59
- /* Remove drag & drop + limit lines — keep Browse button */
60
- [data-testid="stFileUploadDropzone"] p,
61
- [data-testid="stFileUploadDropzone"] [data-testid="stFileUploaderInstructions"],
62
- [data-testid="stFileUploadDropzone"] [data-testid="stCaptionContainer"]{
63
- display:none !important;
64
- }
65
-
66
- /* Pinned title/helper area */
67
- .pinned-top{
68
- position:sticky; top:0; z-index:999; background:#fff; padding-top:4px;
69
- }
70
-
71
- /* Center every table cell we render via HTML */
72
- .table-center table { margin-left:auto; margin-right:auto; border-collapse:collapse; }
73
- .table-center table th, .table-center table td {
74
- text-align:center !important; padding:6px 10px; border:1px solid #e5e7eb;
75
- }
76
- .table-center table thead th { background:#f8fafc; }
77
  </style>
78
  """,
79
  unsafe_allow_html=True
80
  )
81
 
82
  # =========================
83
- # Password gate
84
  # =========================
85
  def inline_logo(path="logo.png") -> str:
86
  try:
@@ -147,15 +126,21 @@ add_password_gate()
147
  # =========================
148
  # Utilities
149
  # =========================
150
- def rmse(y_true, y_pred) -> float:
151
- return float(np.sqrt(mean_squared_error(y_true, y_pred)))
 
 
 
 
 
 
 
 
 
 
152
 
153
- def corrcoef_safe(y_true, y_pred) -> float:
154
- a = pd.Series(y_true, dtype=float)
155
- b = pd.Series(y_pred, dtype=float)
156
- m = np.isfinite(a) & np.isfinite(b)
157
- if not m.any(): return float("nan")
158
- return float(np.corrcoef(a[m], b[m])[0, 1])
159
 
160
  @st.cache_resource(show_spinner=False)
161
  def load_model(model_path: str):
@@ -182,36 +167,25 @@ def find_sheet(book, names):
182
  if nm.lower() in low2orig: return low2orig[nm.lower()]
183
  return None
184
 
185
- def _nice_step(lo: float, hi: float, target_ticks: int = 6) -> float:
186
- rng = max(hi - lo, 1.0)
187
- raw = rng / max(target_ticks, 1)
188
- mag = 10 ** math.floor(math.log10(raw))
189
- for m in [1, 2, 2.5, 5, 10]:
190
- step = m * mag
191
- if raw <= step:
192
- return step
193
- return mag * 10
194
-
195
- def _nice_tick0(xmin: float, step: float) -> float:
196
- if not np.isfinite(xmin): return xmin
197
  return step * math.floor(xmin / step)
198
 
199
- def html_table_center(df: pd.DataFrame, index: bool = False):
200
- html = df.to_html(index=index, classes="table-center")
201
- st.markdown(html, unsafe_allow_html=True)
202
-
203
  # ---------- Plot builders ----------
204
  def cross_plot(actual, pred):
205
  a = pd.Series(actual).astype(float)
206
  p = pd.Series(pred).astype(float)
207
-
208
- # Symmetric / identical axis range & ticks
209
  lo = float(np.nanmin([a.min(), p.min()]))
210
  hi = float(np.nanmax([a.max(), p.max()]))
211
  pad = 0.03 * (hi - lo if hi > lo else 1.0)
212
- lo -= pad; hi += pad
213
- step = _nice_step(lo, hi, target_ticks=6)
214
- tick0 = _nice_tick0(lo, step)
 
 
 
215
 
216
  fig = go.Figure()
217
  fig.add_trace(go.Scatter(
@@ -220,27 +194,25 @@ def cross_plot(actual, pred):
220
  hovertemplate="Actual: %{x:.0f}<br>Pred: %{y:.0f}<extra></extra>",
221
  showlegend=False
222
  ))
223
- # 45° reference
224
  fig.add_trace(go.Scatter(
225
- x=[lo, hi], y=[lo, hi], mode="lines",
226
  line=dict(color=COLORS["ref"], width=1.2, dash="dash"),
227
  hoverinfo="skip", showlegend=False
228
  ))
229
-
230
  fig.update_layout(
231
  width=CROSS_W, height=CROSS_H, paper_bgcolor="#fff", plot_bgcolor="#fff",
232
  margin=dict(l=64, r=18, t=10, b=48), hovermode="closest",
233
  font=dict(size=FONT_SZ)
234
  )
235
- axis_common = dict(
236
- range=[lo, hi], ticks="outside", tickformat=",.0f",
237
- tick0=tick0, dtick=step,
238
- showline=True, linewidth=1.2, linecolor="#444", mirror=True,
239
- showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
240
- )
241
- fig.update_xaxes(title_text="<b>Actual UCS (psi)</b>", **axis_common)
242
- fig.update_yaxes(title_text="<b>Predicted UCS (psi)</b>", **axis_common,
243
- scaleanchor="x", scaleratio=1)
244
  return fig
245
 
246
  def track_plot(df, include_actual=True):
@@ -249,21 +221,21 @@ def track_plot(df, include_actual=True):
249
  y = pd.Series(df[depth_col]).astype(float)
250
  ylab = depth_col
251
  y_min, y_max = float(y.min()), float(y.max())
252
- y_range = [y_max, y_min]
253
  else:
254
  y = pd.Series(np.arange(1, len(df) + 1))
255
  ylab = "Point Index"
256
  y_min, y_max = float(y.min()), float(y.max())
257
  y_range = [y_max, y_min]
258
 
 
259
  x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
260
  if include_actual and TARGET in df.columns:
261
  x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
262
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
263
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
264
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
265
- x_step = _nice_step(xmin, xmax, target_ticks=6)
266
- tick0 = _nice_tick0(xmin, x_step)
267
 
268
  fig = go.Figure()
269
  fig.add_trace(go.Scatter(
@@ -293,7 +265,7 @@ def track_plot(df, include_actual=True):
293
  fig.update_xaxes(
294
  title_text="<b>UCS (psi)</b>", side="top", range=[xmin, xmax],
295
  ticks="outside", tickformat=",.0f",
296
- tickmode="linear", tick0=tick0, dtick=x_step,
297
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
298
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
299
  )
@@ -312,7 +284,8 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
312
  if n == 0:
313
  fig, ax = plt.subplots(figsize=(4, 2))
314
  ax.text(0.5,0.5,"No selected columns",ha="center",va="center")
315
- ax.axis("off"); return fig
 
316
  fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
317
  if n == 1: axes = [axes]
318
  idx = np.arange(1, len(df) + 1)
@@ -347,13 +320,11 @@ def preview_modal(book: dict[str, pd.DataFrame]):
347
  t1, t2 = st.tabs(["Tracks", "Summary"])
348
  with t1: st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
349
  with t2:
350
- tbl = df[FEATURES].agg(['min','max','mean','std']).T.rename(
351
- columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}
352
- ).reset_index(names="Feature")
353
- html_table_center(tbl, index=False)
354
 
355
  # =========================
356
- # Load model
357
  # =========================
358
  def ensure_model() -> Path|None:
359
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
@@ -425,7 +396,7 @@ if st.session_state.app_step == "intro":
425
  st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
426
  st.subheader("How It Works")
427
  st.markdown(
428
- "1) **Upload your data to build the case and preview the performance of our model.** \n"
429
  "2) Click **Run Model** to compute metrics and plots. \n"
430
  "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
431
  )
@@ -449,29 +420,26 @@ if st.session_state.app_step == "dev":
449
  df0 = next(iter(tmp.values()))
450
  st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
451
 
452
- # ---- Pinned title/helper FIRST (so it never appears below preview) ----
453
- st.markdown('<div class="pinned-top">', unsafe_allow_html=True)
454
- st.subheader("Case Building")
455
- if st.session_state.dev_file_loaded and st.session_state.dev_preview:
456
- st.info("Previewed ✓ — now click **Run Model**.")
457
- elif st.session_state.dev_file_loaded:
458
- st.info("📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
459
- else:
460
- st.write("**Upload your data to build a case, then run the model to review development performance.**")
461
- st.markdown('</div>', unsafe_allow_html=True)
462
-
463
- # Preview button ALWAYS enabled
464
- if st.sidebar.button("Preview data", use_container_width=True):
465
- if not st.session_state.dev_file_loaded:
466
- st.warning("Upload an Excel file first, then preview.")
467
- else:
468
- preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
469
- st.session_state.dev_preview = True
470
 
471
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
 
472
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
473
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
474
 
 
 
 
 
 
 
 
 
 
 
 
475
  if run and st.session_state.dev_file_bytes:
476
  book = read_book_bytes(st.session_state.dev_file_bytes)
477
  sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
@@ -481,21 +449,12 @@ if st.session_state.app_step == "dev":
481
  tr = book[sh_train].copy(); te = book[sh_test].copy()
482
  if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
483
  st.error("Missing required columns."); st.stop()
484
-
485
  tr["UCS_Pred"] = model.predict(tr[FEATURES])
486
  te["UCS_Pred"] = model.predict(te[FEATURES])
487
 
488
  st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
489
- st.session_state.results["m_train"]={
490
- "R": corrcoef_safe(tr[TARGET], tr["UCS_Pred"]),
491
- "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
492
- "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
493
- }
494
- st.session_state.results["m_test"] ={
495
- "R": corrcoef_safe(te[TARGET], te["UCS_Pred"]),
496
- "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
497
- "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
498
- }
499
 
500
  tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
501
  st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
@@ -506,13 +465,19 @@ if st.session_state.app_step == "dev":
506
  c1.metric("R", f"{m['R']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
507
  left, spacer, right = st.columns(PLOT_COLS)
508
  with left:
509
- pad, plotcol = left.columns([CROSS_NUDGE, 1])
510
  with plotcol:
511
- st.plotly_chart(cross_plot(df[TARGET], df["UCS_Pred"]),
512
- use_container_width=False, config={"displayModeBar": False, "scrollZoom": True})
 
 
 
513
  with right:
514
- st.plotly_chart(track_plot(df, include_actual=True),
515
- use_container_width=False, config={"displayModeBar": False, "scrollZoom": True})
 
 
 
516
 
517
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
518
  tab1, tab2 = st.tabs(["Training", "Testing"])
@@ -532,22 +497,15 @@ if st.session_state.app_step == "validate":
532
  if book:
533
  df0 = next(iter(book.values()))
534
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
535
-
536
- # pinned title/helper first
537
- st.markdown('<div class="pinned-top">', unsafe_allow_html=True)
538
- st.subheader("Validate the Model")
539
- st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
540
- st.markdown('</div>', unsafe_allow_html=True)
541
-
542
- if st.sidebar.button("Preview data", use_container_width=True):
543
- if up is None:
544
- st.warning("Upload an Excel file first, then preview.")
545
- else:
546
- preview_modal(read_book_bytes(up.getvalue()))
547
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
548
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
549
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
550
 
 
 
 
551
  if go_btn and up is not None:
552
  book = read_book_bytes(up.getvalue())
553
  name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
@@ -563,13 +521,8 @@ if st.session_state.app_step == "validate":
563
  if any_viol.any():
564
  tbl = df.loc[any_viol, FEATURES].copy()
565
  tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
566
- st.session_state.results["m_val"]={
567
- "R": corrcoef_safe(df[TARGET], df["UCS_Pred"]),
568
- "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
569
- "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
570
- }
571
- st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),
572
- "pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
573
  st.session_state.results["oor_tbl"]=tbl
574
 
575
  if "Validate" in st.session_state.results:
@@ -579,7 +532,7 @@ if st.session_state.app_step == "validate":
579
 
580
  left, spacer, right = st.columns(PLOT_COLS)
581
  with left:
582
- pad, plotcol = left.columns([CROSS_NUDGE, 1])
583
  with plotcol:
584
  st.plotly_chart(
585
  cross_plot(st.session_state.results["Validate"][TARGET],
@@ -596,7 +549,7 @@ if st.session_state.app_step == "validate":
596
  if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
597
  if st.session_state.results["oor_tbl"] is not None:
598
  st.write("*Out-of-range rows (vs. Training min–max):*")
599
- html_table_center(st.session_state.results["oor_tbl"].reset_index(drop=True), index=False)
600
 
601
  # =========================
602
  # PREDICTION (no actual UCS)
@@ -609,20 +562,13 @@ if st.session_state.app_step == "predict":
609
  if book:
610
  df0 = next(iter(book.values()))
611
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
 
 
 
 
612
 
613
- # pinned title/helper first
614
- st.markdown('<div class="pinned-top">', unsafe_allow_html=True)
615
  st.subheader("Prediction")
616
  st.write("Upload a dataset with the feature columns (no **UCS**).")
617
- st.markdown('</div>', unsafe_allow_html=True)
618
-
619
- if st.sidebar.button("Preview data", use_container_width=True):
620
- if up is None:
621
- st.warning("Upload an Excel file first, then preview.")
622
- else:
623
- preview_modal(read_book_bytes(up.getvalue()))
624
- go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
625
- if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
626
 
627
  if go_btn and up is not None:
628
  book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
@@ -649,13 +595,12 @@ if st.session_state.app_step == "predict":
649
 
650
  left, spacer, right = st.columns(PLOT_COLS)
651
  with left:
652
- st.success("Predictions ready ✓")
653
  table = pd.DataFrame({
654
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
655
- "Value": [sv["n"], sv["pred_min"], sv["pred_max"],
656
- sv["pred_mean"], sv["pred_std"], f'{sv["oor"]:.1f}%']
657
  })
658
- html_table_center(table, index=False)
 
659
  st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
660
  with right:
661
  st.plotly_chart(
@@ -675,4 +620,4 @@ st.markdown(
675
  </div>
676
  """,
677
  unsafe_allow_html=True
678
- )
 
12
  import matplotlib.pyplot as plt
13
 
14
  import plotly.graph_objects as go
15
+ from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
16
+ from scipy.stats import pearsonr
17
 
18
  # =========================
19
+ # Constants (simple & robust)
20
  # =========================
21
  FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
22
  TARGET = "UCS"
 
26
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
27
 
28
  # ---- Plot sizing controls (edit here) ----
29
+ CROSS_W = 500; CROSS_H = 500 # square cross-plot (Build + Validate)
30
  TRACK_W = 400; TRACK_H = 950 # log-strip style (all pages)
31
+ FONT_SZ = 13
32
+ PLOT_COLS = [14, 0.5, 10] # 3-column band: left • spacer • right (Build + Validate)
33
+ CROSS_NUDGE = 0.5 # push cross-plot to the RIGHT inside its band:
34
+ # inner columns [CROSS_NUDGE : 1] → bigger = more right
35
 
36
  # =========================
37
  # Page / CSS
38
  # =========================
39
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
40
  st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
 
41
  st.markdown(
42
  """
43
  <style>
44
  .stApp { background:#fff; }
45
  section[data-testid="stSidebar"] { background:#F6F9FC; }
46
  .block-container { padding-top:.5rem; padding-bottom:.5rem; }
 
 
47
  .stButton>button { background:#007bff; color:#fff; font-weight:600; border-radius:8px; border:none; }
48
  .stButton>button:hover { background:#0056b3; }
49
+ .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
 
 
50
  .st-hero .brand { width:110px; height:110px; object-fit:contain; }
51
  .st-hero h1 { margin:0; line-height:1.05; }
52
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
 
53
  [data-testid="stBlock"]{ margin-top:0 !important; }
54
+ /* Center align text in table cells */
55
+ .st-emotion-cache-1wq06yv { text-align: center; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  </style>
57
  """,
58
  unsafe_allow_html=True
59
  )
60
 
61
  # =========================
62
+ # Password gate (define first, then call)
63
  # =========================
64
  def inline_logo(path="logo.png") -> str:
65
  try:
 
126
  # =========================
127
  # Utilities
128
  # =========================
129
+ try:
130
+ dialog = st.dialog
131
+ except AttributeError:
132
+ def dialog(title):
133
+ def deco(fn):
134
+ def wrapper(*args, **kwargs):
135
+ with st.expander(title, expanded=True):
136
+ return fn(*args, **kwargs)
137
+ return wrapper
138
+ return deco
139
+
140
+ def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
141
 
142
+ def correlation_coefficient(y_true, y_pred):
143
+ return pearsonr(y_true, y_pred)[0]
 
 
 
 
144
 
145
  @st.cache_resource(show_spinner=False)
146
  def load_model(model_path: str):
 
167
  if nm.lower() in low2orig: return low2orig[nm.lower()]
168
  return None
169
 
170
+ def _nice_tick0(xmin: float, step: int = 100) -> float:
171
+ """Round xmin down to a sensible multiple so the first tick sits at the left edge."""
172
+ if not np.isfinite(xmin):
173
+ return xmin
 
 
 
 
 
 
 
 
174
  return step * math.floor(xmin / step)
175
 
 
 
 
 
176
  # ---------- Plot builders ----------
177
  def cross_plot(actual, pred):
178
  a = pd.Series(actual).astype(float)
179
  p = pd.Series(pred).astype(float)
 
 
180
  lo = float(np.nanmin([a.min(), p.min()]))
181
  hi = float(np.nanmax([a.max(), p.max()]))
182
  pad = 0.03 * (hi - lo if hi > lo else 1.0)
183
+ x0, x1 = lo - pad, hi + pad
184
+
185
+ # Get the global min and max of all data points for consistent scaling
186
+ all_values = pd.concat([a, p]).dropna()
187
+ global_min = all_values.min()
188
+ global_max = all_values.max()
189
 
190
  fig = go.Figure()
191
  fig.add_trace(go.Scatter(
 
194
  hovertemplate="Actual: %{x:.0f}<br>Pred: %{y:.0f}<extra></extra>",
195
  showlegend=False
196
  ))
 
197
  fig.add_trace(go.Scatter(
198
+ x=[global_min, global_max], y=[global_min, global_max], mode="lines",
199
  line=dict(color=COLORS["ref"], width=1.2, dash="dash"),
200
  hoverinfo="skip", showlegend=False
201
  ))
 
202
  fig.update_layout(
203
  width=CROSS_W, height=CROSS_H, paper_bgcolor="#fff", plot_bgcolor="#fff",
204
  margin=dict(l=64, r=18, t=10, b=48), hovermode="closest",
205
  font=dict(size=FONT_SZ)
206
  )
207
+ fig.update_xaxes(title_text="<b>Actual UCS (psi)</b>", range=[global_min, global_max],
208
+ ticks="outside", tickformat=",.0f",
209
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
210
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True)
211
+ fig.update_yaxes(title_text="<b>Predicted UCS (psi)</b>", range=[global_min, global_max],
212
+ ticks="outside", tickformat=",.0f",
213
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
214
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
215
+ scaleanchor="x", scaleratio=1, automargin=True)
216
  return fig
217
 
218
  def track_plot(df, include_actual=True):
 
221
  y = pd.Series(df[depth_col]).astype(float)
222
  ylab = depth_col
223
  y_min, y_max = float(y.min()), float(y.max())
224
+ y_range = [y_max, y_min] # reversed for log profile style
225
  else:
226
  y = pd.Series(np.arange(1, len(df) + 1))
227
  ylab = "Point Index"
228
  y_min, y_max = float(y.min()), float(y.max())
229
  y_range = [y_max, y_min]
230
 
231
+ # X (UCS) range & ticks
232
  x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
233
  if include_actual and TARGET in df.columns:
234
  x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
235
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
236
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
237
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
238
+ tick0 = _nice_tick0(xmin, step=100) # sensible first tick at left border
 
239
 
240
  fig = go.Figure()
241
  fig.add_trace(go.Scatter(
 
265
  fig.update_xaxes(
266
  title_text="<b>UCS (psi)</b>", side="top", range=[xmin, xmax],
267
  ticks="outside", tickformat=",.0f",
268
+ tickmode="auto", tick0=tick0,
269
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
270
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
271
  )
 
284
  if n == 0:
285
  fig, ax = plt.subplots(figsize=(4, 2))
286
  ax.text(0.5,0.5,"No selected columns",ha="center",va="center")
287
+ ax.axis("off")
288
+ return fig
289
  fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
290
  if n == 1: axes = [axes]
291
  idx = np.arange(1, len(df) + 1)
 
320
  t1, t2 = st.tabs(["Tracks", "Summary"])
321
  with t1: st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
322
  with t2:
323
+ tbl = df[FEATURES].agg(['min','max','mean','std']).T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"})
324
+ st.dataframe(tbl.reset_index(names="Feature"), use_container_width=True)
 
 
325
 
326
  # =========================
327
+ # Load model (simple)
328
  # =========================
329
  def ensure_model() -> Path|None:
330
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
 
396
  st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
397
  st.subheader("How It Works")
398
  st.markdown(
399
+ "1) **Upload your data to build the case and preview the performance of our model.** \n"
400
  "2) Click **Run Model** to compute metrics and plots. \n"
401
  "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
402
  )
 
420
  df0 = next(iter(tmp.values()))
421
  st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
422
 
423
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
424
+ preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
425
+ st.session_state.dev_preview = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
428
+ # always available nav
429
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
430
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
431
 
432
+ # ---- Pinned helper at the very top of the page ----
433
+ helper_top = st.container()
434
+ with helper_top:
435
+ st.subheader("Case Building")
436
+ if st.session_state.dev_file_loaded and st.session_state.dev_preview:
437
+ st.info("Previewed ✓ — now click **Run Model**.")
438
+ elif st.session_state.dev_file_loaded:
439
+ st.info("📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
440
+ else:
441
+ st.write("**Upload your data to build a case, then run the model to review development performance.**")
442
+
443
  if run and st.session_state.dev_file_bytes:
444
  book = read_book_bytes(st.session_state.dev_file_bytes)
445
  sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
 
449
  tr = book[sh_train].copy(); te = book[sh_test].copy()
450
  if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
451
  st.error("Missing required columns."); st.stop()
 
452
  tr["UCS_Pred"] = model.predict(tr[FEATURES])
453
  te["UCS_Pred"] = model.predict(te[FEATURES])
454
 
455
  st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
456
+ st.session_state.results["m_train"]={"R":correlation_coefficient(tr[TARGET],tr["UCS_Pred"]), "RMSE":rmse(tr[TARGET],tr["UCS_Pred"]), "MAE":mean_absolute_error(tr[TARGET],tr["UCS_Pred"])}
457
+ st.session_state.results["m_test"] ={"R":correlation_coefficient(te[TARGET],te["UCS_Pred"]), "RMSE":rmse(te[TARGET],te["UCS_Pred"]), "MAE":mean_absolute_error(te[TARGET],te["UCS_Pred"])}
 
 
 
 
 
 
 
 
458
 
459
  tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
460
  st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
 
465
  c1.metric("R", f"{m['R']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
466
  left, spacer, right = st.columns(PLOT_COLS)
467
  with left:
468
+ pad, plotcol = left.columns([CROSS_NUDGE, 1]) # shift cross-plot right inside its band
469
  with plotcol:
470
+ st.plotly_chart(
471
+ cross_plot(df[TARGET], df["UCS_Pred"]),
472
+ use_container_width=False,
473
+ config={"displayModeBar": False, "scrollZoom": True}
474
+ )
475
  with right:
476
+ st.plotly_chart(
477
+ track_plot(df, include_actual=True),
478
+ use_container_width=False,
479
+ config={"displayModeBar": False, "scrollZoom": True}
480
+ )
481
 
482
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
483
  tab1, tab2 = st.tabs(["Training", "Testing"])
 
497
  if book:
498
  df0 = next(iter(book.values()))
499
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
500
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
501
+ preview_modal(read_book_bytes(up.getvalue()))
 
 
 
 
 
 
 
 
 
 
502
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
503
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
504
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
505
 
506
+ st.subheader("Validate the Model")
507
+ st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
508
+
509
  if go_btn and up is not None:
510
  book = read_book_bytes(up.getvalue())
511
  name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
 
521
  if any_viol.any():
522
  tbl = df.loc[any_viol, FEATURES].copy()
523
  tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
524
+ st.session_state.results["m_val"]={"R":correlation_coefficient(df[TARGET],df["UCS_Pred"]), "RMSE":rmse(df[TARGET],df["UCS_Pred"]), "MAE":mean_absolute_error(df[TARGET],df["UCS_Pred"])}
525
+ st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
 
 
 
 
 
526
  st.session_state.results["oor_tbl"]=tbl
527
 
528
  if "Validate" in st.session_state.results:
 
532
 
533
  left, spacer, right = st.columns(PLOT_COLS)
534
  with left:
535
+ pad, plotcol = left.columns([CROSS_NUDGE, 1]) # same nudge
536
  with plotcol:
537
  st.plotly_chart(
538
  cross_plot(st.session_state.results["Validate"][TARGET],
 
549
  if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
550
  if st.session_state.results["oor_tbl"] is not None:
551
  st.write("*Out-of-range rows (vs. Training min–max):*")
552
+ st.dataframe(st.session_state.results["oor_tbl"], use_container_width=True)
553
 
554
  # =========================
555
  # PREDICTION (no actual UCS)
 
562
  if book:
563
  df0 = next(iter(book.values()))
564
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
565
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
566
+ preview_modal(read_book_bytes(up.getvalue()))
567
+ go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
568
+ if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
569
 
 
 
570
  st.subheader("Prediction")
571
  st.write("Upload a dataset with the feature columns (no **UCS**).")
 
 
 
 
 
 
 
 
 
572
 
573
  if go_btn and up is not None:
574
  book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
 
595
 
596
  left, spacer, right = st.columns(PLOT_COLS)
597
  with left:
 
598
  table = pd.DataFrame({
599
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
600
+ "Value": [sv["n"], sv["pred_min"], sv["pred_max"], sv["pred_mean"], sv["pred_std"], f'{sv["oor"]:.1f}%']
 
601
  })
602
+ st.success("Predictions ready ✓")
603
+ st.dataframe(table, use_container_width=True, hide_index=True)
604
  st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
605
  with right:
606
  st.plotly_chart(
 
620
  </div>
621
  """,
622
  unsafe_allow_html=True
623
+ )