UCS2014 commited on
Commit
fb43ee5
·
verified ·
1 Parent(s): eef4094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -108
app.py CHANGED
@@ -6,7 +6,7 @@ import pandas as pd
6
  import numpy as np
7
  import joblib
8
 
9
- # Matplotlib for PREVIEW modal and CROSS-PLOT (static)
10
  import matplotlib
11
  matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
@@ -26,58 +26,43 @@ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
26
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
27
 
28
  # ---- Plot sizing controls ----
29
- CROSS_W = 420 # px (explicit, not None)
30
- CROSS_H = 420 # px
31
- TRACK_W = 320 # px (slightly narrower to avoid crowding)
32
- TRACK_H = 740 # px
33
  FONT_SZ = 13
34
- PLOT_COLS = [36, 6, 28] # wider spacer so plots never bump into each other
35
- CROSS_NUDGE = 0.0 # keep 0 unless you really want an inner pad
36
 
37
  # =========================
38
  # Page / CSS
39
  # =========================
40
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
41
- st.markdown("""
42
- <style>
43
-   /* Reusable logo style */
44
-   .brand-logo { width: 16px; height: auto; object-fit: contain; }
45
 
46
-   /* Sidebar header layout */
47
-   .sidebar-header { display:flex; align-items:center; gap:12px; }
48
-   .sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
49
-   .sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
50
- </style>
51
- """, unsafe_allow_html=True)
52
- # Hide file-uploader helper text; keep only Browse button
53
  st.markdown("""
54
  <style>
55
  /* Older builds (helper wrapped in a Markdown container) */
56
- section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"] { display:none !important; }
57
- /* 1.31–1.34 style (first child inside the dropzone is the helper row) */
58
- section[data-testid="stFileUploader"] [data-testid="stFileUploaderDropzone"] > div:first-child { display:none !important; }
59
- /* 1.35+ explicit helper container */
60
- section[data-testid="stFileUploader"] [data-testid="stFileUploaderInstructions"] { display:none !important; }
61
- /* Fallback: any paragraph/small helper inside uploader */
62
- section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] small { display:none !important; }
63
-
64
- /* Center headers & cells in all st.dataframe tables */
65
- [data-testid="stDataFrame"] table td,
66
- [data-testid="stDataFrame"] table th {
67
- text-align: center !important;
68
- vertical-align: middle !important;
69
- }
70
  </style>
71
  """, unsafe_allow_html=True)
72
 
 
 
 
 
 
 
73
  # =========================
74
  # Password gate
75
  # =========================
76
  def inline_logo(path="logo.png") -> str:
77
  try:
78
  p = Path(path)
79
- if not p.exists():
80
- return ""
81
  return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
82
  except Exception:
83
  return ""
@@ -95,7 +80,7 @@ def add_password_gate() -> None:
95
  if st.session_state.get("auth_ok", False):
96
  return
97
 
98
- st.sidebar.image("logo.png", use_column_width=False)
99
  st.sidebar.markdown("### ST_GeoMech_UCS\nSmart Thinking • Secure Access")
100
  pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
101
  if st.sidebar.button("Unlock", type="primary"):
@@ -111,25 +96,13 @@ add_password_gate()
111
  # =========================
112
  # Utilities
113
  # =========================
114
- try:
115
- dialog = st.dialog
116
- except AttributeError:
117
- def dialog(title):
118
- def deco(fn):
119
- def wrapper(*args, **kwargs):
120
- with st.expander(title, expanded=True):
121
- return fn(*args, **kwargs)
122
- return wrapper
123
- return deco
124
-
125
  def rmse(y_true, y_pred) -> float:
126
  return float(np.sqrt(mean_squared_error(y_true, y_pred)))
127
 
128
  def pearson_r(y_true, y_pred) -> float:
129
  a = np.asarray(y_true, dtype=float)
130
  p = np.asarray(y_pred, dtype=float)
131
- if a.size < 2:
132
- return float("nan")
133
  return float(np.corrcoef(a, p)[0, 1])
134
 
