Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# app.py
|
| 2 |
import io, json, os, base64, math
|
| 3 |
from pathlib import Path
|
| 4 |
import streamlit as st
|
|
@@ -26,9 +25,9 @@ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
|
|
| 26 |
COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
|
| 27 |
|
| 28 |
# ---- Plot sizing controls ----
|
| 29 |
-
CROSS_W = 450
|
| 30 |
CROSS_H = 450
|
| 31 |
-
TRACK_H = 740
|
| 32 |
FONT_SZ = 13
|
| 33 |
BOLD_FONT = "Arial Black, Arial, sans-serif" # used for bold axis titles & ticks
|
| 34 |
|
|
@@ -36,15 +35,20 @@ BOLD_FONT = "Arial Black, Arial, sans-serif" # used for bold axis titles & tick
|
|
| 36 |
# Page / CSS
|
| 37 |
# =========================
|
| 38 |
st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
|
|
|
|
|
|
|
| 39 |
st.markdown("""
|
| 40 |
<style>
|
| 41 |
-
/* Reusable logo style */
|
| 42 |
.brand-logo { width: 16px; height: auto; object-fit: contain; }
|
| 43 |
-
|
| 44 |
-
/* Sidebar header layout */
|
| 45 |
.sidebar-header { display:flex; align-items:center; gap:12px; }
|
| 46 |
.sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
|
| 47 |
.sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
</style>
|
| 49 |
""", unsafe_allow_html=True)
|
| 50 |
|
|
@@ -62,6 +66,26 @@ section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] s
|
|
| 62 |
</style>
|
| 63 |
""", unsafe_allow_html=True)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# Center text in all pandas Styler tables (headers + cells)
|
| 66 |
TABLE_CENTER_CSS = [
|
| 67 |
dict(selector="th", props=[("text-align", "center")]),
|
|
@@ -92,8 +116,16 @@ def add_password_gate() -> None:
|
|
| 92 |
if st.session_state.get("auth_ok", False):
|
| 93 |
return
|
| 94 |
|
| 95 |
-
st.sidebar.image("logo.png", use_column_width=False)
|
| 96 |
-
st.sidebar.markdown("### ST_GeoMech_UCS\nSmart Thinking • Secure Access")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
|
| 98 |
if st.sidebar.button("Unlock", type="primary"):
|
| 99 |
if pwd == required:
|
|
@@ -146,13 +178,15 @@ def _nice_tick0(xmin: float, step: int = 100) -> float:
|
|
| 146 |
return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
|
| 147 |
|
| 148 |
def df_centered_rounded(df: pd.DataFrame, hide_index=True):
|
| 149 |
-
"""
|
| 150 |
out = df.copy()
|
| 151 |
numcols = out.select_dtypes(include=[np.number]).columns
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
st.dataframe(styler, use_container_width=True, hide_index=hide_index)
|
| 157 |
|
| 158 |
# =========================
|
|
@@ -160,7 +194,7 @@ def df_centered_rounded(df: pd.DataFrame, hide_index=True):
|
|
| 160 |
# =========================
|
| 161 |
def cross_plot_static(actual, pred):
|
| 162 |
a = pd.Series(actual, dtype=float)
|
| 163 |
-
p = pd.Series(pred,
|
| 164 |
|
| 165 |
fixed_min, fixed_max = 6000, 10000
|
| 166 |
ticks = np.arange(fixed_min, fixed_max + 1, 1000)
|
|
@@ -180,15 +214,15 @@ def cross_plot_static(actual, pred):
|
|
| 180 |
ax.set_ylim(fixed_min, fixed_max)
|
| 181 |
ax.set_xticks(ticks)
|
| 182 |
ax.set_yticks(ticks)
|
| 183 |
-
ax.set_aspect("equal", adjustable="box")
|
| 184 |
|
| 185 |
fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
|
| 186 |
ax.xaxis.set_major_formatter(fmt)
|
| 187 |
ax.yaxis.set_major_formatter(fmt)
|
| 188 |
|
| 189 |
-
ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=12)
|
| 190 |
-
ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=12)
|
| 191 |
-
ax.tick_params(labelsize=10)
|
| 192 |
|
| 193 |
ax.grid(True, linestyle=":", alpha=0.3)
|
| 194 |
for spine in ax.spines.values():
|
|
@@ -199,7 +233,7 @@ def cross_plot_static(actual, pred):
|
|
| 199 |
return fig
|
| 200 |
|
| 201 |
# =========================
|
| 202 |
-
# Track plot (Plotly)
|
| 203 |
# =========================
|
| 204 |
def track_plot(df, include_actual=True):
|
| 205 |
depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
|
|
@@ -240,7 +274,7 @@ def track_plot(df, include_actual=True):
|
|
| 240 |
height=TRACK_H, width=None, # width auto-fits the column
|
| 241 |
paper_bgcolor="#fff", plot_bgcolor="#fff",
|
| 242 |
margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
|
| 243 |
-
font=dict(size=FONT_SZ),
|
| 244 |
legend=dict(
|
| 245 |
x=0.98, y=0.05, xanchor="right", yanchor="bottom",
|
| 246 |
bgcolor="rgba(255,255,255,0.75)", bordercolor="#ccc", borderwidth=1
|
|
@@ -248,11 +282,11 @@ def track_plot(df, include_actual=True):
|
|
| 248 |
legend_title_text=""
|
| 249 |
)
|
| 250 |
|
| 251 |
-
# Bold axis titles &
|
| 252 |
fig.update_xaxes(
|
| 253 |
title_text="UCS (psi)",
|
| 254 |
-
title_font=dict(size=16, family=BOLD_FONT),
|
| 255 |
-
tickfont=dict(size=
|
| 256 |
side="top",
|
| 257 |
range=[xmin, xmax],
|
| 258 |
ticks="outside",
|
|
@@ -262,11 +296,10 @@ def track_plot(df, include_actual=True):
|
|
| 262 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
| 263 |
showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
|
| 264 |
)
|
| 265 |
-
|
| 266 |
fig.update_yaxes(
|
| 267 |
title_text=ylab,
|
| 268 |
-
title_font=dict(size=16, family=BOLD_FONT),
|
| 269 |
-
tickfont=dict(size=
|
| 270 |
range=y_range,
|
| 271 |
ticks="outside",
|
| 272 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
|
@@ -275,7 +308,7 @@ def track_plot(df, include_actual=True):
|
|
| 275 |
|
| 276 |
return fig
|
| 277 |
|
| 278 |
-
# ---------- Preview modal (matplotlib)
|
| 279 |
def preview_tracks(df: pd.DataFrame, cols: list[str]):
|
| 280 |
cols = [c for c in cols if c in df.columns]
|
| 281 |
n = len(cols)
|
|
@@ -376,11 +409,19 @@ st.session_state.setdefault("dev_preview",False)
|
|
| 376 |
# =========================
|
| 377 |
# Branding in Sidebar
|
| 378 |
# =========================
|
| 379 |
-
st.sidebar.image("logo.png", use_column_width=False)
|
| 380 |
-
st.sidebar.markdown(
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
)
|
| 385 |
|
| 386 |
# =========================
|
|
@@ -391,7 +432,7 @@ if st.session_state.app_step == "intro":
|
|
| 391 |
st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
|
| 392 |
st.subheader("How It Works")
|
| 393 |
st.markdown(
|
| 394 |
-
"1) **Upload your data to build the case and preview the performance of our model.**
|
| 395 |
"2) Click **Run Model** to compute metrics and plots. \n"
|
| 396 |
"3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
|
| 397 |
)
|
|
@@ -465,7 +506,7 @@ if st.session_state.app_step == "dev":
|
|
| 465 |
c1,c2,c3 = st.columns(3)
|
| 466 |
c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
|
| 467 |
|
| 468 |
-
#
|
| 469 |
col_cross, col_track = st.columns([3, 2], gap="large")
|
| 470 |
with col_cross:
|
| 471 |
st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=True)
|
|
@@ -519,9 +560,7 @@ if st.session_state.app_step == "validate":
|
|
| 519 |
tbl = df.loc[any_viol, FEATURES].copy()
|
| 520 |
for c in FEATURES:
|
| 521 |
if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
|
| 522 |
-
tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(
|
| 523 |
-
lambda r:", ".join([c for c,v in r.items() if v]), axis=1
|
| 524 |
-
)
|
| 525 |
st.session_state.results["m_val"]={
|
| 526 |
"R": pearson_r(df[TARGET], df["UCS_Pred"]),
|
| 527 |
"RMSE": rmse(df[TARGET], df["UCS_Pred"]),
|
|
@@ -628,4 +667,4 @@ st.markdown(
|
|
| 628 |
</div>
|
| 629 |
""",
|
| 630 |
unsafe_allow_html=True
|
| 631 |
-
)
|
|
|
|
|
|
|
| 1 |
import io, json, os, base64, math
|
| 2 |
from pathlib import Path
|
| 3 |
import streamlit as st
|
|
|
|
| 25 |
COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
|
| 26 |
|
| 27 |
# ---- Plot sizing controls ----
|
| 28 |
+
CROSS_W = 450 # px (matplotlib figure size; Streamlit will still scale)
|
| 29 |
CROSS_H = 450
|
| 30 |
+
TRACK_H = 740 # px (plotly height; width auto-fits column)
|
| 31 |
FONT_SZ = 13
|
| 32 |
BOLD_FONT = "Arial Black, Arial, sans-serif" # used for bold axis titles & ticks
|
| 33 |
|
|
|
|
| 35 |
# Page / CSS
|
| 36 |
# =========================
|
| 37 |
st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
|
| 38 |
+
|
| 39 |
+
# General CSS (logo helpers etc.)
|
| 40 |
st.markdown("""
|
| 41 |
<style>
|
|
|
|
| 42 |
.brand-logo { width: 16px; height: auto; object-fit: contain; }
|
|
|
|
|
|
|
| 43 |
.sidebar-header { display:flex; align-items:center; gap:12px; }
|
| 44 |
.sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
|
| 45 |
.sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
|
| 46 |
+
.centered-container {
|
| 47 |
+
display: flex;
|
| 48 |
+
flex-direction: column;
|
| 49 |
+
align-items: center;
|
| 50 |
+
text-align: center;
|
| 51 |
+
}
|
| 52 |
</style>
|
| 53 |
""", unsafe_allow_html=True)
|
| 54 |
|
|
|
|
| 66 |
</style>
|
| 67 |
""", unsafe_allow_html=True)
|
| 68 |
|
| 69 |
+
# Make the Preview expander title & tabs sticky (pinned to the top)
|
| 70 |
+
st.markdown("""
|
| 71 |
+
<style>
|
| 72 |
+
div[data-testid="stExpander"] > details > summary {
|
| 73 |
+
position: sticky;
|
| 74 |
+
top: 0;
|
| 75 |
+
z-index: 10;
|
| 76 |
+
background: #fff;
|
| 77 |
+
border-bottom: 1px solid #eee;
|
| 78 |
+
}
|
| 79 |
+
div[data-testid="stExpander"] div[data-baseweb="tab-list"] {
|
| 80 |
+
position: sticky;
|
| 81 |
+
top: 42px; /* adjust if your expander header height differs */
|
| 82 |
+
z-index: 9;
|
| 83 |
+
background: #fff;
|
| 84 |
+
padding-top: 6px;
|
| 85 |
+
}
|
| 86 |
+
</style>
|
| 87 |
+
""", unsafe_allow_html=True)
|
| 88 |
+
|
| 89 |
# Center text in all pandas Styler tables (headers + cells)
|
| 90 |
TABLE_CENTER_CSS = [
|
| 91 |
dict(selector="th", props=[("text-align", "center")]),
|
|
|
|
| 116 |
if st.session_state.get("auth_ok", False):
|
| 117 |
return
|
| 118 |
|
| 119 |
+
# st.sidebar.image("logo.png", use_column_width=False)
|
| 120 |
+
# st.sidebar.markdown("### ST_GeoMech_UCS\nSmart Thinking • Secure Access")
|
| 121 |
+
st.sidebar.markdown(f"""
|
| 122 |
+
<div class="centered-container">
|
| 123 |
+
<img src="{inline_logo('logo.png')}" style="width: 16px; height: auto; object-fit: contain;">
|
| 124 |
+
<div style='font-weight:800;font-size:1.2rem; margin-top: 10px;'>ST_GeoMech_UCS</div>
|
| 125 |
+
<div style='color:#667085;'>Smart Thinking • Secure Access</div>
|
| 126 |
+
</div>
|
| 127 |
+
""", unsafe_allow_html=True
|
| 128 |
+
)
|
| 129 |
pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
|
| 130 |
if st.sidebar.button("Unlock", type="primary"):
|
| 131 |
if pwd == required:
|
|
|
|
| 178 |
return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
|
| 179 |
|
| 180 |
def df_centered_rounded(df: pd.DataFrame, hide_index=True):
|
| 181 |
+
"""Center headers & cells; format numeric columns to 2 decimals."""
|
| 182 |
out = df.copy()
|
| 183 |
numcols = out.select_dtypes(include=[np.number]).columns
|
| 184 |
+
styler = (
|
| 185 |
+
out.style
|
| 186 |
+
.format({c: "{:.2f}" for c in numcols})
|
| 187 |
+
.set_properties(**{"text-align": "center"})
|
| 188 |
+
.set_table_styles(TABLE_CENTER_CSS)
|
| 189 |
+
)
|
| 190 |
st.dataframe(styler, use_container_width=True, hide_index=hide_index)
|
| 191 |
|
| 192 |
# =========================
|
|
|
|
| 194 |
# =========================
|
| 195 |
def cross_plot_static(actual, pred):
|
| 196 |
a = pd.Series(actual, dtype=float)
|
| 197 |
+
p = pd.Series(pred, dtype=float)
|
| 198 |
|
| 199 |
fixed_min, fixed_max = 6000, 10000
|
| 200 |
ticks = np.arange(fixed_min, fixed_max + 1, 1000)
|
|
|
|
| 214 |
ax.set_ylim(fixed_min, fixed_max)
|
| 215 |
ax.set_xticks(ticks)
|
| 216 |
ax.set_yticks(ticks)
|
| 217 |
+
ax.set_aspect("equal", adjustable="box") # true 45°
|
| 218 |
|
| 219 |
fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
|
| 220 |
ax.xaxis.set_major_formatter(fmt)
|
| 221 |
ax.yaxis.set_major_formatter(fmt)
|
| 222 |
|
| 223 |
+
ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=12, color="black")
|
| 224 |
+
ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=12, color="black")
|
| 225 |
+
ax.tick_params(labelsize=10, colors="black")
|
| 226 |
|
| 227 |
ax.grid(True, linestyle=":", alpha=0.3)
|
| 228 |
for spine in ax.spines.values():
|
|
|
|
| 233 |
return fig
|
| 234 |
|
| 235 |
# =========================
|
| 236 |
+
# Track plot (Plotly)
|
| 237 |
# =========================
|
| 238 |
def track_plot(df, include_actual=True):
|
| 239 |
depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
|
|
|
|
| 274 |
height=TRACK_H, width=None, # width auto-fits the column
|
| 275 |
paper_bgcolor="#fff", plot_bgcolor="#fff",
|
| 276 |
margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
|
| 277 |
+
font=dict(size=FONT_SZ, color="#000"),
|
| 278 |
legend=dict(
|
| 279 |
x=0.98, y=0.05, xanchor="right", yanchor="bottom",
|
| 280 |
bgcolor="rgba(255,255,255,0.75)", bordercolor="#ccc", borderwidth=1
|
|
|
|
| 282 |
legend_title_text=""
|
| 283 |
)
|
| 284 |
|
| 285 |
+
# Bold, black axis titles & ticks
|
| 286 |
fig.update_xaxes(
|
| 287 |
title_text="UCS (psi)",
|
| 288 |
+
title_font=dict(size=16, family=BOLD_FONT, color="#000"),
|
| 289 |
+
tickfont=dict(size=11, family=BOLD_FONT, color="#000"),
|
| 290 |
side="top",
|
| 291 |
range=[xmin, xmax],
|
| 292 |
ticks="outside",
|
|
|
|
| 296 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
| 297 |
showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
|
| 298 |
)
|
|
|
|
| 299 |
fig.update_yaxes(
|
| 300 |
title_text=ylab,
|
| 301 |
+
title_font=dict(size=16, family=BOLD_FONT, color="#000"),
|
| 302 |
+
tickfont=dict(size=11, family=BOLD_FONT, color="#000"),
|
| 303 |
range=y_range,
|
| 304 |
ticks="outside",
|
| 305 |
showline=True, linewidth=1.2, linecolor="#444", mirror=True,
|
|
|
|
| 308 |
|
| 309 |
return fig
|
| 310 |
|
| 311 |
+
# ---------- Preview modal (matplotlib) ----------
|
| 312 |
def preview_tracks(df: pd.DataFrame, cols: list[str]):
|
| 313 |
cols = [c for c in cols if c in df.columns]
|
| 314 |
n = len(cols)
|
|
|
|
| 409 |
# =========================
|
| 410 |
# Branding in Sidebar
|
| 411 |
# =========================
|
| 412 |
+
# st.sidebar.image("logo.png", use_column_width=False)
|
| 413 |
+
# st.sidebar.markdown(
|
| 414 |
+
# "<div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>"
|
| 415 |
+
# "<div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>",
|
| 416 |
+
# unsafe_allow_html=True
|
| 417 |
+
# )
|
| 418 |
+
st.sidebar.markdown(f"""
|
| 419 |
+
<div class="centered-container">
|
| 420 |
+
<img src="{inline_logo('logo.png')}" style="width: 16px; height: auto; object-fit: contain;">
|
| 421 |
+
<div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>
|
| 422 |
+
<div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>
|
| 423 |
+
</div>
|
| 424 |
+
""", unsafe_allow_html=True
|
| 425 |
)
|
| 426 |
|
| 427 |
# =========================
|
|
|
|
| 432 |
st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
|
| 433 |
st.subheader("How It Works")
|
| 434 |
st.markdown(
|
| 435 |
+
"1) **Upload your data to build the case and preview the performance of our model.** \n"
|
| 436 |
"2) Click **Run Model** to compute metrics and plots. \n"
|
| 437 |
"3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
|
| 438 |
)
|
|
|
|
| 506 |
c1,c2,c3 = st.columns(3)
|
| 507 |
c1.metric("R", f"{m['R']:.2f}"); c2.metric("RMSE", f"{m['RMSE']:.2f}"); c3.metric("MAE", f"{m['MAE']:.2f}")
|
| 508 |
|
| 509 |
+
# 2-column layout, big gap (prevents overlap)
|
| 510 |
col_cross, col_track = st.columns([3, 2], gap="large")
|
| 511 |
with col_cross:
|
| 512 |
st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=True)
|
|
|
|
| 560 |
tbl = df.loc[any_viol, FEATURES].copy()
|
| 561 |
for c in FEATURES:
|
| 562 |
if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
|
| 563 |
+
tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
|
|
|
|
|
|
|
| 564 |
st.session_state.results["m_val"]={
|
| 565 |
"R": pearson_r(df[TARGET], df["UCS_Pred"]),
|
| 566 |
"RMSE": rmse(df[TARGET], df["UCS_Pred"]),
|
|
|
|
| 667 |
</div>
|
| 668 |
""",
|
| 669 |
unsafe_allow_html=True
|
| 670 |
+
)
|