Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import pandas as pd
|
|
| 6 |
import numpy as np
|
| 7 |
import joblib
|
| 8 |
|
| 9 |
-
# Matplotlib for PREVIEW modal and CROSS-PLOT (static)
|
| 10 |
import matplotlib
|
| 11 |
matplotlib.use("Agg")
|
| 12 |
import matplotlib.pyplot as plt
|
|
@@ -26,58 +26,43 @@ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
|
|
| 26 |
COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
|
| 27 |
|
| 28 |
# ---- Plot sizing controls ----
|
| 29 |
-
CROSS_W = 420
|
| 30 |
-
CROSS_H = 420
|
| 31 |
-
|
| 32 |
-
TRACK_H = 740 # px
|
| 33 |
FONT_SZ = 13
|
| 34 |
-
PLOT_COLS = [36, 6, 28] # wider spacer so plots never bump into each other
|
| 35 |
-
CROSS_NUDGE = 0.0 # keep 0 unless you really want an inner pad
|
| 36 |
|
| 37 |
# =========================
|
| 38 |
# Page / CSS
|
| 39 |
# =========================
|
| 40 |
st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
|
| 41 |
-
st.markdown("""
|
| 42 |
-
<style>
|
| 43 |
-
/* Reusable logo style */
|
| 44 |
-
.brand-logo { width: 16px; height: auto; object-fit: contain; }
|
| 45 |
|
| 46 |
-
|
| 47 |
-
.sidebar-header { display:flex; align-items:center; gap:12px; }
|
| 48 |
-
.sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
|
| 49 |
-
.sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
|
| 50 |
-
</style>
|
| 51 |
-
""", unsafe_allow_html=True)
|
| 52 |
-
# Hide file-uploader helper text; keep only Browse button
|
| 53 |
st.markdown("""
|
| 54 |
<style>
|
| 55 |
/* Older builds (helper wrapped in a Markdown container) */
|
| 56 |
-
section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"]
|
| 57 |
-
/* 1.31–1.34
|
| 58 |
-
section[data-testid="stFileUploader"] [data-testid="stFileUploaderDropzone"] > div:first-child
|
| 59 |
-
/* 1.35+ explicit helper container */
|
| 60 |
-
section[data-testid="stFileUploader"] [data-testid="stFileUploaderInstructions"]
|
| 61 |
-
/* Fallback: any paragraph/small
|
| 62 |
-
section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] small
|
| 63 |
-
|
| 64 |
-
/* Center headers & cells in all st.dataframe tables */
|
| 65 |
-
[data-testid="stDataFrame"] table td,
|
| 66 |
-
[data-testid="stDataFrame"] table th {
|
| 67 |
-
text-align: center !important;
|
| 68 |
-
vertical-align: middle !important;
|
| 69 |
-
}
|
| 70 |
</style>
|
| 71 |
""", unsafe_allow_html=True)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
# =========================
|
| 74 |
# Password gate
|
| 75 |
# =========================
|
| 76 |
def inline_logo(path="logo.png") -> str:
|
| 77 |
try:
|
| 78 |
p = Path(path)
|
| 79 |
-
if not p.exists():
|
| 80 |
-
return ""
|
| 81 |
return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
|
| 82 |
except Exception:
|
| 83 |
return ""
|
|
@@ -95,7 +80,7 @@ def add_password_gate() -> None:
|
|
| 95 |
if st.session_state.get("auth_ok", False):
|
| 96 |
return
|
| 97 |
|
| 98 |
-
st.sidebar.image("logo.png", use_column_width=
|
| 99 |
st.sidebar.markdown("### ST_GeoMech_UCS\nSmart Thinking • Secure Access")
|
| 100 |
pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
|
| 101 |
if st.sidebar.button("Unlock", type="primary"):
|
|
@@ -111,25 +96,13 @@ add_password_gate()
|
|
| 111 |
# =========================
|
| 112 |
# Utilities
|
| 113 |
# =========================
|
| 114 |
-
try:
|
| 115 |
-
dialog = st.dialog
|
| 116 |
-
except AttributeError:
|
| 117 |
-
def dialog(title):
|
| 118 |
-
def deco(fn):
|
| 119 |
-
def wrapper(*args, **kwargs):
|
| 120 |
-
with st.expander(title, expanded=True):
|
| 121 |
-
return fn(*args, **kwargs)
|
| 122 |
-
return wrapper
|
| 123 |
-
return deco
|
| 124 |
-
|
| 125 |
def rmse(y_true, y_pred) -> float:
|
| 126 |
return float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 127 |
|
| 128 |
def pearson_r(y_true, y_pred) -> float:
|
| 129 |
a = np.asarray(y_true, dtype=float)
|
| 130 |
p = np.asarray(y_pred, dtype=float)
|
| 131 |
-
if a.size < 2:
|
| 132 |
-
return float("nan")
|
| 133 |
return float(np.corrcoef(a, p)[0, 1])
|
| 134 |
|
| 135 |
@st.cache_resource(show_spinner=False)
|
|
@@ -142,8 +115,7 @@ def parse_excel(data_bytes: bytes):
|
|
| 142 |
xl = pd.ExcelFile(bio)
|
| 143 |
return {sh: xl.parse(sh) for sh in xl.sheet_names}
|
| 144 |
|
| 145 |
-
def read_book_bytes(b: bytes):
|
| 146 |
-
return parse_excel(b) if b else {}
|
| 147 |
|
| 148 |
def ensure_cols(df, cols):
|
| 149 |
miss = [c for c in cols if c not in df.columns]
|
|
@@ -155,13 +127,22 @@ def ensure_cols(df, cols):
|
|
| 155 |
def find_sheet(book, names):
|
| 156 |
low2orig = {k.lower(): k for k in book.keys()}
|
| 157 |
for nm in names:
|
| 158 |
-
if nm.lower() in low2orig:
|
| 159 |
-
return low2orig[nm.lower()]
|
| 160 |
return None
|
| 161 |
|
| 162 |
def _nice_tick0(xmin: float, step: int = 100) -> float:
|
| 163 |
return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
# =========================
|
| 166 |
# Cross plot (Matplotlib, fixed limits & ticks)
|
| 167 |
# =========================
|
|
@@ -172,7 +153,6 @@ def cross_plot_static(actual, pred):
|
|
| 172 |
fixed_min, fixed_max = 6000, 10000
|
| 173 |
ticks = np.arange(fixed_min, fixed_max + 1, 1000)
|
| 174 |
|
| 175 |
-
# fixed px size = (figsize * dpi)
|
| 176 |
dpi = 110
|
| 177 |
fig, ax = plt.subplots(
|
| 178 |
figsize=(CROSS_W / dpi, CROSS_H / dpi),
|
|
@@ -180,43 +160,32 @@ def cross_plot_static(actual, pred):
|
|
| 180 |
constrained_layout=False
|
| 181 |
)
|
| 182 |
|
| 183 |
-
# points
|
| 184 |
ax.scatter(a, p, s=16, c=COLORS["pred"], alpha=0.9, linewidths=0)
|
| 185 |
-
|
| 186 |
-
# 1:1 diagonal
|
| 187 |
ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
|
| 188 |
linestyle="--", linewidth=1.2, color=COLORS["ref"])
|
| 189 |
|
| 190 |
-
# identical axes limits + ticks
|
| 191 |
ax.set_xlim(fixed_min, fixed_max)
|
| 192 |
ax.set_ylim(fixed_min, fixed_max)
|
| 193 |
ax.set_xticks(ticks)
|
| 194 |
ax.set_yticks(ticks)
|
| 195 |
-
|
| 196 |
-
# equal aspect → true 45°
|
| 197 |
ax.set_aspect("equal", adjustable="box")
|
| 198 |
|
| 199 |
-
# thousands formatting
|
| 200 |
fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
|
| 201 |
ax.xaxis.set_major_formatter(fmt)
|
| 202 |
ax.yaxis.set_major_formatter(fmt)
|
| 203 |
|
| 204 |
-
# labels + ticks (smaller so they don't dominate)
|
| 205 |
ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=12)
|
| 206 |
ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=12)
|
| 207 |
ax.tick_params(labelsize=10)
|
| 208 |
|
| 209 |
-
# grid & frame
|
| 210 |
ax.grid(True, linestyle=":", alpha=0.3)
|
| 211 |
for spine in ax.spines.values():
|
| 212 |
spine.set_linewidth(1.1)
|
| 213 |
spine.set_color("#444")
|
| 214 |
|
| 215 |
-
# moderate margins; keeps labels readable but not huge
|
| 216 |
fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
|
| 217 |
return fig
|
| 218 |
|
| 219 |
-
|
| 220 |
# =========================
|
| 221 |
# Track plot (Plotly)
|
| 222 |
# =========================
|
|
@@ -256,8 +225,9 @@ def track_plot(df, include_actual=True):
|
|
| 256 |
))
|
| 257 |
|
| 258 |
fig.update_layout(
|
| 259 |
-
|
| 260 |
-
|
|
|
|
| 261 |
font=dict(size=FONT_SZ),
|
| 262 |
legend=dict(
|
| 263 |
x=0.98, y=0.05, xanchor="right", yanchor="bottom",
|
|
@@ -266,14 +236,14 @@ def track_plot(df, include_actual=True):
|
|
| 266 |
legend_title_text=""
|
| 267 |
)
|
| 268 |
fig.update_xaxes(
|
| 269 |
-
title_text="<b>UCS (psi)</b>", title_font=dict(size=
|
| 270 |
side="top", range=[xmin, xmax],
|
| 271 |
ticks="outside", tickformat=",.0f", tickmode="auto", tick0=tick0,
|
| 272 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
| 273 |
showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
|
| 274 |
)
|
| 275 |
fig.update_yaxes(
|
| 276 |
-
title_text=f"<b>{ylab}</b>", title_font=dict(size=
|
| 277 |
range=y_range, ticks="outside",
|
| 278 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
| 279 |
showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
|
|
@@ -286,8 +256,7 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
|
|
| 286 |
n = len(cols)
|
| 287 |
if n == 0:
|
| 288 |
fig, ax = plt.subplots(figsize=(4, 2))
|
| 289 |
-
ax.text(0.5,0.5,"No selected columns",ha="center",va="center")
|
| 290 |
-
ax.axis("off")
|
| 291 |
return fig
|
| 292 |
fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
|
| 293 |
if n == 1: axes = [axes]
|
|
@@ -300,6 +269,18 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
|
|
| 300 |
axes[0].set_ylabel("Point Index")
|
| 301 |
return fig
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
@dialog("Preview data")
|
| 304 |
def preview_modal(book: dict[str, pd.DataFrame]):
|
| 305 |
if not book:
|
|
@@ -310,12 +291,13 @@ def preview_modal(book: dict[str, pd.DataFrame]):
|
|
| 310 |
with t:
|
| 311 |
df = book[name]
|
| 312 |
t1, t2 = st.tabs(["Tracks", "Summary"])
|
| 313 |
-
with t1:
|
|
|
|
| 314 |
with t2:
|
| 315 |
-
tbl = df[FEATURES]
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
|
| 320 |
# =========================
|
| 321 |
# Load model
|
|
@@ -369,7 +351,7 @@ st.session_state.setdefault("dev_preview",False)
|
|
| 369 |
# =========================
|
| 370 |
# Branding in Sidebar
|
| 371 |
# =========================
|
| 372 |
-
st.sidebar.image("logo.png", use_column_width=
|
| 373 |
st.sidebar.markdown(
|
| 374 |
"<div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>"
|
| 375 |
"<div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>",
|
|
@@ -413,11 +395,9 @@ if st.session_state.app_step == "dev":
|
|
| 413 |
st.session_state.dev_preview = True
|
| 414 |
|
| 415 |
run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
|
| 416 |
-
# nav
|
| 417 |
if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
|
| 418 |
if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
|
| 419 |
|
| 420 |
-
# ---- Pinned helper at the very top ----
|
| 421 |
helper_top = st.container()
|
| 422 |
with helper_top:
|
| 423 |
st.subheader("Case Building")
|
|
@@ -460,18 +440,14 @@ if st.session_state.app_step == "dev":
|
|
| 460 |
c1,c2,c3 = st.columns(3)
|
| 461 |
c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
|
| 462 |
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
plotcol = left
|
| 469 |
-
with plotcol:
|
| 470 |
-
st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
|
| 471 |
-
with right:
|
| 472 |
st.plotly_chart(
|
| 473 |
track_plot(df, include_actual=True),
|
| 474 |
-
use_container_width=
|
| 475 |
config={"displayModeBar": False, "scrollZoom": True}
|
| 476 |
)
|
| 477 |
|
|
@@ -517,11 +493,8 @@ if st.session_state.app_step == "validate":
|
|
| 517 |
if any_viol.any():
|
| 518 |
tbl = df.loc[any_viol, FEATURES].copy()
|
| 519 |
for c in FEATURES:
|
| 520 |
-
if pd.api.types.is_numeric_dtype(tbl[c]):
|
| 521 |
-
|
| 522 |
-
tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(
|
| 523 |
-
lambda r:", ".join([c for c,v in r.items() if v]), axis=1
|
| 524 |
-
)
|
| 525 |
st.session_state.results["m_val"]={
|
| 526 |
"R": pearson_r(df[TARGET], df["UCS_Pred"]),
|
| 527 |
"RMSE": rmse(df[TARGET], df["UCS_Pred"]),
|
|
@@ -535,28 +508,24 @@ if st.session_state.app_step == "validate":
|
|
| 535 |
c1,c2,c3 = st.columns(3)
|
| 536 |
c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
|
| 537 |
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
pad, plotcol = left.columns([CROSS_NUDGE, 1])
|
| 541 |
-
else:
|
| 542 |
-
plotcol = left
|
| 543 |
-
with plotcol:
|
| 544 |
st.pyplot(
|
| 545 |
cross_plot_static(st.session_state.results["Validate"][TARGET],
|
| 546 |
st.session_state.results["Validate"]["UCS_Pred"]),
|
| 547 |
-
use_container_width=
|
| 548 |
)
|
| 549 |
-
with
|
| 550 |
st.plotly_chart(
|
| 551 |
track_plot(st.session_state.results["Validate"], include_actual=True),
|
| 552 |
-
use_container_width=
|
| 553 |
)
|
| 554 |
|
| 555 |
sv = st.session_state.results["sv_val"]
|
| 556 |
if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
|
| 557 |
if st.session_state.results["oor_tbl"] is not None:
|
| 558 |
st.write("*Out-of-range rows (vs. Training min–max):*")
|
| 559 |
-
|
| 560 |
|
| 561 |
# =========================
|
| 562 |
# PREDICTION (no actual UCS)
|
|
@@ -599,8 +568,9 @@ if st.session_state.app_step == "predict":
|
|
| 599 |
|
| 600 |
if "PredictOnly" in st.session_state.results:
|
| 601 |
df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
|
| 602 |
-
|
| 603 |
-
|
|
|
|
| 604 |
table = pd.DataFrame({
|
| 605 |
"Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
|
| 606 |
"Value": [sv["n"],
|
|
@@ -611,12 +581,12 @@ if st.session_state.app_step == "predict":
|
|
| 611 |
f'{sv["oor"]:.1f}%']
|
| 612 |
})
|
| 613 |
st.success("Predictions ready ✓")
|
| 614 |
-
|
| 615 |
st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
|
| 616 |
-
with
|
| 617 |
st.plotly_chart(
|
| 618 |
track_plot(df, include_actual=False),
|
| 619 |
-
use_container_width=
|
| 620 |
)
|
| 621 |
|
| 622 |
# =========================
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import joblib
|
| 8 |
|
| 9 |
+
# Matplotlib for PREVIEW modal and for the CROSS-PLOT (static)
|
| 10 |
import matplotlib
|
| 11 |
matplotlib.use("Agg")
|
| 12 |
import matplotlib.pyplot as plt
|
|
|
|
| 26 |
COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
|
| 27 |
|
| 28 |
# ---- Plot sizing controls ----
|
| 29 |
+
CROSS_W = 420 # px (matplotlib figure size; Streamlit will still scale)
|
| 30 |
+
CROSS_H = 420
|
| 31 |
+
TRACK_H = 740 # px (plotly height; width auto-fits column)
|
|
|
|
| 32 |
FONT_SZ = 13
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# =========================
|
| 35 |
# Page / CSS
|
| 36 |
# =========================
|
| 37 |
st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# Hide uploader helper text ("Drag and drop file here", limits, etc.)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
st.markdown("""
|
| 41 |
<style>
|
| 42 |
/* Older builds (helper wrapped in a Markdown container) */
|
| 43 |
+
section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"]{display:none !important;}
|
| 44 |
+
/* 1.31–1.34: helper is the first child in the dropzone */
|
| 45 |
+
section[data-testid="stFileUploader"] [data-testid="stFileUploaderDropzone"] > div:first-child{display:none !important;}
|
| 46 |
+
/* 1.35+: explicit helper container */
|
| 47 |
+
section[data-testid="stFileUploader"] [data-testid="stFileUploaderInstructions"]{display:none !important;}
|
| 48 |
+
/* Fallback: any paragraph/small text inside the uploader */
|
| 49 |
+
section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] small{display:none !important;}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
</style>
|
| 51 |
""", unsafe_allow_html=True)
|
| 52 |
|
| 53 |
+
# Center text in all pandas Styler tables (headers + cells)
|
| 54 |
+
TABLE_CENTER_CSS = [
|
| 55 |
+
dict(selector="th", props=[("text-align", "center")]),
|
| 56 |
+
dict(selector="td", props=[("text-align", "center")]),
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
# =========================
|
| 60 |
# Password gate
|
| 61 |
# =========================
|
| 62 |
def inline_logo(path="logo.png") -> str:
|
| 63 |
try:
|
| 64 |
p = Path(path)
|
| 65 |
+
if not p.exists(): return ""
|
|
|
|
| 66 |
return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
|
| 67 |
except Exception:
|
| 68 |
return ""
|
|
|
|
| 80 |
if st.session_state.get("auth_ok", False):
|
| 81 |
return
|
| 82 |
|
| 83 |
+
st.sidebar.image("logo.png", use_column_width=True)
|
| 84 |
st.sidebar.markdown("### ST_GeoMech_UCS\nSmart Thinking • Secure Access")
|
| 85 |
pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
|
| 86 |
if st.sidebar.button("Unlock", type="primary"):
|
|
|
|
| 96 |
# =========================
|
| 97 |
# Utilities
|
| 98 |
# =========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def rmse(y_true, y_pred) -> float:
|
| 100 |
return float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 101 |
|
| 102 |
def pearson_r(y_true, y_pred) -> float:
|
| 103 |
a = np.asarray(y_true, dtype=float)
|
| 104 |
p = np.asarray(y_pred, dtype=float)
|
| 105 |
+
if a.size < 2: return float("nan")
|
|
|
|
| 106 |
return float(np.corrcoef(a, p)[0, 1])
|
| 107 |
|
| 108 |
@st.cache_resource(show_spinner=False)
|
|
|
|
| 115 |
xl = pd.ExcelFile(bio)
|
| 116 |
return {sh: xl.parse(sh) for sh in xl.sheet_names}
|
| 117 |
|
| 118 |
+
def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
|
|
|
|
| 119 |
|
| 120 |
def ensure_cols(df, cols):
|
| 121 |
miss = [c for c in cols if c not in df.columns]
|
|
|
|
| 127 |
def find_sheet(book, names):
|
| 128 |
low2orig = {k.lower(): k for k in book.keys()}
|
| 129 |
for nm in names:
|
| 130 |
+
if nm.lower() in low2orig: return low2orig[nm.lower()]
|
|
|
|
| 131 |
return None
|
| 132 |
|
| 133 |
def _nice_tick0(xmin: float, step: int = 100) -> float:
|
| 134 |
return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
|
| 135 |
|
| 136 |
+
def df_centered_rounded(df: pd.DataFrame, hide_index=True):
|
| 137 |
+
"""Round numeric columns to 2 decimals and center headers & cells."""
|
| 138 |
+
out = df.copy()
|
| 139 |
+
numcols = out.select_dtypes(include=[np.number]).columns
|
| 140 |
+
out[numcols] = out[numcols].round(2)
|
| 141 |
+
styler = (out.style
|
| 142 |
+
.set_properties(**{"text-align": "center"})
|
| 143 |
+
.set_table_styles(TABLE_CENTER_CSS))
|
| 144 |
+
st.dataframe(styler, use_container_width=True, hide_index=hide_index)
|
| 145 |
+
|
| 146 |
# =========================
|
| 147 |
# Cross plot (Matplotlib, fixed limits & ticks)
|
| 148 |
# =========================
|
|
|
|
| 153 |
fixed_min, fixed_max = 6000, 10000
|
| 154 |
ticks = np.arange(fixed_min, fixed_max + 1, 1000)
|
| 155 |
|
|
|
|
| 156 |
dpi = 110
|
| 157 |
fig, ax = plt.subplots(
|
| 158 |
figsize=(CROSS_W / dpi, CROSS_H / dpi),
|
|
|
|
| 160 |
constrained_layout=False
|
| 161 |
)
|
| 162 |
|
|
|
|
| 163 |
ax.scatter(a, p, s=16, c=COLORS["pred"], alpha=0.9, linewidths=0)
|
|
|
|
|
|
|
| 164 |
ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
|
| 165 |
linestyle="--", linewidth=1.2, color=COLORS["ref"])
|
| 166 |
|
|
|
|
| 167 |
ax.set_xlim(fixed_min, fixed_max)
|
| 168 |
ax.set_ylim(fixed_min, fixed_max)
|
| 169 |
ax.set_xticks(ticks)
|
| 170 |
ax.set_yticks(ticks)
|
|
|
|
|
|
|
| 171 |
ax.set_aspect("equal", adjustable="box")
|
| 172 |
|
|
|
|
| 173 |
fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
|
| 174 |
ax.xaxis.set_major_formatter(fmt)
|
| 175 |
ax.yaxis.set_major_formatter(fmt)
|
| 176 |
|
|
|
|
| 177 |
ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=12)
|
| 178 |
ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=12)
|
| 179 |
ax.tick_params(labelsize=10)
|
| 180 |
|
|
|
|
| 181 |
ax.grid(True, linestyle=":", alpha=0.3)
|
| 182 |
for spine in ax.spines.values():
|
| 183 |
spine.set_linewidth(1.1)
|
| 184 |
spine.set_color("#444")
|
| 185 |
|
|
|
|
| 186 |
fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
|
| 187 |
return fig
|
| 188 |
|
|
|
|
| 189 |
# =========================
|
| 190 |
# Track plot (Plotly)
|
| 191 |
# =========================
|
|
|
|
| 225 |
))
|
| 226 |
|
| 227 |
fig.update_layout(
|
| 228 |
+
height=TRACK_H, width=None, # width automatically fits the column
|
| 229 |
+
paper_bgcolor="#fff", plot_bgcolor="#fff",
|
| 230 |
+
margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
|
| 231 |
font=dict(size=FONT_SZ),
|
| 232 |
legend=dict(
|
| 233 |
x=0.98, y=0.05, xanchor="right", yanchor="bottom",
|
|
|
|
| 236 |
legend_title_text=""
|
| 237 |
)
|
| 238 |
fig.update_xaxes(
|
| 239 |
+
title_text="<b>UCS (psi)</b>", title_font=dict(size=14),
|
| 240 |
side="top", range=[xmin, xmax],
|
| 241 |
ticks="outside", tickformat=",.0f", tickmode="auto", tick0=tick0,
|
| 242 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
| 243 |
showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
|
| 244 |
)
|
| 245 |
fig.update_yaxes(
|
| 246 |
+
title_text=f"<b>{ylab}</b>", title_font=dict(size=14),
|
| 247 |
range=y_range, ticks="outside",
|
| 248 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
| 249 |
showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
|
|
|
|
| 256 |
n = len(cols)
|
| 257 |
if n == 0:
|
| 258 |
fig, ax = plt.subplots(figsize=(4, 2))
|
| 259 |
+
ax.text(0.5,0.5,"No selected columns",ha="center",va="center"); ax.axis("off")
|
|
|
|
| 260 |
return fig
|
| 261 |
fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
|
| 262 |
if n == 1: axes = [axes]
|
|
|
|
| 269 |
axes[0].set_ylabel("Point Index")
|
| 270 |
return fig
|
| 271 |
|
| 272 |
+
# Modal wrapper (Streamlit compatibility)
|
| 273 |
+
try:
|
| 274 |
+
dialog = st.dialog
|
| 275 |
+
except AttributeError:
|
| 276 |
+
def dialog(title):
|
| 277 |
+
def deco(fn):
|
| 278 |
+
def wrapper(*args, **kwargs):
|
| 279 |
+
with st.expander(title, expanded=True):
|
| 280 |
+
return fn(*args, **kwargs)
|
| 281 |
+
return wrapper
|
| 282 |
+
return deco
|
| 283 |
+
|
| 284 |
@dialog("Preview data")
|
| 285 |
def preview_modal(book: dict[str, pd.DataFrame]):
|
| 286 |
if not book:
|
|
|
|
| 291 |
with t:
|
| 292 |
df = book[name]
|
| 293 |
t1, t2 = st.tabs(["Tracks", "Summary"])
|
| 294 |
+
with t1:
|
| 295 |
+
st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
|
| 296 |
with t2:
|
| 297 |
+
tbl = (df[FEATURES]
|
| 298 |
+
.agg(['min','max','mean','std'])
|
| 299 |
+
.T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
|
| 300 |
+
df_centered_rounded(tbl.reset_index(names="Feature"))
|
| 301 |
|
| 302 |
# =========================
|
| 303 |
# Load model
|
|
|
|
| 351 |
# =========================
|
| 352 |
# Branding in Sidebar
|
| 353 |
# =========================
|
| 354 |
+
st.sidebar.image("logo.png", use_column_width=True)
|
| 355 |
st.sidebar.markdown(
|
| 356 |
"<div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>"
|
| 357 |
"<div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>",
|
|
|
|
| 395 |
st.session_state.dev_preview = True
|
| 396 |
|
| 397 |
run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
|
|
|
|
| 398 |
if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
|
| 399 |
if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
|
| 400 |
|
|
|
|
| 401 |
helper_top = st.container()
|
| 402 |
with helper_top:
|
| 403 |
st.subheader("Case Building")
|
|
|
|
| 440 |
c1,c2,c3 = st.columns(3)
|
| 441 |
c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
|
| 442 |
|
| 443 |
+
# NEW: 2-column layout, big gap, no nested columns, no 0-weights.
|
| 444 |
+
col_cross, col_track = st.columns([3, 2], gap="large")
|
| 445 |
+
with col_cross:
|
| 446 |
+
st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=True)
|
| 447 |
+
with col_track:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
st.plotly_chart(
|
| 449 |
track_plot(df, include_actual=True),
|
| 450 |
+
use_container_width=True,
|
| 451 |
config={"displayModeBar": False, "scrollZoom": True}
|
| 452 |
)
|
| 453 |
|
|
|
|
| 493 |
if any_viol.any():
|
| 494 |
tbl = df.loc[any_viol, FEATURES].copy()
|
| 495 |
for c in FEATURES:
|
| 496 |
+
if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
|
| 497 |
+
tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
|
|
|
|
|
|
|
|
|
|
| 498 |
st.session_state.results["m_val"]={
|
| 499 |
"R": pearson_r(df[TARGET], df["UCS_Pred"]),
|
| 500 |
"RMSE": rmse(df[TARGET], df["UCS_Pred"]),
|
|
|
|
| 508 |
c1,c2,c3 = st.columns(3)
|
| 509 |
c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
|
| 510 |
|
| 511 |
+
col_cross, col_track = st.columns([3, 2], gap="large")
|
| 512 |
+
with col_cross:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
st.pyplot(
|
| 514 |
cross_plot_static(st.session_state.results["Validate"][TARGET],
|
| 515 |
st.session_state.results["Validate"]["UCS_Pred"]),
|
| 516 |
+
use_container_width=True
|
| 517 |
)
|
| 518 |
+
with col_track:
|
| 519 |
st.plotly_chart(
|
| 520 |
track_plot(st.session_state.results["Validate"], include_actual=True),
|
| 521 |
+
use_container_width=True, config={"displayModeBar": False, "scrollZoom": True}
|
| 522 |
)
|
| 523 |
|
| 524 |
sv = st.session_state.results["sv_val"]
|
| 525 |
if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
|
| 526 |
if st.session_state.results["oor_tbl"] is not None:
|
| 527 |
st.write("*Out-of-range rows (vs. Training min–max):*")
|
| 528 |
+
df_centered_rounded(st.session_state.results["oor_tbl"])
|
| 529 |
|
| 530 |
# =========================
|
| 531 |
# PREDICTION (no actual UCS)
|
|
|
|
| 568 |
|
| 569 |
if "PredictOnly" in st.session_state.results:
|
| 570 |
df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
|
| 571 |
+
|
| 572 |
+
col_left, col_right = st.columns([2,3], gap="large")
|
| 573 |
+
with col_left:
|
| 574 |
table = pd.DataFrame({
|
| 575 |
"Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
|
| 576 |
"Value": [sv["n"],
|
|
|
|
| 581 |
f'{sv["oor"]:.1f}%']
|
| 582 |
})
|
| 583 |
st.success("Predictions ready ✓")
|
| 584 |
+
df_centered_rounded(table, hide_index=True)
|
| 585 |
st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
|
| 586 |
+
with col_right:
|
| 587 |
st.plotly_chart(
|
| 588 |
track_plot(df, include_actual=False),
|
| 589 |
+
use_container_width=True, config={"displayModeBar": False, "scrollZoom": True}
|
| 590 |
)
|
| 591 |
|
| 592 |
# =========================
|