UCS2014 commited on
Commit
cc91bfb
·
verified ·
1 Parent(s): 91dd3ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -62
app.py CHANGED
@@ -115,7 +115,8 @@ st.markdown("""
115
  def inline_logo(path="logo.png") -> str:
116
  try:
117
  p = Path(path)
118
- if not p.exists(): return ""
 
119
  return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
120
  except Exception:
121
  return ""
@@ -194,7 +195,8 @@ def ensure_cols(df, cols):
194
  def find_sheet(book, names):
195
  low2orig = {k.lower(): k for k in book.keys()}
196
  for nm in names:
197
- if nm.lower() in low2orig: return low2orig[nm.lower()]
 
198
  return None
199
 
200
  def _nice_tick0(xmin: float, step: int = 5) -> float:
@@ -233,9 +235,8 @@ def to_actual_series(df: pd.DataFrame, target_col: str, actual_col_hint: str, tr
233
  return pd.Series(df["GR"], dtype=float)
234
  raise ValueError("Cannot find actual GR column or target to invert.")
235
 
236
- # === NEW: Excel export helpers =================================================
237
  def _excel_engine() -> str:
238
- """Prefer xlsxwriter for better formatting; fall back to openpyxl if missing."""
239
  try:
240
  import xlsxwriter # noqa: F401
241
  return "xlsxwriter"
@@ -243,7 +244,6 @@ def _excel_engine() -> str:
243
  return "openpyxl"
244
 
245
  def _excel_safe_name(name: str) -> str:
246
- """Excel sheet names: max 31 chars, no []:*?/\\."""
247
  bad = '[]:*?/\\'
248
  safe = ''.join('_' if ch in bad else ch for ch in str(name))
249
  return safe[:31]
@@ -273,10 +273,6 @@ def _train_ranges_df(ranges: dict[str, tuple[float, float]]) -> pd.DataFrame:
273
  return _round_numeric(df)
274
 
275
  def build_export_workbook() -> tuple[bytes|None, str|None, list[str]]:
276
- """
277
- Build a multi-sheet Excel workbook (as bytes) from what's currently in session state.
278
- Returns: (bytes_or_None, filename_or_None, [sheet_names])
279
- """
280
  res = st.session_state.get("results", {})
281
  if not res:
282
  return None, None, []
@@ -379,7 +375,6 @@ def render_export_button(key: str = "export_main") -> None:
379
  help="Exports all available results, metrics, summaries, OOR, training ranges, and info.",
380
  key=key,
381
  )
382
- # ================================================================================
383
 
384
  # =========================
385
  # Cross plot (Matplotlib) — auto limits for GR
@@ -524,10 +519,8 @@ def track_plot(df, include_actual=True, pred_col="GR_Pred", actual_col="GR"):
524
 
525
  return fig
526
 
527
-
528
  # ---------- Preview modal (matplotlib) — y-axis reversed ----------
529
  def preview_tracks(df: pd.DataFrame, cols: list[str]):
530
- # keep only columns that exist
531
  cols = [c for c in cols if c in df.columns]
532
  n = len(cols)
533
  if n == 0:
@@ -535,7 +528,6 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
535
  ax.text(0.5, 0.5, "No selected columns", ha="center", va="center"); ax.axis("off")
536
  return fig
537
 
538
- # use Depth if present, else 1..N
539
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
540
  if depth_col is not None:
541
  y = pd.Series(df[depth_col], dtype=float)
@@ -544,7 +536,6 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
544
  y = pd.Series(np.arange(1, len(df) + 1), dtype=float)
545
  ylab = "Point Index"
546
 
547
- # IMPORTANT: don't share y so inversion always applies
548
  fig, axes = plt.subplots(1, n, figsize=(2.4 * n, 7.0), dpi=100, sharey=False)
549
  if n == 1:
550
  axes = [axes]
@@ -554,11 +545,8 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
554
  ax.set_xlabel(col)
555
  ax.xaxis.set_label_position('top')
556
  ax.xaxis.tick_top()
557
-
558
- # Reverse y-axis universally: shallow at top, deep at bottom
559
  ax.set_ylim(float(y.min()), float(y.max()))
560
  ax.invert_yaxis()
561
-
562
  ax.grid(True, linestyle=":", alpha=0.3)
563
  for s in ax.spines.values():
564
  s.set_visible(True)
@@ -566,7 +554,6 @@ def preview_tracks(df: pd.DataFrame, cols: list[str]):
566
  axes[0].set_ylabel(ylab)
567
  return fig
568
 
569
-
570
  # Modal wrapper (Streamlit compatibility)
571
  try:
572
  dialog = st.dialog
@@ -584,9 +571,11 @@ except AttributeError:
584
  # =========================
585
  def ensure_model() -> Path|None:
586
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
587
- if p.exists() and p.stat().st_size > 0: return p
 
588
  url = os.environ.get("MODEL_URL", "")
589
- if not url: return None
 
590
  try:
591
  import requests
592
  DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
@@ -594,7 +583,8 @@ def ensure_model() -> Path|None:
594
  r.raise_for_status()
595
  with open(DEFAULT_MODEL, "wb") as f:
596
  for chunk in r.iter_content(1<<20):
597
- if chunk: f.write(chunk)
 
598
  return DEFAULT_MODEL
599
  except Exception:
600
  return None
@@ -699,8 +689,10 @@ if st.session_state.app_step == "dev":
699
  st.session_state.dev_preview = True
700
 
701
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
702
- if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
703
- if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
 
 
704
 
705
  # Sticky helper
706
  if st.session_state.dev_file_loaded and st.session_state.dev_preview:
@@ -719,7 +711,8 @@ if st.session_state.app_step == "dev":
719
  st.stop()
720
  tr = normalize_df(book[sh_train].copy()); te = normalize_df(book[sh_test].copy())
721
  if not (ensure_cols(tr, FEATURES) and ensure_cols(te, FEATURES)):
722
- st.markdown('<div class="st-message-box st-error">Missing required feature columns.</div>', unsafe_allow_html=True); st.stop()
 
723
 
724
  # predictions (handle log targets)
725
  tr_pred_raw = model.predict(tr[FEATURES])
@@ -747,42 +740,43 @@ if st.session_state.app_step == "dev":
747
  st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
748
  st.markdown('<div class="st-message-box st-success">Case has been built and results are displayed below.</div>', unsafe_allow_html=True)
749
 
 
750
  def _dev_block(df, m):
751
- c1, c2, c3 = st.columns(3)
752
- c1.metric("R", f"{m['R']:.3f}")
753
- c2.metric("RMSE", f"{m['RMSE']:.3f}")
754
- c3.metric("MAE", f"{m['MAE']:.3f}")
755
-
756
- st.markdown(
757
- """
758
- <div style='text-align:left;font-size:0.8em;color:#6b7280;margin-top:-16px;margin-bottom:8px;'>
759
- <strong>R:</strong> Pearson Correlation Coefficient<br>
760
- <strong>RMSE:</strong> Root Mean Square Error<br>
761
- <strong>MAE:</strong> Mean Absolute Error
762
- </div>
763
- """,
764
- unsafe_allow_html=True,
765
- )
766
 
767
- col_track, col_cross = st.columns([2, 3], gap="large")
768
- with col_track:
769
- st.plotly_chart(
770
- track_plot(df, include_actual=True, pred_col="GR_Pred", actual_col="GR_Actual"),
771
- use_container_width=False,
772
- config={"displayModeBar": False, "scrollZoom": True},
 
 
 
773
  )
774
- with col_cross:
775
- st.pyplot(cross_plot_static(df["GR_Actual"], df["GR_Pred"]), use_container_width=False)
776
-
777
-
778
- if "Train" in st.session_state.results or "Test" in st.session_state.results:
779
- tab1, tab2 = st.tabs(["Training", "Testing"])
780
- if "Train" in st.session_state.results:
781
- with tab1:
782
- _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
783
- if "Test" in st.session_state.results:
784
- with tab2:
785
- _dev_block(st.session_state.results["Test"], st.session_state.results["m_test"])
 
 
 
 
 
 
 
 
786
  # =========================
787
  # VALIDATION (with actual GR)
788
  # =========================
@@ -797,8 +791,10 @@ if st.session_state.app_step == "validate":
797
  if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
798
  st.session_state.show_preview_modal = True
799
  go_btn = st.sidebar.button("Predict & Validate", type="primary", use_container_width=True)
800
- if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
801
- if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
 
 
802
 
803
  sticky_header("Validate the Model", "Upload a dataset with the same **features** and **GR** to evaluate performance.")
804
 
@@ -827,7 +823,8 @@ if st.session_state.app_step == "validate":
827
  if any_viol.any():
828
  tbl = df.loc[any_viol, FEATURES].copy()
829
  for c in FEATURES:
830
- if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
 
831
  tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
832
 
833
  st.session_state.results["m_val"]={
@@ -867,7 +864,8 @@ if st.session_state.app_step == "validate":
867
  )
868
 
869
  sv = st.session_state.results["sv_val"]
870
- if sv["oor"] > 0: st.markdown('<div class="st-message-box st-warning">Some inputs fall outside **training min–max** ranges.</div>', unsafe_allow_html=True)
 
871
  if st.session_state.results["oor_tbl"] is not None:
872
  st.write("*Out-of-range rows (vs. Training min–max):*")
873
  df_centered_rounded(st.session_state.results["oor_tbl"])
@@ -905,7 +903,8 @@ if st.session_state.app_step == "predict":
905
  if any_viol.any():
906
  oor_tbl = df.loc[any_viol, FEATURES].copy()
907
  for c in FEATURES:
908
- if pd.api.types.is_numeric_dtype(oor_tbl[c]): oor_tbl[c] = oor_tbl[c].round(2)
 
909
  oor_tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
910
  st.session_state.results["sv_pred"]={
911
  "n":len(df),
 
115
  def inline_logo(path="logo.png") -> str:
116
  try:
117
  p = Path(path)
118
+ if not p.exists():
119
+ return ""
120
  return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
121
  except Exception:
122
  return ""
 
195
  def find_sheet(book, names):
196
  low2orig = {k.lower(): k for k in book.keys()}
197
  for nm in names:
198
+ if nm.lower() in low2orig:
199
+ return low2orig[nm.lower()]
200
  return None
201
 
202
  def _nice_tick0(xmin: float, step: int = 5) -> float:
 
235
  return pd.Series(df["GR"], dtype=float)
236
  raise ValueError("Cannot find actual GR column or target to invert.")
237
 
238
+ # === Excel export helpers =================================================
239
  def _excel_engine() -> str:
 
240
  try:
241
  import xlsxwriter # noqa: F401
242
  return "xlsxwriter"
 
244
  return "openpyxl"
245
 
246
  def _excel_safe_name(name: str) -> str:
 
247
  bad = '[]:*?/\\'
248
  safe = ''.join('_' if ch in bad else ch for ch in str(name))
249
  return safe[:31]
 
273
  return _round_numeric(df)
274
 
275
  def build_export_workbook() -> tuple[bytes|None, str|None, list[str]]:
 
 
 
 
276
  res = st.session_state.get("results", {})
277
  if not res:
278
  return None, None, []
 
375
  help="Exports all available results, metrics, summaries, OOR, training ranges, and info.",
376
  key=key,
377
  )
 
378
 
379
  # =========================
380
  # Cross plot (Matplotlib) — auto limits for GR
 
519
 
520
  return fig
521
 
 
522
  # ---------- Preview modal (matplotlib) — y-axis reversed ----------
523
  def preview_tracks(df: pd.DataFrame, cols: list[str]):
 
524
  cols = [c for c in cols if c in df.columns]
525
  n = len(cols)
526
  if n == 0:
 
528
  ax.text(0.5, 0.5, "No selected columns", ha="center", va="center"); ax.axis("off")
529
  return fig
530
 
 
531
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
532
  if depth_col is not None:
533
  y = pd.Series(df[depth_col], dtype=float)
 
536
  y = pd.Series(np.arange(1, len(df) + 1), dtype=float)
537
  ylab = "Point Index"
538
 
 
539
  fig, axes = plt.subplots(1, n, figsize=(2.4 * n, 7.0), dpi=100, sharey=False)
540
  if n == 1:
541
  axes = [axes]
 
545
  ax.set_xlabel(col)
546
  ax.xaxis.set_label_position('top')
547
  ax.xaxis.tick_top()
 
 
548
  ax.set_ylim(float(y.min()), float(y.max()))
549
  ax.invert_yaxis()
 
550
  ax.grid(True, linestyle=":", alpha=0.3)
551
  for s in ax.spines.values():
552
  s.set_visible(True)
 
554
  axes[0].set_ylabel(ylab)
555
  return fig
556
 
 
557
  # Modal wrapper (Streamlit compatibility)
558
  try:
559
  dialog = st.dialog
 
571
  # =========================
572
  def ensure_model() -> Path|None:
573
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
574
+ if p.exists() and p.stat().st_size > 0:
575
+ return p
576
  url = os.environ.get("MODEL_URL", "")
577
+ if not url:
578
+ return None
579
  try:
580
  import requests
581
  DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
 
583
  r.raise_for_status()
584
  with open(DEFAULT_MODEL, "wb") as f:
585
  for chunk in r.iter_content(1<<20):
586
+ if chunk:
587
+ f.write(chunk)
588
  return DEFAULT_MODEL
589
  except Exception:
590
  return None
 
689
  st.session_state.dev_preview = True
690
 
691
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
692
+ if st.sidebar.button("Proceed to Validation ▶", use_container_width=True):
693
+ st.session_state.app_step="validate"; st.rerun()
694
+ if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True):
695
+ st.session_state.app_step="predict"; st.rerun()
696
 
697
  # Sticky helper
698
  if st.session_state.dev_file_loaded and st.session_state.dev_preview:
 
711
  st.stop()
712
  tr = normalize_df(book[sh_train].copy()); te = normalize_df(book[sh_test].copy())
713
  if not (ensure_cols(tr, FEATURES) and ensure_cols(te, FEATURES)):
714
+ st.markdown('<div class="st-message-box st-error">Missing required feature columns.</div>', unsafe_allow_html=True)
715
+ st.stop()
716
 
717
  # predictions (handle log targets)
718
  tr_pred_raw = model.predict(tr[FEATURES])
 
740
  st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
741
  st.markdown('<div class="st-message-box st-success">Case has been built and results are displayed below.</div>', unsafe_allow_html=True)
742
 
743
+ # -------- Metrics + Plots (3 decimals here) --------
744
  def _dev_block(df, m):
745
+ c1, c2, c3 = st.columns(3)
746
+ c1.metric("R", f"{m['R']:.3f}")
747
+ c2.metric("RMSE", f"{m['RMSE']:.3f}")
748
+ c3.metric("MAE", f"{m['MAE']:.3f}")
 
 
 
 
 
 
 
 
 
 
 
749
 
750
+ st.markdown(
751
+ """
752
+ <div style='text-align:left;font-size:0.8em;color:#6b7280;margin-top:-16px;margin-bottom:8px;'>
753
+ <strong>R:</strong> Pearson Correlation Coefficient<br>
754
+ <strong>RMSE:</strong> Root Mean Square Error<br>
755
+ <strong>MAE:</strong> Mean Absolute Error
756
+ </div>
757
+ """,
758
+ unsafe_allow_html=True,
759
  )
760
+
761
+ col_track, col_cross = st.columns([2, 3], gap="large")
762
+ with col_track:
763
+ st.plotly_chart(
764
+ track_plot(df, include_actual=True, pred_col="GR_Pred", actual_col="GR_Actual"),
765
+ use_container_width=False,
766
+ config={"displayModeBar": False, "scrollZoom": True},
767
+ )
768
+ with col_cross:
769
+ st.pyplot(cross_plot_static(df["GR_Actual"], df["GR_Pred"]), use_container_width=False)
770
+
771
+ if "Train" in st.session_state.results or "Test" in st.session_state.results:
772
+ tab1, tab2 = st.tabs(["Training", "Testing"])
773
+ if "Train" in st.session_state.results:
774
+ with tab1:
775
+ _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
776
+ if "Test" in st.session_state.results:
777
+ with tab2:
778
+ _dev_block(st.session_state.results["Test"], st.session_state.results["m_test"])
779
+
780
  # =========================
781
  # VALIDATION (with actual GR)
782
  # =========================
 
791
  if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
792
  st.session_state.show_preview_modal = True
793
  go_btn = st.sidebar.button("Predict & Validate", type="primary", use_container_width=True)
794
+ if st.sidebar.button("⬅ Back to Case Building", use_container_width=True):
795
+ st.session_state.app_step="dev"; st.rerun()
796
+ if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True):
797
+ st.session_state.app_step="predict"; st.rerun()
798
 
799
  sticky_header("Validate the Model", "Upload a dataset with the same **features** and **GR** to evaluate performance.")
800
 
 
823
  if any_viol.any():
824
  tbl = df.loc[any_viol, FEATURES].copy()
825
  for c in FEATURES:
826
+ if pd.api.types.is_numeric_dtype(tbl[c]):
827
+ tbl[c] = tbl[c].round(2)
828
  tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
829
 
830
  st.session_state.results["m_val"]={
 
864
  )
865
 
866
  sv = st.session_state.results["sv_val"]
867
+ if sv["oor"] > 0:
868
+ st.markdown('<div class="st-message-box st-warning">Some inputs fall outside **training min–max** ranges.</div>', unsafe_allow_html=True)
869
  if st.session_state.results["oor_tbl"] is not None:
870
  st.write("*Out-of-range rows (vs. Training min–max):*")
871
  df_centered_rounded(st.session_state.results["oor_tbl"])
 
903
  if any_viol.any():
904
  oor_tbl = df.loc[any_viol, FEATURES].copy()
905
  for c in FEATURES:
906
+ if pd.api.types.is_numeric_dtype(oor_tbl[c]):
907
+ oor_tbl[c] = oor_tbl[c].round(2)
908
  oor_tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
909
  st.session_state.results["sv_pred"]={
910
  "n":len(df),