135
  @st.cache_resource(show_spinner=False)
@@ -142,8 +115,7 @@ def parse_excel(data_bytes: bytes):
142
  xl = pd.ExcelFile(bio)
143
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
144
 
145
- def read_book_bytes(b: bytes):
146
- return parse_excel(b) if b else {}
147
 
148
  def ensure_cols(df, cols):
149
  miss = [c for c in cols if c not in df.columns]
@@ -155,13 +127,22 @@ def ensure_cols(df, cols):
155
  def find_sheet(book, names):
156
  low2orig = {k.lower(): k for k in book.keys()}
157
  for nm in names:
158
- if nm.lower() in low2orig:
159
- return low2orig[nm.lower()]
160
  return None
161
 
162
  def _nice_tick0(xmin: float, step: int = 100) -> float:
163
  return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
164
 
 
 
 
 
 
 
 
 
 
 
165
  # =========================
166
  # Cross plot (Matplotlib, fixed limits & ticks)
167
  # =========================
@@ -172,7 +153,6 @@ def cross_plot_static(actual, pred):
172
  fixed_min, fixed_max = 6000, 10000
173
  ticks = np.arange(fixed_min, fixed_max + 1, 1000)
174
 
175
- # fixed px size = (figsize * dpi)
176
  dpi = 110
177
  fig, ax = plt.subplots(
178
  figsize=(CROSS_W / dpi, CROSS_H / dpi),
@@ -180,43 +160,32 @@ def cross_plot_static(actual, pred):
180
  constrained_layout=False
181
  )
182
 
183
- # points
184
  ax.scatter(a, p, s=16, c=COLORS["pred"], alpha=0.9, linewidths=0)
185
-
186
- # 1:1 diagonal
187
  ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
188
  linestyle="--", linewidth=1.2, color=COLORS["ref"])
189
 
190
- # identical axes limits + ticks
191
  ax.set_xlim(fixed_min, fixed_max)
192
  ax.set_ylim(fixed_min, fixed_max)
193
  ax.set_xticks(ticks)
194
  ax.set_yticks(ticks)
195
-
196
- # equal aspect → true 45°
197
  ax.set_aspect("equal", adjustable="box")
198
 
199
- # thousands formatting
200
  fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
201
  ax.xaxis.set_major_formatter(fmt)
202
  ax.yaxis.set_major_formatter(fmt)
203
 
204
- # labels + ticks (smaller so they don't dominate)
205
  ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=12)
206
  ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=12)
207
  ax.tick_params(labelsize=10)
208
 
209
- # grid & frame
210
  ax.grid(True, linestyle=":", alpha=0.3)
211
  for spine in ax.spines.values():
212
  spine.set_linewidth(1.1)
213
  spine.set_color("#444")
214
 
215
- # moderate margins; keeps labels readable but not huge
216
  fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
217
  return fig
218
 
219
-
220
  # =========================
221
  # Track plot (Plotly)
222
  # =========================
@@ -256,8 +225,9 @@ def track_plot(df, include_actual=True):
256
  ))
257
 
