Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,6 +16,13 @@ MODELS_DIR = Path("models")
|
|
| 16 |
DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
|
| 17 |
MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# =========================
|
| 20 |
# Page / Theme
|
| 21 |
# =========================
|
|
@@ -79,13 +86,15 @@ def find_sheet(book, names):
|
|
| 79 |
if nm.lower() in low2orig: return low2orig[nm.lower()]
|
| 80 |
return None
|
| 81 |
|
| 82 |
-
def cross_plot(actual, pred, title, size=(
|
|
|
|
| 83 |
fig, ax = plt.subplots(figsize=size, dpi=100)
|
| 84 |
-
ax.scatter(actual, pred, s=14, alpha=0.
|
| 85 |
lo = float(np.nanmin([actual.min(), pred.min()]))
|
| 86 |
hi = float(np.nanmax([actual.max(), pred.max()]))
|
| 87 |
pad = 0.03 * (hi - lo if hi > lo else 1.0)
|
| 88 |
-
ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad],
|
|
|
|
| 89 |
ax.set_xlim(lo - pad, hi + pad)
|
| 90 |
ax.set_ylim(lo - pad, hi + pad)
|
| 91 |
ax.set_aspect('equal', 'box') # perfect 1:1
|
|
@@ -94,31 +103,38 @@ def cross_plot(actual, pred, title, size=(4.6, 4.6)):
|
|
| 94 |
return fig
|
| 95 |
|
| 96 |
def depth_or_index_track(df, title=None, include_actual=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
# Find depth-like column if available
|
| 98 |
-
depth_col = None
|
| 99 |
-
for c in df.columns:
|
| 100 |
-
if 'depth' in str(c).lower():
|
| 101 |
-
depth_col = c; break
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
|
| 106 |
if depth_col is not None:
|
| 107 |
-
ax.plot(df["UCS_Pred"], df[depth_col],
|
|
|
|
| 108 |
if include_actual and TARGET in df.columns:
|
| 109 |
-
ax.plot(df[TARGET], df[depth_col],
|
|
|
|
| 110 |
ax.set_ylabel(depth_col); ax.set_xlabel("UCS")
|
| 111 |
ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
|
| 112 |
else:
|
| 113 |
idx = np.arange(1, len(df) + 1)
|
| 114 |
-
ax.plot(df["UCS_Pred"], idx,
|
|
|
|
| 115 |
if include_actual and TARGET in df.columns:
|
| 116 |
-
ax.plot(df[TARGET], idx,
|
|
|
|
| 117 |
ax.set_ylabel("Point Index"); ax.set_xlabel("UCS")
|
| 118 |
ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
|
| 119 |
|
| 120 |
ax.grid(True, linestyle=":", alpha=0.4)
|
| 121 |
-
if title: ax.set_title(title, pad=8) # no title
|
| 122 |
ax.legend(loc="best")
|
| 123 |
return fig
|
| 124 |
|
|
@@ -272,7 +288,7 @@ if st.session_state.app_step == "dev":
|
|
| 272 |
on_click=(lambda: st.session_state.update(app_step="predict")) if st.session_state.dev_ready else None,
|
| 273 |
)
|
| 274 |
|
| 275 |
-
#
|
| 276 |
st.subheader("Model Development")
|
| 277 |
st.write("Upload your data to train the model and review the development performance.")
|
| 278 |
|
|
@@ -323,11 +339,11 @@ if st.session_state.app_step == "dev":
|
|
| 323 |
df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
|
| 324 |
c1,c2,c3 = st.columns(3)
|
| 325 |
c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
|
| 326 |
-
|
|
|
|
| 327 |
with left:
|
| 328 |
st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
|
| 329 |
with right:
|
| 330 |
-
# no title on the track (cleaner)
|
| 331 |
st.pyplot(depth_or_index_track(df, title=None, include_actual=True), use_container_width=True)
|
| 332 |
|
| 333 |
if "Test" in st.session_state.results:
|
|
@@ -335,7 +351,7 @@ if st.session_state.app_step == "dev":
|
|
| 335 |
df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
|
| 336 |
c1,c2,c3 = st.columns(3)
|
| 337 |
c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
|
| 338 |
-
left,right = st.columns([
|
| 339 |
with left:
|
| 340 |
st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
|
| 341 |
with right:
|
|
@@ -412,21 +428,36 @@ if st.session_state.app_step == "predict":
|
|
| 412 |
st.subheader("Validation Results")
|
| 413 |
sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
|
| 414 |
|
| 415 |
-
#
|
| 416 |
if sv["oor_pct"] > 0:
|
| 417 |
st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
|
| 418 |
|
| 419 |
c1,c2,c3,c4 = st.columns(4)
|
| 420 |
c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
|
| 421 |
c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
|
| 422 |
-
|
|
|
|
| 423 |
with left:
|
| 424 |
if TARGET in st.session_state.results["Validate"].columns:
|
| 425 |
-
st.pyplot(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
else:
|
| 427 |
st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
|
| 428 |
with right:
|
| 429 |
-
st.pyplot(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
if oor_table is not None:
|
| 432 |
st.write("*Out-of-range rows (vs. Training min–max):*")
|
|
@@ -451,4 +482,4 @@ if st.session_state.app_step == "predict":
|
|
| 451 |
# Footer
|
| 452 |
# =========================
|
| 453 |
st.markdown("---")
|
| 454 |
-
st.markdown("<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>", unsafe_allow_html=True)
|
|
|
|
| 16 |
DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
|
| 17 |
MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
|
| 18 |
|
| 19 |
+
# Colors for plots
|
| 20 |
+
COLORS = {
|
| 21 |
+
"pred": "#1f77b4", # blue (predicted)
|
| 22 |
+
"actual": "#f2b702", # yellow (actual)
|
| 23 |
+
"ref": "#5a5a5a" # grey 1:1 line
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
# =========================
|
| 27 |
# Page / Theme
|
| 28 |
# =========================
|
|
|
|
| 86 |
if nm.lower() in low2orig: return low2orig[nm.lower()]
|
| 87 |
return None
|
| 88 |
|
| 89 |
+
def cross_plot(actual, pred, title, size=(3.9, 3.9)):
|
| 90 |
+
"""Compact, square cross-plot with a 1:1 reference line."""
|
| 91 |
fig, ax = plt.subplots(figsize=size, dpi=100)
|
| 92 |
+
ax.scatter(actual, pred, s=14, alpha=0.85, color=COLORS["pred"])
|
| 93 |
lo = float(np.nanmin([actual.min(), pred.min()]))
|
| 94 |
hi = float(np.nanmax([actual.max(), pred.max()]))
|
| 95 |
pad = 0.03 * (hi - lo if hi > lo else 1.0)
|
| 96 |
+
ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad],
|
| 97 |
+
'--', lw=1.2, color=COLORS["ref"])
|
| 98 |
ax.set_xlim(lo - pad, hi + pad)
|
| 99 |
ax.set_ylim(lo - pad, hi + pad)
|
| 100 |
ax.set_aspect('equal', 'box') # perfect 1:1
|
|
|
|
| 103 |
return fig
|
| 104 |
|
| 105 |
def depth_or_index_track(df, title=None, include_actual=True):
|
| 106 |
+
"""
|
| 107 |
+
Narrow, tall track: predicted solid blue; actual dotted yellow.
|
| 108 |
+
Works for either Depth on Y or Index on Y.
|
| 109 |
+
"""
|
| 110 |
# Find depth-like column if available
|
| 111 |
+
depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
# Narrow width, tall height for logging look
|
| 114 |
+
fig_w = 3.1
|
| 115 |
+
fig_h = 7.6 if depth_col is not None else 7.2
|
| 116 |
+
fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=100)
|
| 117 |
|
| 118 |
if depth_col is not None:
|
| 119 |
+
ax.plot(df["UCS_Pred"], df[depth_col],
|
| 120 |
+
'-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
|
| 121 |
if include_actual and TARGET in df.columns:
|
| 122 |
+
ax.plot(df[TARGET], df[depth_col],
|
| 123 |
+
':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
|
| 124 |
ax.set_ylabel(depth_col); ax.set_xlabel("UCS")
|
| 125 |
ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
|
| 126 |
else:
|
| 127 |
idx = np.arange(1, len(df) + 1)
|
| 128 |
+
ax.plot(df["UCS_Pred"], idx,
|
| 129 |
+
'-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
|
| 130 |
if include_actual and TARGET in df.columns:
|
| 131 |
+
ax.plot(df[TARGET], idx,
|
| 132 |
+
':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
|
| 133 |
ax.set_ylabel("Point Index"); ax.set_xlabel("UCS")
|
| 134 |
ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
|
| 135 |
|
| 136 |
ax.grid(True, linestyle=":", alpha=0.4)
|
| 137 |
+
if title: ax.set_title(title, pad=8) # keep no title by passing None
|
| 138 |
ax.legend(loc="best")
|
| 139 |
return fig
|
| 140 |
|
|
|
|
| 288 |
on_click=(lambda: st.session_state.update(app_step="predict")) if st.session_state.dev_ready else None,
|
| 289 |
)
|
| 290 |
|
| 291 |
+
# Header + helper sentence under the header
|
| 292 |
st.subheader("Model Development")
|
| 293 |
st.write("Upload your data to train the model and review the development performance.")
|
| 294 |
|
|
|
|
| 339 |
df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"]
|
| 340 |
c1,c2,c3 = st.columns(3)
|
| 341 |
c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
|
| 342 |
+
# Narrow track column for log-look
|
| 343 |
+
left, right = st.columns([0.9, 0.55])
|
| 344 |
with left:
|
| 345 |
st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
|
| 346 |
with right:
|
|
|
|
| 347 |
st.pyplot(depth_or_index_track(df, title=None, include_actual=True), use_container_width=True)
|
| 348 |
|
| 349 |
if "Test" in st.session_state.results:
|
|
|
|
| 351 |
df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
|
| 352 |
c1,c2,c3 = st.columns(3)
|
| 353 |
c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
|
| 354 |
+
left, right = st.columns([0.9, 0.55])
|
| 355 |
with left:
|
| 356 |
st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
|
| 357 |
with right:
|
|
|
|
| 428 |
st.subheader("Validation Results")
|
| 429 |
sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table")
|
| 430 |
|
| 431 |
+
# Show OOR warning above the plots when applicable
|
| 432 |
if sv["oor_pct"] > 0:
|
| 433 |
st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.")
|
| 434 |
|
| 435 |
c1,c2,c3,c4 = st.columns(4)
|
| 436 |
c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}")
|
| 437 |
c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%")
|
| 438 |
+
|
| 439 |
+
left, right = st.columns([0.9, 0.55]) # slim log-look track
|
| 440 |
with left:
|
| 441 |
if TARGET in st.session_state.results["Validate"].columns:
|
| 442 |
+
st.pyplot(
|
| 443 |
+
cross_plot(
|
| 444 |
+
st.session_state.results["Validate"][TARGET],
|
| 445 |
+
st.session_state.results["Validate"]["UCS_Pred"],
|
| 446 |
+
"Validation: Actual vs Predicted"
|
| 447 |
+
),
|
| 448 |
+
use_container_width=True
|
| 449 |
+
)
|
| 450 |
else:
|
| 451 |
st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
|
| 452 |
with right:
|
| 453 |
+
st.pyplot(
|
| 454 |
+
depth_or_index_track(
|
| 455 |
+
st.session_state.results["Validate"],
|
| 456 |
+
title=None,
|
| 457 |
+
include_actual=(TARGET in st.session_state.results["Validate"].columns)
|
| 458 |
+
),
|
| 459 |
+
use_container_width=True
|
| 460 |
+
)
|
| 461 |
|
| 462 |
if oor_table is not None:
|
| 463 |
st.write("*Out-of-range rows (vs. Training min–max):*")
|
|
|
|
| 482 |
# Footer
|
| 483 |
# =========================
|
| 484 |
st.markdown("---")
|
| 485 |
+
st.markdown("<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>", unsafe_allow_html=True)
|