Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,18 +25,18 @@ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
|
|
| 25 |
COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
|
| 26 |
|
| 27 |
# ---- Plot sizing controls (edit here) ----
|
| 28 |
-
CROSS_W = 500; CROSS_H = 500 # square cross-plot
|
| 29 |
-
TRACK_W = 400; TRACK_H = 950 # log-strip style (
|
| 30 |
FONT_SZ = 13
|
| 31 |
-
PLOT_COLS = [14, 0.3, 10] # 3-column band: left • spacer • right
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# =========================
|
| 34 |
# Page / CSS
|
| 35 |
# =========================
|
| 36 |
st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
|
| 37 |
st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
|
| 38 |
-
# helper class to right-align the cross-plot inside its column
|
| 39 |
-
st.markdown("<style>.align-right{display:flex;justify-content:flex-end;width:100%;}</style>", unsafe_allow_html=True)
|
| 40 |
st.markdown(
|
| 41 |
"""
|
| 42 |
<style>
|
|
@@ -147,7 +147,14 @@ def parse_excel(data_bytes: bytes):
|
|
| 147 |
return {sh: xl.parse(sh) for sh in xl.sheet_names}
|
| 148 |
|
| 149 |
def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def find_sheet(book, names):
|
| 152 |
low2orig = {k.lower(): k for k in book.keys()}
|
| 153 |
for nm in names:
|
|
@@ -160,15 +167,6 @@ def _nice_tick0(xmin: float, step: int = 100) -> float:
|
|
| 160 |
return xmin
|
| 161 |
return step * math.floor(xmin / step)
|
| 162 |
|
| 163 |
-
def _nice_dtick(xmin: float, xmax: float) -> int:
|
| 164 |
-
width = max(xmax - xmin, 1.0)
|
| 165 |
-
candidates = [50, 100, 200, 250, 500, 1000]
|
| 166 |
-
for dt in candidates:
|
| 167 |
-
n = width / dt
|
| 168 |
-
if 5 <= n <= 12:
|
| 169 |
-
return dt
|
| 170 |
-
return 100
|
| 171 |
-
|
| 172 |
# ---------- Plot builders ----------
|
| 173 |
def cross_plot(actual, pred):
|
| 174 |
a = pd.Series(actual).astype(float)
|
|
@@ -226,8 +224,7 @@ def track_plot(df, include_actual=True):
|
|
| 226 |
x_lo, x_hi = float(x_series.min()), float(x_series.max())
|
| 227 |
x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
|
| 228 |
xmin, xmax = x_lo - x_pad, x_hi + x_pad
|
| 229 |
-
tick0 = _nice_tick0(xmin, step=100)
|
| 230 |
-
dtick = _nice_dtick(xmin, xmax)
|
| 231 |
|
| 232 |
fig = go.Figure()
|
| 233 |
fig.add_trace(go.Scatter(
|
|
@@ -257,7 +254,7 @@ def track_plot(df, include_actual=True):
|
|
| 257 |
fig.update_xaxes(
|
| 258 |
title_text="<b>UCS (psi)</b>", side="top", range=[xmin, xmax],
|
| 259 |
ticks="outside", tickformat=",.0f",
|
| 260 |
-
tickmode="
|
| 261 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
| 262 |
showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
|
| 263 |
)
|
|
@@ -412,7 +409,16 @@ if st.session_state.app_step == "dev":
|
|
| 412 |
df0 = next(iter(tmp.values()))
|
| 413 |
st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
|
| 414 |
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
helper_top = st.container()
|
| 417 |
with helper_top:
|
| 418 |
st.subheader("Case Building (Development)")
|
|
@@ -423,15 +429,6 @@ if st.session_state.app_step == "dev":
|
|
| 423 |
else:
|
| 424 |
st.write("**Upload your data to build a case, then run the model to review development performance.**")
|
| 425 |
|
| 426 |
-
# preview modal call AFTER helper, so helper stays pinned above
|
| 427 |
-
if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
|
| 428 |
-
preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
|
| 429 |
-
st.session_state.dev_preview = True
|
| 430 |
-
|
| 431 |
-
run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
|
| 432 |
-
if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
|
| 433 |
-
if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
|
| 434 |
-
|
| 435 |
if run and st.session_state.dev_file_bytes:
|
| 436 |
book = read_book_bytes(st.session_state.dev_file_bytes)
|
| 437 |
sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
|
|
@@ -457,13 +454,13 @@ if st.session_state.app_step == "dev":
|
|
| 457 |
c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
|
| 458 |
left, spacer, right = st.columns(PLOT_COLS)
|
| 459 |
with left:
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
with right:
|
| 468 |
st.plotly_chart(
|
| 469 |
track_plot(df, include_actual=True),
|
|
@@ -489,16 +486,15 @@ if st.session_state.app_step == "validate":
|
|
| 489 |
if book:
|
| 490 |
df0 = next(iter(book.values()))
|
| 491 |
st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
|
| 492 |
-
|
| 493 |
-
st.subheader("Validate the Model")
|
| 494 |
-
st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
|
| 495 |
-
|
| 496 |
if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
|
| 497 |
preview_modal(read_book_bytes(up.getvalue()))
|
| 498 |
go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
|
| 499 |
if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
|
| 500 |
if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
|
| 501 |
|
|
|
|
|
|
|
|
|
|
| 502 |
if go_btn and up is not None:
|
| 503 |
book = read_book_bytes(up.getvalue())
|
| 504 |
name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
|
|
@@ -519,25 +515,26 @@ if st.session_state.app_step == "validate":
|
|
| 519 |
st.session_state.results["oor_tbl"]=tbl
|
| 520 |
|
| 521 |
if "Validate" in st.session_state.results:
|
| 522 |
-
m = st.session_state.results["m_val"]
|
| 523 |
c1,c2,c3 = st.columns(3)
|
| 524 |
c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
|
| 525 |
|
| 526 |
left, spacer, right = st.columns(PLOT_COLS)
|
| 527 |
with left:
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
with right:
|
| 536 |
st.plotly_chart(
|
| 537 |
track_plot(st.session_state.results["Validate"], include_actual=True),
|
| 538 |
use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
|
| 539 |
)
|
| 540 |
|
|
|
|
| 541 |
if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
|
| 542 |
if st.session_state.results["oor_tbl"] is not None:
|
| 543 |
st.write("*Out-of-range rows (vs. Training min–max):*")
|
|
|
|
| 25 |
COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
|
| 26 |
|
| 27 |
# ---- Plot sizing controls (edit here) ----
|
| 28 |
+
CROSS_W = 500; CROSS_H = 500 # square cross-plot (Build + Validate)
|
| 29 |
+
TRACK_W = 400; TRACK_H = 950 # log-strip style (all pages)
|
| 30 |
FONT_SZ = 13
|
| 31 |
+
PLOT_COLS = [14, 0.3, 10] # 3-column band: left • spacer • right (Build + Validate)
|
| 32 |
+
CROSS_NUDGE = 1.2 # push cross-plot to the RIGHT inside its band:
|
| 33 |
+
# inner columns [CROSS_NUDGE : 1] → bigger = more right
|
| 34 |
|
| 35 |
# =========================
|
| 36 |
# Page / CSS
|
| 37 |
# =========================
|
| 38 |
st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
|
| 39 |
st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
|
|
|
|
|
|
|
| 40 |
st.markdown(
|
| 41 |
"""
|
| 42 |
<style>
|
|
|
|
| 147 |
return {sh: xl.parse(sh) for sh in xl.sheet_names}
|
| 148 |
|
| 149 |
def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
|
| 150 |
+
|
| 151 |
+
def ensure_cols(df, cols):
|
| 152 |
+
miss = [c for c in cols if c not in df.columns]
|
| 153 |
+
if miss:
|
| 154 |
+
st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
|
| 155 |
+
return False
|
| 156 |
+
return True
|
| 157 |
+
|
| 158 |
def find_sheet(book, names):
|
| 159 |
low2orig = {k.lower(): k for k in book.keys()}
|
| 160 |
for nm in names:
|
|
|
|
| 167 |
return xmin
|
| 168 |
return step * math.floor(xmin / step)
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# ---------- Plot builders ----------
|
| 171 |
def cross_plot(actual, pred):
|
| 172 |
a = pd.Series(actual).astype(float)
|
|
|
|
| 224 |
x_lo, x_hi = float(x_series.min()), float(x_series.max())
|
| 225 |
x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
|
| 226 |
xmin, xmax = x_lo - x_pad, x_hi + x_pad
|
| 227 |
+
tick0 = _nice_tick0(xmin, step=100) # sensible first tick at left border
|
|
|
|
| 228 |
|
| 229 |
fig = go.Figure()
|
| 230 |
fig.add_trace(go.Scatter(
|
|
|
|
| 254 |
fig.update_xaxes(
|
| 255 |
title_text="<b>UCS (psi)</b>", side="top", range=[xmin, xmax],
|
| 256 |
ticks="outside", tickformat=",.0f",
|
| 257 |
+
tickmode="auto", tick0=tick0,
|
| 258 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
| 259 |
showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
|
| 260 |
)
|
|
|
|
| 409 |
df0 = next(iter(tmp.values()))
|
| 410 |
st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
|
| 411 |
|
| 412 |
+
if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
|
| 413 |
+
preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
|
| 414 |
+
st.session_state.dev_preview = True
|
| 415 |
+
|
| 416 |
+
run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
|
| 417 |
+
# always available nav
|
| 418 |
+
if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
|
| 419 |
+
if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
|
| 420 |
+
|
| 421 |
+
# ---- Pinned helper at the very top of the page ----
|
| 422 |
helper_top = st.container()
|
| 423 |
with helper_top:
|
| 424 |
st.subheader("Case Building (Development)")
|
|
|
|
| 429 |
else:
|
| 430 |
st.write("**Upload your data to build a case, then run the model to review development performance.**")
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
if run and st.session_state.dev_file_bytes:
|
| 433 |
book = read_book_bytes(st.session_state.dev_file_bytes)
|
| 434 |
sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
|
|
|
|
| 454 |
c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
|
| 455 |
left, spacer, right = st.columns(PLOT_COLS)
|
| 456 |
with left:
|
| 457 |
+
pad, plotcol = left.columns([CROSS_NUDGE, 1]) # shift cross-plot right inside its band
|
| 458 |
+
with plotcol:
|
| 459 |
+
st.plotly_chart(
|
| 460 |
+
cross_plot(df[TARGET], df["UCS_Pred"]),
|
| 461 |
+
use_container_width=False,
|
| 462 |
+
config={"displayModeBar": False, "scrollZoom": True}
|
| 463 |
+
)
|
| 464 |
with right:
|
| 465 |
st.plotly_chart(
|
| 466 |
track_plot(df, include_actual=True),
|
|
|
|
| 486 |
if book:
|
| 487 |
df0 = next(iter(book.values()))
|
| 488 |
st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
|
| 490 |
preview_modal(read_book_bytes(up.getvalue()))
|
| 491 |
go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
|
| 492 |
if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
|
| 493 |
if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
|
| 494 |
|
| 495 |
+
st.subheader("Validate the Model")
|
| 496 |
+
st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
|
| 497 |
+
|
| 498 |
if go_btn and up is not None:
|
| 499 |
book = read_book_bytes(up.getvalue())
|
| 500 |
name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
|
|
|
|
| 515 |
st.session_state.results["oor_tbl"]=tbl
|
| 516 |
|
| 517 |
if "Validate" in st.session_state.results:
|
| 518 |
+
m = st.session_state.results["m_val"]
|
| 519 |
c1,c2,c3 = st.columns(3)
|
| 520 |
c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
|
| 521 |
|
| 522 |
left, spacer, right = st.columns(PLOT_COLS)
|
| 523 |
with left:
|
| 524 |
+
pad, plotcol = left.columns([CROSS_NUDGE, 1]) # same nudge
|
| 525 |
+
with plotcol:
|
| 526 |
+
st.plotly_chart(
|
| 527 |
+
cross_plot(st.session_state.results["Validate"][TARGET],
|
| 528 |
+
st.session_state.results["Validate"]["UCS_Pred"]),
|
| 529 |
+
use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
|
| 530 |
+
)
|
| 531 |
with right:
|
| 532 |
st.plotly_chart(
|
| 533 |
track_plot(st.session_state.results["Validate"], include_actual=True),
|
| 534 |
use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
|
| 535 |
)
|
| 536 |
|
| 537 |
+
sv = st.session_state.results["sv_val"]
|
| 538 |
if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
|
| 539 |
if st.session_state.results["oor_tbl"] is not None:
|
| 540 |
st.write("*Out-of-range rows (vs. Training min–max):*")
|