258
  fig.update_layout(
259
- width=TRACK_W, height=TRACK_H, paper_bgcolor="#fff", plot_bgcolor="#fff",
260
- margin=dict(l=72, r=18, t=36, b=48), hovermode="closest",
 
261
  font=dict(size=FONT_SZ),
262
  legend=dict(
263
  x=0.98, y=0.05, xanchor="right", yanchor="bottom",
@@ -266,14 +236,14 @@ def track_plot(df, include_actual=True):
266
  legend_title_text=""
267
  )
268
  fig.update_xaxes(
269
- title_text="<b>UCS (psi)</b>", title_font=dict(size=16),
270
  side="top", range=[xmin, xmax],
271
  ticks="outside", tickformat=",.0f", tickmode="auto", tick0=tick0,
272
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
273
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
274
  )
275
  fig.update_yaxes(
276
- title_text=f"<b>{ylab}</b>", title_font=dict(size=16),
277
  range=y_range, ticks="outside",
278
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
279
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
@@ -286,8 +256,7 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
286
  n = len(cols)
287
  if n == 0:
288
  fig, ax = plt.subplots(figsize=(4, 2))
289
- ax.text(0.5,0.5,"No selected columns",ha="center",va="center")
290
- ax.axis("off")
291
  return fig
292
  fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
293
  if n == 1: axes = [axes]
@@ -300,6 +269,18 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
300
  axes[0].set_ylabel("Point Index")
301
  return fig
302
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  @dialog("Preview data")
304
  def preview_modal(book: dict[str, pd.DataFrame]):
305
  if not book:
@@ -310,12 +291,13 @@ def preview_modal(book: dict[str, pd.DataFrame]):
310
  with t:
311
  df = book[name]
312
  t1, t2 = st.tabs(["Tracks", "Summary"])
313
- with t1: st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
 
314
  with t2:
315
- tbl = df[FEATURES].agg(['min','max','mean','std']).T.rename(
316
- columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}
317
- ).round(2)
318
- st.dataframe(tbl.reset_index(names="Feature"), use_container_width=True)
319
 
320
  # =========================
321
  # Load model
@@ -369,7 +351,7 @@ st.session_state.setdefault("dev_preview",False)
369
  # =========================
370
  # Branding in Sidebar
371
  # =========================
372
- st.sidebar.image("logo.png", use_column_width=False)
373
  st.sidebar.markdown(
374
  "<div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>"
375
  "<div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>",
@@ -413,11 +395,9 @@ if st.session_state.app_step == "dev":
413
  st.session_state.dev_preview = True
414
 
415
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
416
- # nav
417
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
418
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
419
 
420
- # ---- Pinned helper at the very top ----
421
  helper_top = st.container()
422
  with helper_top:
423
  st.subheader("Case Building")
@@ -460,18 +440,14 @@ if st.session_state.app_step == "dev":
460
  c1,c2,c3 = st.columns(3)
461
  c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
462
 
463
- left, spacer, right = st.columns(PLOT_COLS)
464
- # Robust nudge: only make the inner columns if weight > 0
465
- if CROSS_NUDGE and CROSS_NUDGE > 0:
466
- pad, plotcol = left.columns([CROSS_NUDGE, 1])
467
- else:
468
- plotcol = left
469
- with plotcol:
470
- st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
471
- with right:
472
  st.plotly_chart(
473
  track_plot(df, include_actual=True),
474
- use_container_width=False,
475
  config={"displayModeBar": False, "scrollZoom": True}
476
  )
477
 
@@ -517,11 +493,8 @@ if st.session_state.app_step == "validate":
517
  if any_viol.any():
518
  tbl = df.loc[any_viol, FEATURES].copy()
519
  for c in FEATURES:
520
- if pd.api.types.is_numeric_dtype(tbl[c]):
521
- tbl[c] = tbl[c].round(2)
522
- tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(
523
- lambda r:", ".join([c for c,v in r.items() if v]), axis=1
524
- )
525
  st.session_state.results["m_val"]={
526
  "R": pearson_r(df[TARGET], df["UCS_Pred"]),
527
  "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
@@ -535,28 +508,24 @@ if st.session_state.app_step == "validate":
535
  c1,c2,c3 = st.columns(3)
536
  c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
537
 
538
- left, spacer, right = st.columns(PLOT_COLS)
539
- if CROSS_NUDGE and CROSS_NUDGE > 0:
540
- pad, plotcol = left.columns([CROSS_NUDGE, 1])
541
- else:
542
- plotcol = left
543
- with plotcol:
544
  st.pyplot(
545
  cross_plot_static(st.session_state.results["Validate"][TARGET],
546
  st.session_state.results["Validate"]["UCS_Pred"]),
547
- use_container_width=False
548
  )
549
- with right:
550
  st.plotly_chart(
551
  track_plot(st.session_state.results["Validate"], include_actual=True),
552
- use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
553
  )
554
 
555
  sv = st.session_state.results["sv_val"]
556
  if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
557
  if st.session_state.results["oor_tbl"] is not None:
558
  st.write("*Out-of-range rows (vs. Training min–max):*")
559
- st.dataframe(st.session_state.results["oor_tbl"], use_container_width=True)
560
 
561
  # =========================
562
  # PREDICTION (no actual UCS)
@@ -599,8 +568,9 @@ if st.session_state.app_step == "predict":
599
 
600
  if "PredictOnly" in st.session_state.results:
601
  df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
602
- left, spacer, right = st.columns(PLOT_COLS)
603
- with left:
 
604
  table = pd.DataFrame({
605
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
606
  "Value": [sv["n"],
@@ -611,12 +581,12 @@ if st.session_state.app_step == "predict":
611
  f'{sv["oor"]:.1f}%']
612
  })
613
  st.success("Predictions ready ✓")
614
- st.dataframe(table, use_container_width=True, hide_index=True)
615
  st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
616
- with right:
617
  st.plotly_chart(
618
  track_plot(df, include_actual=False),
619
- use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
620
  )
621
 
622
  # =========================
 
6
  import numpy as np
7
  import joblib
8
 
9
+ # Matplotlib for PREVIEW modal and for the CROSS-PLOT (static)
10
  import matplotlib
11
  matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
 
26
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
27
 
28
  # ---- Plot sizing controls ----
29
+ CROSS_W = 420 # px (matplotlib figure size; Streamlit will still scale)
30
+ CROSS_H = 420
31
+ TRACK_H = 740 # px (plotly height; width auto-fits column)
 
32
  FONT_SZ = 13
 
 
33
 
34
  # =========================
35
  # Page / CSS
36
  # =========================
37
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
 
 
 
 
38
 
39
+ # Hide uploader helper text ("Drag and drop file here", limits, etc.)
 
 
 
 
 
 
40
  st.markdown("""
41
  <style>
42
  /* Older builds (helper wrapped in a Markdown container) */
43
+ section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"]{display:none !important;}
44
+ /* 1.31–1.34: helper is the first child in the dropzone */
45
+ section[data-testid="stFileUploader"] [data-testid="stFileUploaderDropzone"] > div:first-child{display:none !important;}
46
+ /* 1.35+: explicit helper container */
47
+ section[data-testid="stFileUploader"] [data-testid="stFileUploaderInstructions"]{display:none !important;}
48
+ /* Fallback: any paragraph/small text inside the uploader */
49
+ section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] small{display:none !important;}
 
 
 
 
 
 
 
50
  </style>
51
  """, unsafe_allow_html=True)
52
 
53
+ # Center text in all pandas Styler tables (headers + cells)
54
+ TABLE_CENTER_CSS = [
55
+ dict(selector="th", props=[("text-align", "center")]),
56
+ dict(selector="td", props=[("text-align", "center")]),
57
+ ]
58
+
59
  # =========================
60
  # Password gate
61
  # =========================
62
  def inline_logo(path="logo.png") -> str:
63
  try:
64
  p = Path(path)
65
+ if not p.exists(): return ""
 
66
  return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
67
  except Exception:
68
  return ""
 
80
  if st.session_state.get("auth_ok", False):
81
  return
82
 
83
+ st.sidebar.image("logo.png", use_column_width=True)
84
  st.sidebar.markdown("### ST_GeoMech_UCS\nSmart Thinking • Secure Access")
85
  pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
86
  if st.sidebar.button("Unlock", type="primary"):
 
96
  # =========================
97
  # Utilities
98
  # =========================
 
 
 
 
 
 
 
 
 
 
 
99
  def rmse(y_true, y_pred) -> float:
100
  return float(np.sqrt(mean_squared_error(y_true, y_pred)))
101
 
102
  def pearson_r(y_true, y_pred) -> float:
103
  a = np.asarray(y_true, dtype=float)
104
  p = np.asarray(y_pred, dtype=float)
105
+ if a.size < 2: return float("nan")
 
106
  return float(np.corrcoef(a, p)[0, 1])
107
 
108
  @st.cache_resource(show_spinner=False)
 
115
  xl = pd.ExcelFile(bio)
116
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
117
 
118
+ def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
 
119
 
120
  def ensure_cols(df, cols):
121
  miss = [c for c in cols if c not in df.columns]
 
127
  def find_sheet(book, names):
128
  low2orig = {k.lower(): k for k in book.keys()}
129
  for nm in names:
130
+ if nm.lower() in low2orig: return low2orig[nm.lower()]
 
131
  return None
132
 
133
  def _nice_tick0(xmin: float, step: int = 100) -> float:
134
  return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
135
 
136
+ def df_centered_rounded(df: pd.DataFrame, hide_index=True):
137
+ """Round numeric columns to 2 decimals and center headers & cells."""
138
+ out = df.copy()
139
+ numcols = out.select_dtypes(include=[np.number]).columns
140
+ out[numcols] = out[numcols].round(2)
141
+ styler = (out.style
142
+ .set_properties(**{"text-align": "center"})
143
+ .set_table_styles(TABLE_CENTER_CSS))
144
+ st.dataframe(styler, use_container_width=True, hide_index=hide_index)
145
+
146
  # =========================
147
  # Cross plot (Matplotlib, fixed limits & ticks)
148
  # =========================
 
153
  fixed_min, fixed_max = 6000, 10000
154
  ticks = np.arange(fixed_min, fixed_max + 1, 1000)
155
 
 
156
  dpi = 110
157
  fig, ax = plt.subplots(
158
  figsize=(CROSS_W / dpi, CROSS_H / dpi),
 
160
  constrained_layout=False
161
  )
162
 
 
163
  ax.scatter(a, p, s=16, c=COLORS["pred"], alpha=0.9, linewidths=0)
 
 
164
  ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
165
  linestyle="--", linewidth=1.2, color=COLORS["ref"])
166
 
 
167
  ax.set_xlim(fixed_min, fixed_max)
168
  ax.set_ylim(fixed_min, fixed_max)
169
  ax.set_xticks(ticks)
170
  ax.set_yticks(ticks)
 
 
171
  ax.set_aspect("equal", adjustable="box")
172
 
 
173
  fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
174
  ax.xaxis.set_major_formatter(fmt)
175
  ax.yaxis.set_major_formatter(fmt)
176
 
 
177
  ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=12)
178
  ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=12)
