UCS2014 commited on
Commit
339bb57
·
verified ·
1 Parent(s): 4d2d27f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -23
app.py CHANGED
@@ -16,6 +16,13 @@ MODELS_DIR = Path("models")
16
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
17
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
18
 
 
 
 
 
 
 
 
19
  # =========================
20
  # Page / Theme
21
  # =========================
@@ -79,13 +86,15 @@ def find_sheet(book, names):
79
  if nm.lower() in low2orig: return low2orig[nm.lower()]
80
  return None
81
 
82
- def cross_plot(actual, pred, title, size=(4.6, 4.6)):
 
83
  fig, ax = plt.subplots(figsize=size, dpi=100)
84
- ax.scatter(actual, pred, s=14, alpha=0.8)
85
  lo = float(np.nanmin([actual.min(), pred.min()]))
86
  hi = float(np.nanmax([actual.max(), pred.max()]))
87
  pad = 0.03 * (hi - lo if hi > lo else 1.0)
88
- ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], '--', lw=1.2, color=(0.35, 0.35, 0.35))
 
89
  ax.set_xlim(lo - pad, hi + pad)
90
  ax.set_ylim(lo - pad, hi + pad)
91
  ax.set_aspect('equal', 'box') # perfect 1:1
@@ -94,31 +103,38 @@ def cross_plot(actual, pred, title, size=(4.6, 4.6)):
94
  return fig
95
 
96
  def depth_or_index_track(df, title=None, include_actual=True):
 
 
 
 
97
  # Find depth-like column if available
98
- depth_col = None
99
- for c in df.columns:
100
- if 'depth' in str(c).lower():
101
- depth_col = c; break
102
 
103
- fig_h = 7.4 if depth_col is not None else 7.0 # taller track; still fits most screens
104
- fig, ax = plt.subplots(figsize=(6.0, fig_h), dpi=100)
 
 
105
 
106
  if depth_col is not None:
107
- ax.plot(df["UCS_Pred"], df[depth_col], '--', lw=1.6, label="UCS_Pred")
 
108
  if include_actual and TARGET in df.columns:
109
- ax.plot(df[TARGET], df[depth_col], '-', lw=2.0, alpha=0.85, label="UCS (actual)")
 
110
  ax.set_ylabel(depth_col); ax.set_xlabel("UCS")
111
  ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
112
  else:
113
  idx = np.arange(1, len(df) + 1)
114
- ax.plot(df["UCS_Pred"], idx, '--', lw=1.6, label="UCS_Pred")
 
115
  if include_actual and TARGET in df.columns:
116
- ax.plot(df[TARGET], idx, '-', lw=2.0, alpha=0.85, label="UCS (actual)")
 
117
  ax.set_ylabel("Point Index"); ax.set_xlabel("UCS")
118
  ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
119
 
120
  ax.grid(True, linestyle=":", alpha=0.4)
121
- if title: ax.set_title(title, pad=8) # no title if None/empty
122
  ax.legend(loc="best")
123
  return fig
124
 
@@ -272,7 +288,7 @@ if st.session_state.app_step == "dev":
272
  on_click=(lambda: st.session_state.update(app_step="predict")) if st.session_state.dev_ready else None,
273
  )
274
 
275
- # ---- Header + helper sentence positioned under the header (your request) ----
276
  st.subheader("Model Development")
277
  st.write("Upload your data to train the model and review the development performance.")
278
 
@@ -323,11 +339,11 @@ if st.session_state.app_step == "dev":
323
  df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
324
  c1,c2,c3 = st.columns(3)
325
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
326
- left,right = st.columns([1,1])
 
327
  with left:
328
  st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
329
  with right:
330
- # no title on the track (cleaner)
331
  st.pyplot(depth_or_index_track(df, title=None, include_actual=True), use_container_width=True)
332
 
333
  if "Test" in st.session_state.results:
@@ -335,7 +351,7 @@ if st.session_state.app_step == "dev":
335
  df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
336
  c1,c2,c3 = st.columns(3)
337
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
338
- left,right = st.columns([1,1])
339
  with left:
340
  st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
341
  with right:
@@ -412,21 +428,36 @@ if st.session_state.app_step == "predict":
412
  st.subheader("Validation Results")
413
  sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
414
 
415
- # ---- NEW: show OOR warning above the plots when applicable ----
416
  if sv["oor_pct"] > 0:
417
  st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
418
 
419
  c1,c2,c3,c4 = st.columns(4)
420
  c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
421
  c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
422
- left,right = st.columns([1,1])
 
423
  with left:
424
  if TARGET in st.session_state.results["Validate"].columns:
425
- st.pyplot(cross_plot(st.session_state.results["Validate"][TARGET], st.session_state.results["Validate"]["UCS_Pred"], "Validation: Actual vs Predicted"), use_container_width=True)
 
 
 
 
 
 
 
426
  else:
427
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
428
  with right:
429
- st.pyplot(depth_or_index_track(st.session_state.results["Validate"], title=None, include_actual=(TARGET in st.session_state.results["Validate"].columns)), use_container_width=True)
 
 
 
 
 
 
 
430
 
431
  if oor_table is not None:
432
  st.write("*Out-of-range rows (vs. Training min–max):*")
