UCS2014 commited on
Commit
a03a1b1
·
verified ·
1 Parent(s): 0df59b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -47
app.py CHANGED
@@ -25,18 +25,18 @@ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
25
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
 
27
  # ---- Plot sizing controls (edit here) ----
28
- CROSS_W = 500; CROSS_H = 500 # square cross-plot
29
- TRACK_W = 400; TRACK_H = 950 # log-strip style (tall, slightly wider)
30
  FONT_SZ = 13
31
- PLOT_COLS = [14, 0.3, 10] # 3-column band: left • spacer • right
 
 
32
 
33
  # =========================
34
  # Page / CSS
35
  # =========================
36
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
37
  st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
38
- # helper class to right-align the cross-plot inside its column
39
- st.markdown("<style>.align-right{display:flex;justify-content:flex-end;width:100%;}</style>", unsafe_allow_html=True)
40
  st.markdown(
41
  """
42
  <style>
@@ -147,7 +147,14 @@ def parse_excel(data_bytes: bytes):
147
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
148
 
149
  def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
150
- def ensure_cols(df, cols): return not [c for c in cols if c not in df.columns] or False
 
 
 
 
 
 
 
151
  def find_sheet(book, names):
152
  low2orig = {k.lower(): k for k in book.keys()}
153
  for nm in names:
@@ -160,15 +167,6 @@ def _nice_tick0(xmin: float, step: int = 100) -> float:
160
  return xmin
161
  return step * math.floor(xmin / step)
162
 
163
- def _nice_dtick(xmin: float, xmax: float) -> int:
164
- width = max(xmax - xmin, 1.0)
165
- candidates = [50, 100, 200, 250, 500, 1000]
166
- for dt in candidates:
167
- n = width / dt
168
- if 5 <= n <= 12:
169
- return dt
170
- return 100
171
-
172
  # ---------- Plot builders ----------
173
  def cross_plot(actual, pred):
174
  a = pd.Series(actual).astype(float)
@@ -226,8 +224,7 @@ def track_plot(df, include_actual=True):
226
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
227
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
228
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
229
- tick0 = _nice_tick0(xmin, step=100)
230
- dtick = _nice_dtick(xmin, xmax)
231
 
232
  fig = go.Figure()
233
  fig.add_trace(go.Scatter(
@@ -257,7 +254,7 @@ def track_plot(df, include_actual=True):
257
  fig.update_xaxes(
258
  title_text="<b>UCS (psi)</b>", side="top", range=[xmin, xmax],
259
  ticks="outside", tickformat=",.0f",
260
- tickmode="linear", tick0=tick0, dtick=dtick,
261
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
262
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
263
  )
@@ -412,7 +409,16 @@ if st.session_state.app_step == "dev":
412
  df0 = next(iter(tmp.values()))
413
  st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
414
 
415
- # ---- Pinned helper at the very top of the page (before preview call) ----
 
 
 
 
 
 
 
 
 
416
  helper_top = st.container()
417
  with helper_top:
418
  st.subheader("Case Building (Development)")
@@ -423,15 +429,6 @@ if st.session_state.app_step == "dev":
423
  else:
424
  st.write("**Upload your data to build a case, then run the model to review development performance.**")
425
 
426
- # preview modal call AFTER helper, so helper stays pinned above
427
- if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
428
- preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
429
- st.session_state.dev_preview = True
430
-
431
- run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
432
- if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
433
- if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
434
-
435
  if run and st.session_state.dev_file_bytes:
436
  book = read_book_bytes(st.session_state.dev_file_bytes)
437
  sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
@@ -457,13 +454,13 @@ if st.session_state.app_step == "dev":
457
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
458
  left, spacer, right = st.columns(PLOT_COLS)
459
  with left:
460
- st.markdown("<div class='align-right'>", unsafe_allow_html=True)
461
- st.plotly_chart(
462
- cross_plot(df[TARGET], df["UCS_Pred"]),
463
- use_container_width=False,
464
- config={"displayModeBar": False, "scrollZoom": True}
465
- )
466
- st.markdown("</div>", unsafe_allow_html=True)
467
  with right:
468
  st.plotly_chart(
469
  track_plot(df, include_actual=True),
@@ -489,16 +486,15 @@ if st.session_state.app_step == "validate":
489
  if book:
490
  df0 = next(iter(book.values()))
491
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
492
-
493
- st.subheader("Validate the Model")
494
- st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
495
-
496
  if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
497
  preview_modal(read_book_bytes(up.getvalue()))
498
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
499
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
500
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
501
 
 
 
 
502
  if go_btn and up is not None:
503
  book = read_book_bytes(up.getvalue())
504
  name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
@@ -519,25 +515,26 @@ if st.session_state.app_step == "validate":
519
  st.session_state.results["oor_tbl"]=tbl
520
 
521
  if "Validate" in st.session_state.results:
522
- m = st.session_state.results["m_val"]; sv = st.session_state.results["sv_val"]
523
  c1,c2,c3 = st.columns(3)
524
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
525
 
526
  left, spacer, right = st.columns(PLOT_COLS)
527
  with left:
528
- st.markdown("<div class='align-right'>", unsafe_allow_html=True)
529
- st.plotly_chart(
530
- cross_plot(st.session_state.results["Validate"][TARGET],
531
- st.session_state.results["Validate"]["UCS_Pred"]),
532
- use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
533
- )
534
- st.markdown("</div>", unsafe_allow_html=True)
535
  with right:
536
  st.plotly_chart(
537
  track_plot(st.session_state.results["Validate"], include_actual=True),
538
  use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
539
  )
540
 
 
541
  if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
542
  if st.session_state.results["oor_tbl"] is not None:
543
  st.write("*Out-of-range rows (vs. Training min–max):*")
 
25
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
 
27
  # ---- Plot sizing controls (edit here) ----
28
+ CROSS_W = 500; CROSS_H = 500 # square cross-plot (Build + Validate)
29
+ TRACK_W = 400; TRACK_H = 950 # log-strip style (all pages)
30
  FONT_SZ = 13
31
+ PLOT_COLS = [14, 0.3, 10] # 3-column band: left • spacer • right (Build + Validate)
32
+ CROSS_NUDGE = 1.2 # push cross-plot to the RIGHT inside its band:
33
+ # inner columns [CROSS_NUDGE : 1] → bigger = more right
34
 
35
  # =========================
36
  # Page / CSS
37
  # =========================
38
  st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
39
  st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
 
 
40
  st.markdown(
41
  """
42
  <style>
 
147
  return {sh: xl.parse(sh) for sh in xl.sheet_names}
148
 
149
  def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
150
+
151
+ def ensure_cols(df, cols):
152
+ miss = [c for c in cols if c not in df.columns]
153
+ if miss:
154
+ st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
155
+ return False
156
+ return True
157
+
158
  def find_sheet(book, names):
159
  low2orig = {k.lower(): k for k in book.keys()}
160
  for nm in names:
 
167
  return xmin
168
  return step * math.floor(xmin / step)
169
 
 
 
 
 
 
 
 
 
 
170
  # ---------- Plot builders ----------
171
  def cross_plot(actual, pred):
172
  a = pd.Series(actual).astype(float)
 
224
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
225
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
226
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
227
+ tick0 = _nice_tick0(xmin, step=100) # sensible first tick at left border
 
228
 
229
  fig = go.Figure()
230
  fig.add_trace(go.Scatter(
 
254
  fig.update_xaxes(
255
  title_text="<b>UCS (psi)</b>", side="top", range=[xmin, xmax],
256
  ticks="outside", tickformat=",.0f",
257
+ tickmode="auto", tick0=tick0,
258
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
259
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
260
  )
 
409
  df0 = next(iter(tmp.values()))
410
  st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
411
 
412
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
413
+ preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
414
+ st.session_state.dev_preview = True
415
+
416
+ run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
417
+ # always available nav
418
+ if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
419
+ if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
420
+
421
+ # ---- Pinned helper at the very top of the page ----
422
  helper_top = st.container()
423
  with helper_top:
424
  st.subheader("Case Building (Development)")
 
429
  else:
430
  st.write("**Upload your data to build a case, then run the model to review development performance.**")
431
 
 
 
 
 
 
 
 
 
 
432
  if run and st.session_state.dev_file_bytes:
433
  book = read_book_bytes(st.session_state.dev_file_bytes)
434
  sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
 
454
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
455
  left, spacer, right = st.columns(PLOT_COLS)
456
  with left:
457
+ pad, plotcol = left.columns([CROSS_NUDGE, 1]) # shift cross-plot right inside its band
458
+ with plotcol:
459
+ st.plotly_chart(
460
+ cross_plot(df[TARGET], df["UCS_Pred"]),
461
+ use_container_width=False,
462
+ config={"displayModeBar": False, "scrollZoom": True}
463
+ )
464
  with right:
465
  st.plotly_chart(
466
  track_plot(df, include_actual=True),
 
486
  if book:
487
  df0 = next(iter(book.values()))
488
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
 
 
 
 
489
  if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
490
  preview_modal(read_book_bytes(up.getvalue()))
491
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
492
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
493
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
494
 
495
+ st.subheader("Validate the Model")
496
+ st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
497
+
498
  if go_btn and up is not None:
499
  book = read_book_bytes(up.getvalue())
500
  name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
 
515
  st.session_state.results["oor_tbl"]=tbl
516
 
517
  if "Validate" in st.session_state.results:
518
+ m = st.session_state.results["m_val"]
519
  c1,c2,c3 = st.columns(3)
520
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
521
 
522
  left, spacer, right = st.columns(PLOT_COLS)
523
  with left:
524
+ pad, plotcol = left.columns([CROSS_NUDGE, 1]) # same nudge
525
+ with plotcol:
526
+ st.plotly_chart(
527
+ cross_plot(st.session_state.results["Validate"][TARGET],
528
+ st.session_state.results["Validate"]["UCS_Pred"]),
529
+ use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
530
+ )
531
  with right:
532
  st.plotly_chart(
533
  track_plot(st.session_state.results["Validate"], include_actual=True),
534
  use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
535
  )
536
 
537
+ sv = st.session_state.results["sv_val"]
538
  if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
539
  if st.session_state.results["oor_tbl"] is not None:
540
  st.write("*Out-of-range rows (vs. Training min–max):*")