179
  ax.tick_params(labelsize=10)
180
 
 
181
  ax.grid(True, linestyle=":", alpha=0.3)
182
  for spine in ax.spines.values():
183
  spine.set_linewidth(1.1)
184
  spine.set_color("#444")
185
 
 
186
  fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
187
  return fig
188
 
 
189
  # =========================
190
  # Track plot (Plotly)
191
  # =========================
 
225
  ))
226
 
227
  fig.update_layout(
228
+ height=TRACK_H, width=None, # width automatically fits the column
229
+ paper_bgcolor="#fff", plot_bgcolor="#fff",
230
+ margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
231
  font=dict(size=FONT_SZ),
232
  legend=dict(
233
  x=0.98, y=0.05, xanchor="right", yanchor="bottom",
 
236
  legend_title_text=""
237
  )
238
  fig.update_xaxes(
239
+ title_text="<b>UCS (psi)</b>", title_font=dict(size=14),
240
  side="top", range=[xmin, xmax],
241
  ticks="outside", tickformat=",.0f", tickmode="auto", tick0=tick0,
242
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
243
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
244
  )
245
  fig.update_yaxes(
246
+ title_text=f"<b>{ylab}</b>", title_font=dict(size=14),
247
  range=y_range, ticks="outside",
248
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
249
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
 
256
  n = len(cols)
257
  if n == 0:
258
  fig, ax = plt.subplots(figsize=(4, 2))
259
+ ax.text(0.5,0.5,"No selected columns",ha="center",va="center"); ax.axis("off")
 
260
  return fig
261
  fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
262
  if n == 1: axes = [axes]
 
269
  axes[0].set_ylabel("Point Index")
270
  return fig
271
 
272
+ # Modal wrapper (Streamlit compatibility)
273
+ try:
274
+ dialog = st.dialog
275
+ except AttributeError:
276
+ def dialog(title):
277
+ def deco(fn):
278
+ def wrapper(*args, **kwargs):
279
+ with st.expander(title, expanded=True):
280
+ return fn(*args, **kwargs)
281
+ return wrapper
282
+ return deco
283
+
284
  @dialog("Preview data")
285
  def preview_modal(book: dict[str, pd.DataFrame]):
286
  if not book:
 
291
  with t:
292
  df = book[name]
293
  t1, t2 = st.tabs(["Tracks", "Summary"])
294
+ with t1:
295
+ st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
296
  with t2:
297
+ tbl = (df[FEATURES]
298
+ .agg(['min','max','mean','std'])
299
+ .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
300
+ df_centered_rounded(tbl.reset_index(names="Feature"))
301
 
302
  # =========================
303
  # Load model
 
351
  # =========================
352
  # Branding in Sidebar
353
  # =========================
354
+ st.sidebar.image("logo.png", use_column_width=True)
355
  st.sidebar.markdown(
356
  "<div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>"
357
  "<div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>",
 
395
  st.session_state.dev_preview = True
396
 
397
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
 
398
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
399
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
400
 
 
401
  helper_top = st.container()
402
  with helper_top:
403
  st.subheader("Case Building")
 
440
  c1,c2,c3 = st.columns(3)
441
  c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
442
 
443
+ # NEW: 2-column layout, big gap, no nested columns, no 0-weights.
444
+ col_cross, col_track = st.columns([3, 2], gap="large")
445
+ with col_cross:
446
+ st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=True)
447
+ with col_track:
 
 
 
 
448
  st.plotly_chart(
449
  track_plot(df, include_actual=True),
450
+ use_container_width=True,
451
  config={"displayModeBar": False, "scrollZoom": True}
452
  )
453
 
 
493
  if any_viol.any():
494
  tbl = df.loc[any_viol, FEATURES].copy()
495
  for c in FEATURES:
496
+ if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
497
+ tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
 
 
 
498
  st.session_state.results["m_val"]={
499
  "R": pearson_r(df[TARGET], df["UCS_Pred"]),
500
  "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
 
508
  c1,c2,c3 = st.columns(3)
509
  c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
510
 
511
+ col_cross, col_track = st.columns([3, 2], gap="large")
512
+ with col_cross:
 
 
 
 
513
  st.pyplot(
514
  cross_plot_static(st.session_state.results["Validate"][TARGET],
515
  st.session_state.results["Validate"]["UCS_Pred"]),
516
+ use_container_width=True
517
  )
518
+ with col_track:
519
  st.plotly_chart(
520
  track_plot(st.session_state.results["Validate"], include_actual=True),
521
+ use_container_width=True, config={"displayModeBar": False, "scrollZoom": True}
522
  )
523
 
524
  sv = st.session_state.results["sv_val"]
525
  if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
526
  if st.session_state.results["oor_tbl"] is not None:
527
  st.write("*Out-of-range rows (vs. Training min–max):*")
528
+ df_centered_rounded(st.session_state.results["oor_tbl"])
529
 
530
  # =========================
531
  # PREDICTION (no actual UCS)
 
568
 
569
  if "PredictOnly" in st.session_state.results:
570
  df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
571
+
572
+ col_left, col_right = st.columns([2,3], gap="large")
573
+ with col_left:
574
  table = pd.DataFrame({
575
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
576
  "Value": [sv["n"],
 
581
  f'{sv["oor"]:.1f}%']
582
  })
583
  st.success("Predictions ready ✓")
584
+ df_centered_rounded(table, hide_index=True)
585
  st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
586
+ with col_right:
587
  st.plotly_chart(
588
  track_plot(df, include_actual=False),
589
+ use_container_width=True, config={"displayModeBar": False, "scrollZoom": True}
590
  )
591
 
592
  # =========================