@@ -451,4 +482,4 @@ if st.session_state.app_step == "predict":
451
  # Footer
452
  # =========================
453
  st.markdown("---")
454
- st.markdown("<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>", unsafe_allow_html=True)
 
16
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
17
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
18
 
19
+ # Colors for plots
20
+ COLORS = {
21
+ "pred": "#1f77b4", # blue (predicted)
22
+ "actual": "#f2b702", # yellow (actual)
23
+ "ref": "#5a5a5a" # grey 1:1 line
24
+ }
25
+
26
  # =========================
27
  # Page / Theme
28
  # =========================
 
86
  if nm.lower() in low2orig: return low2orig[nm.lower()]
87
  return None
88
 
89
+ def cross_plot(actual, pred, title, size=(3.9, 3.9)):
90
+ """Compact, square cross-plot with a 1:1 reference line."""
91
  fig, ax = plt.subplots(figsize=size, dpi=100)
92
+ ax.scatter(actual, pred, s=14, alpha=0.85, color=COLORS["pred"])
93
  lo = float(np.nanmin([actual.min(), pred.min()]))
94
  hi = float(np.nanmax([actual.max(), pred.max()]))
95
  pad = 0.03 * (hi - lo if hi > lo else 1.0)
96
+ ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad],
97
+ '--', lw=1.2, color=COLORS["ref"])
98
  ax.set_xlim(lo - pad, hi + pad)
99
  ax.set_ylim(lo - pad, hi + pad)
100
  ax.set_aspect('equal', 'box') # perfect 1:1
 
103
  return fig
104
 
105
  def depth_or_index_track(df, title=None, include_actual=True):
106
+ """
107
+ Narrow, tall track: predicted solid blue; actual dotted yellow.
108
+ Works for either Depth on Y or Index on Y.
109
+ """
110
  # Find depth-like column if available
111
+ depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
 
 
 
112
 
113
+ # Narrow width, tall height for logging look
114
+ fig_w = 3.1
115
+ fig_h = 7.6 if depth_col is not None else 7.2
116
+ fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=100)
117
 
118
  if depth_col is not None:
119
+ ax.plot(df["UCS_Pred"], df[depth_col],
120
+ '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
121
  if include_actual and TARGET in df.columns:
122
+ ax.plot(df[TARGET], df[depth_col],
123
+ ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
124
  ax.set_ylabel(depth_col); ax.set_xlabel("UCS")
125
  ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
126
  else:
127
  idx = np.arange(1, len(df) + 1)
128
+ ax.plot(df["UCS_Pred"], idx,
129
+ '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
130
  if include_actual and TARGET in df.columns:
131
+ ax.plot(df[TARGET], idx,
132
+ ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
133
  ax.set_ylabel("Point Index"); ax.set_xlabel("UCS")
134
  ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
135
 
136
  ax.grid(True, linestyle=":", alpha=0.4)
137
+ if title: ax.set_title(title, pad=8) # keep no title by passing None
138
  ax.legend(loc="best")
139
  return fig
140
 
 
288
  on_click=(lambda: st.session_state.update(app_step="predict")) if st.session_state.dev_ready else None,
289
  )
290
 
291
+ # Header + helper sentence under the header
292
  st.subheader("Model Development")
293
  st.write("Upload your data to train the model and review the development performance.")
294
 
 
339
  df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
340
  c1,c2,c3 = st.columns(3)
341
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
342
+ # Narrow track column for log-look
343
+ left, right = st.columns([0.9, 0.55])
344
  with left:
345
  st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
346
  with right:
 
347
  st.pyplot(depth_or_index_track(df, title=None, include_actual=True), use_container_width=True)
348
 
349
  if "Test" in st.session_state.results:
 
351
  df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
352
  c1,c2,c3 = st.columns(3)
353
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
354
+ left, right = st.columns([0.9, 0.55])
355
  with left:
356
  st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
357
  with right:
 
428
  st.subheader("Validation Results")
429
  sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
430
 
431
+ # Show OOR warning above the plots when applicable
432
  if sv["oor_pct"] > 0:
433
  st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
434
 
435
  c1,c2,c3,c4 = st.columns(4)
436
  c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
437
  c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
438
+
439
+ left, right = st.columns([0.9, 0.55]) # slim log-look track
440
  with left:
441
  if TARGET in st.session_state.results["Validate"].columns:
442
+ st.pyplot(
443
+ cross_plot(
444
+ st.session_state.results["Validate"][TARGET],
445
+ st.session_state.results["Validate"]["UCS_Pred"],
446
+ "Validation: Actual vs Predicted"
447
+ ),
448
+ use_container_width=True
449
+ )
450
  else:
451
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
452
  with right:
453
+ st.pyplot(
454
+ depth_or_index_track(
455
+ st.session_state.results["Validate"],
456
+ title=None,
457
+ include_actual=(TARGET in st.session_state.results["Validate"].columns)
458
+ ),
459
+ use_container_width=True
460
+ )
461
 
462
  if oor_table is not None:
463
  st.write("*Out-of-range rows (vs. Training min–max):*")
 
482
  # Footer
483
  # =========================
484
  st.markdown("---")
485
+ st.markdown("<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>", unsafe_allow_html=True)