UCS2014 commited on
Commit
3852c3c
·
verified ·
1 Parent(s): fd24fec

Rename app_revised_with_r_value.py to app.py

Browse files
Files changed (2) hide show
  1. app.py +687 -0
  2. app_revised_with_r_value.py +0 -26
app.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io, json, os, base64, math
3
+ from pathlib import Path
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import numpy as np
7
+ import joblib
8
+
9
+ # matplotlib only for PREVIEW modal
10
+ import matplotlib
11
+ matplotlib.use("Agg")
12
+ import matplotlib.pyplot as plt
13
+
14
+ import plotly.graph_objects as go
15
+ from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
16
+
17
+ # =========================
18
+ # Constants (simple & robust)
19
+ # =========================
20
+ FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
21
+ TARGET = "UCS"
22
+ MODELS_DIR = Path("models")
23
+ DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
24
+ MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
25
+ COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
+
27
+ # ---- Plot sizing controls (edit here) ----
28
+ CROSS_W = 450; CROSS_H = 450 # square cross-plot (Build + Validate)
29
+ TRACK_W = 400; TRACK_H = 950 # log-strip style (all pages)
30
+ FONT_SZ = 15
31
+ PLOT_COLS = [30, 1, 20] # 3-column band: left • spacer • right (Build + Validate)
32
+ CROSS_NUDGE = 0.02 # push cross-plot to the RIGHT inside its band
33
+
34
+ # =========================
35
+ # Page / CSS
36
+ # =========================
37
+ st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
38
+ st.markdown("""
39
+ <style>
40
+ /* ✅ Hide the helper text in file uploader */
41
+ section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"] {
42
+ display: none !important;
43
+ }
44
+ </style>
45
+ """, unsafe_allow_html=True)
46
+ st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True)
47
+ st.markdown(
48
+ """
49
+ <style>
50
+ .stApp { background:#fff; }
51
+ section[data-testid="stSidebar"] { background:#F6F9FC; }
52
+ .block-container { padding-top:.5rem; padding-bottom:.5rem; }
53
+ .stButton>button { background:#007bff; color:#fff; font-weight:600; border-radius:8px; border:none; }
54
+ .stButton>button:hover { background:#0056b3; }
55
+ .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
56
+ .st-hero .brand { width:110px; height:110px; object-fit:contain; }
57
+ .st-hero h1 { margin:0; line-height:1.05; }
58
+ .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
59
+ [data-testid="stBlock"]{ margin-top:0 !important; }
60
+ </style>
61
+ """,
62
+ unsafe_allow_html=True
63
+ )
64
+
65
+ # =========================
66
+ # Password gate (define first, then call)
67
+ # =========================
68
+ def inline_logo(path="logo.png") -> str:
69
+ try:
70
+ p = Path(path)
71
+ if not p.exists(): return ""
72
+ return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
73
+ except Exception:
74
+ return ""
75
+
76
+ def add_password_gate() -> None:
77
+ try:
78
+ required = st.secrets.get("APP_PASSWORD", "")
79
+ except Exception:
80
+ required = os.environ.get("APP_PASSWORD", "")
81
+
82
+ if not required:
83
+ st.markdown(
84
+ f"""
85
+ <div style="display:flex;align-items:center;gap:14px;margin:8px 0 6px 0;">
86
+ <img src="{inline_logo()}" style="width:56px;height:56px;object-fit:contain"/>
87
+ <div>
88
+ <div style="font-size:1.9rem;font-weight:800;">ST_GeoMech_UCS</div>
89
+ <div style="color:#667085;">Smart Thinking • Secure Access</div>
90
+ </div>
91
+ </div>
92
+ <div style="font-size:1.25rem;font-weight:700;margin:8px 0 4px 0;">Protected Area</div>
93
+ <div style="color:#6b7280;margin-bottom:14px;">
94
+ Set <code>APP_PASSWORD</code> in <b>Settings → Secrets</b> (or environment) and restart.
95
+ </div>
96
+ """,
97
+ unsafe_allow_html=True,
98
+ )
99
+ st.stop()
100
+
101
+ if st.session_state.get("auth_ok", False):
102
+ return
103
+
104
+ st.markdown(
105
+ f"""
106
+ <div style="display:flex;align-items:center;gap:14px;margin:8px 0 6px 0;">
107
+ <img src="{inline_logo()}" style="width:56px;height:56px;object-fit:contain"/>
108
+ <div>
109
+ <div style="font-size:1.9rem;font-weight:800;">ST_GeoMech_UCS</div>
110
+ <div style="color:#667085;">Smart Thinking • Secure Access</div>
111
+ </div>
112
+ </div>
113
+ <div style="font-size:1.25rem;font-weight:700;margin:8px 0 4px 0;">Protected</div>
114
+ <div style="color:#6b7280;margin-bottom:14px;">Please enter your access key to continue.</div>
115
+ """,
116
+ unsafe_allow_html=True
117
+ )
118
+
119
+ pwd = st.text_input("Access key", type="password", placeholder="••••••••")
120
+ if st.button("Unlock", type="primary"):
121
+ if pwd == required:
122
+ st.session_state.auth_ok = True
123
+ st.rerun()
124
+ else:
125
+ st.error("Incorrect key.")
126
+ st.stop()
127
+
128
+ add_password_gate()
129
+
130
+ # =========================
131
+ # Utilities
132
+ # =========================
133
+ try:
134
+ dialog = st.dialog
135
+ except AttributeError:
136
+ def dialog(title):
137
+ def deco(fn):
138
+ def wrapper(*args, **kwargs):
139
+ with st.expander(title, expanded=True):
140
+ return fn(*args, **kwargs)
141
+ return wrapper
142
+ return deco
143
+
144
+ def rmse(y_true, y_pred):
145
+ return float(np.sqrt(mean_squared_error(y_true, y_pred)))
146
+
147
+ def r_value(y_true, y_pred):
148
+ """Pearson correlation coefficient (R)."""
149
+ y_true = np.asarray(y_true, dtype=float)
150
+ y_pred = np.asarray(y_pred, dtype=float)
151
+ mask = np.isfinite(y_true) & np.isfinite(y_pred)
152
+ if mask.sum() < 2:
153
+ return float("nan")
154
+ return float(np.corrcoef(y_true[mask], y_pred[mask])[0, 1])
155
+
156
+ @st.cache_resource(show_spinner=False)
157
+ def load_model(model_path: str):
158
+ return joblib.load(model_path)
159
+
160
+ @st.cache_data(show_spinner=False)
161
+ def parse_excel(data_bytes: bytes):
162
+ bio = io.BytesIO(data_bytes)
163
+ xl = pd.ExcelFile(bio)
164
+ return {sh: xl.parse(sh) for sh in xl.sheet_names}
165
+
166
+ def read_book_bytes(b: bytes):
167
+ return parse_excel(b) if b else {}
168
+
169
+ def ensure_cols(df, cols):
170
+ miss = [c for c in cols if c not in df.columns]
171
+ if miss:
172
+ st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
173
+ return False
174
+ return True
175
+
176
+ def find_sheet(book, names):
177
+ low2orig = {k.lower(): k for k in book.keys()}
178
+ for nm in names:
179
+ if nm.lower() in low2orig:
180
+ return low2orig[nm.lower()]
181
+ return None
182
+
183
+ def _nice_tick0(xmin: float, step: int = 100) -> float:
184
+ return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
185
+
186
+ # ---------- cross_plot ----------
187
+ def cross_plot(actual, pred):
188
+ a = pd.Series(actual).astype(float)
189
+ p = pd.Series(pred).astype(float)
190
+
191
+ # Dynamic extents with a small pad
192
+ lo = float(np.nanmin([a.min(), p.min()]))
193
+ hi = float(np.nanmax([a.max(), p.max()]))
194
+ pad = 0.03 * (hi - lo if hi > lo else 1.0)
195
+ x0, x1 = lo - pad, hi + pad
196
+
197
+ fig = go.Figure()
198
+
199
+ # Scatter points
200
+ fig.add_trace(go.Scatter(
201
+ x=a, y=p, mode="markers",
202
+ marker=dict(size=6, color=COLORS["pred"]),
203
+ hovertemplate="Actual: %{x:.0f}<br>Pred: %{y:.0f}<extra></extra>",
204
+ showlegend=False
205
+ ))
206
+
207
+ # 1:1 diagonal
208
+ fig.add_trace(go.Scatter(
209
+ x=[x0, x1], y=[x0, x1], mode="lines",
210
+ line=dict(color=COLORS["ref"], width=1.2, dash="dash"),
211
+ hoverinfo="skip", showlegend=False
212
+ ))
213
+
214
+ fig.update_layout(
215
+ width=CROSS_W, height=CROSS_H,
216
+ paper_bgcolor="#fff", plot_bgcolor="#fff",
217
+ margin=dict(l=64, r=18, t=10, b=48),
218
+ hovermode="closest",
219
+ font=dict(size=FONT_SZ)
220
+ )
221
+
222
+ # lock aspect to keep 45° line visually accurate
223
+ fig.update_xaxes(
224
+ title_text="<b>Actual UCS (psi)</b>",
225
+ title_font=dict(size=18, family="Arial", color="#000"),
226
+ range=[x0, x1],
227
+ ticks="outside",
228
+ tickformat=",.0f",
229
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
230
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
231
+ scaleanchor="y", scaleratio=1,
232
+ automargin=True
233
+ )
234
+ fig.update_yaxes(
235
+ title_text="<b>Predicted UCS (psi)</b>",
236
+ title_font=dict(size=18, family="Arial", color="#000"),
237
+ range=[x0, x1],
238
+ ticks="outside",
239
+ tickformat=",.0f",
240
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
241
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)",
242
+ automargin=True
243
+ )
244
+ return fig
245
+
246
+ # ---------- track_plot ----------
247
+ def track_plot(df, include_actual=True):
248
+ depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
249
+ if depth_col:
250
+ y = pd.Series(df[depth_col]).astype(float)
251
+ ylab = depth_col
252
+ else:
253
+ y = pd.Series(np.arange(1, len(df) + 1))
254
+ ylab = "Point Index"
255
+
256
+ y_range = [float(y.max()), float(y.min())]
257
+
258
+ x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
259
+ if include_actual and TARGET in df.columns:
260
+ x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
261
+
262
+ x_lo, x_hi = float(x_series.min()), float(x_series.max())
263
+ x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
264
+ xmin, xmax = x_lo - x_pad, x_hi + x_pad
265
+ tick0 = _nice_tick0(xmin, step=100)
266
+
267
+ fig = go.Figure()
268
+
269
+ fig.add_trace(go.Scatter(
270
+ x=df["UCS_Pred"], y=y, mode="lines",
271
+ line=dict(color=COLORS["pred"], width=1.8),
272
+ name="UCS_Pred",
273
+ hovertemplate="UCS_Pred: %{x:.0f}<br>" + ylab + ": %{y}<extra></extra>"
274
+ ))
275
+
276
+ if include_actual and TARGET in df.columns:
277
+ fig.add_trace(go.Scatter(
278
+ x=df[TARGET], y=y, mode="lines",
279
+ line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
280
+ name="UCS (actual)",
281
+ hovertemplate="UCS (actual): %{x:.0f}<br>" + ylab + ": %{y}<extra></extra>"
282
+ ))
283
+
284
+ fig.update_layout(
285
+ width=TRACK_W, height=TRACK_H,
286
+ paper_bgcolor="#fff", plot_bgcolor="#fff",
287
+ margin=dict(l=72, r=18, t=36, b=48),
288
+ hovermode="closest",
289
+ font=dict(size=FONT_SZ),
290
+ legend=dict(
291
+ x=0.98, y=0.05, xanchor="right", yanchor="bottom",
292
+ bgcolor="rgba(255,255,255,0.75)", bordercolor="#ccc", borderwidth=1
293
+ ),
294
+ legend_title_text=""
295
+ )
296
+
297
+ fig.update_xaxes(
298
+ title_text="<b>UCS (psi)</b>",
299
+ title_font=dict(size=18, family="Arial", color="#000"),
300
+ side="top", range=[xmin, xmax],
301
+ tick0=tick0, tickmode="auto", tickformat=",.0f",
302
+ ticks="outside",
303
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
304
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
305
+ )
306
+
307
+ fig.update_yaxes(
308
+ title_text=f"<b>{ylab}</b>",
309
+ title_font=dict(size=18, family="Arial", color="#000"),
310
+ range=y_range,
311
+ ticks="outside",
312
+ showline=True, linewidth=1.2, linecolor="#444", mirror=True,
313
+ showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
314
+ )
315
+
316
+ return fig
317
+
318
+ # ---------- Preview modal (matplotlib) ----------
319
+ def preview_tracks(df: pd.DataFrame, cols: list[str]):
320
+ cols = [c for c in cols if c in df.columns]
321
+ n = len(cols)
322
+ if n == 0:
323
+ fig, ax = plt.subplots(figsize=(4, 2))
324
+ ax.text(0.5,0.5,"No selected columns",ha="center",va="center")
325
+ ax.axis("off")
326
+ return fig
327
+ fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
328
+ if n == 1: axes = [axes]
329
+ idx = np.arange(1, len(df) + 1)
330
+ for ax, col in zip(axes, cols):
331
+ ax.plot(df[col], idx, '-', lw=1.4, color="#333")
332
+ ax.set_xlabel(col); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
333
+ ax.grid(True, linestyle=":", alpha=0.3)
334
+ for s in ax.spines.values(): s.set_visible(True)
335
+ axes[0].set_ylabel("Point Index")
336
+ return fig
337
+
338
+ try:
339
+ dialog = st.dialog
340
+ except AttributeError:
341
+ def dialog(title):
342
+ def deco(fn):
343
+ def wrapper(*args, **kwargs):
344
+ with st.expander(title, expanded=True):
345
+ return fn(*args, **kwargs)
346
+ return wrapper
347
+ return deco
348
+
349
+ @dialog("Preview data")
350
+ def preview_modal(book: dict[str, pd.DataFrame]):
351
+ if not book:
352
+ st.info("No data loaded yet."); return
353
+ names = list(book.keys())
354
+ tabs = st.tabs(names)
355
+ for t, name in zip(tabs, names):
356
+ with t:
357
+ df = book[name]
358
+ t1, t2 = st.tabs(["Tracks", "Summary"])
359
+ with t1: st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
360
+ with t2:
361
+ tbl = df[FEATURES].agg(['min','max','mean','std']).T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"})
362
+ st.dataframe(tbl.reset_index(names="Feature"), use_container_width=True)
363
+
364
+ # =========================
365
+ # Load model (simple)
366
+ # =========================
367
+ def ensure_model() -> Path|None:
368
+ for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
369
+ if p.exists() and p.stat().st_size > 0: return p
370
+ url = os.environ.get("MODEL_URL", "")
371
+ if not url: return None
372
+ try:
373
+ import requests
374
+ DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
375
+ with requests.get(url, stream=True, timeout=30) as r:
376
+ r.raise_for_status()
377
+ with open(DEFAULT_MODEL, "wb") as f:
378
+ for chunk in r.iter_content(1<<20):
379
+ if chunk: f.write(chunk)
380
+ return DEFAULT_MODEL
381
+ except Exception:
382
+ return None
383
+
384
+ mpath = ensure_model()
385
+ if not mpath:
386
+ st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
387
+ st.stop()
388
+ try:
389
+ model = load_model(str(mpath))
390
+ except Exception as e:
391
+ st.error(f"Failed to load model: {e}")
392
+ st.stop()
393
+
394
+ meta_path = MODELS_DIR / "meta.json"
395
+ if meta_path.exists():
396
+ try:
397
+ meta = json.loads(meta_path.read_text(encoding="utf-8"))
398
+ FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
399
+ except Exception:
400
+ pass
401
+
402
+ # =========================
403
+ # Session state
404
+ # =========================
405
+ st.session_state.setdefault("app_step", "intro")
406
+ st.session_state.setdefault("results", {})
407
+ st.session_state.setdefault("train_ranges", None)
408
+ st.session_state.setdefault("dev_file_name","")
409
+ st.session_state.setdefault("dev_file_bytes",b"")
410
+ st.session_state.setdefault("dev_file_loaded",False)
411
+ st.session_state.setdefault("dev_preview",False)
412
+
413
+ # =========================
414
+ # Hero
415
+ # =========================
416
+ st.markdown(
417
+ f"""
418
+ <div class="st-hero">
419
+ <img src="{inline_logo()}" class="brand" />
420
+ <div>
421
+ <h1>ST_GeoMech_UCS</h1>
422
+ <div class="tagline">Real-Time UCS Tracking While Drilling</div>
423
+ </div>
424
+ </div>
425
+ """,
426
+ unsafe_allow_html=True,
427
+ )
428
+
429
+ # =========================
430
+ # INTRO
431
+ # =========================
432
+ if st.session_state.app_step == "intro":
433
+ st.header("Welcome!")
434
+ st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
435
+ st.subheader("How It Works")
436
+ st.markdown(
437
+ "1) **Upload your data to build the case and preview the performance of our model.** \n"
438
+ "2) Click **Run Model** to compute metrics and plots. \n"
439
+ "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
440
+ )
441
+ if st.button("Start Showcase", type="primary"):
442
+ st.session_state.app_step = "dev"; st.rerun()
443
+
444
+ # =========================
445
+ # CASE BUILDING
446
+ # =========================
447
+ if st.session_state.app_step == "dev":
448
+ st.sidebar.header("Case Building")
449
+ up = st.sidebar.file_uploader("Upload Train/Test Excel", type=["xlsx","xls"])
450
+ if up is not None:
451
+ st.session_state.dev_file_bytes = up.getvalue()
452
+ st.session_state.dev_file_name = up.name
453
+ st.session_state.dev_file_loaded = True
454
+ st.session_state.dev_preview = False
455
+ if st.session_state.dev_file_loaded:
456
+ tmp = read_book_bytes(st.session_state.dev_file_bytes)
457
+ if tmp:
458
+ df0 = next(iter(tmp.values()))
459
+ st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
460
+
461
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
462
+ preview_modal(read_book_bytes(st.session_state.dev_file_bytes))
463
+ st.session_state.dev_preview = True
464
+
465
+ run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
466
+ # always available nav
467
+ if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
468
+ if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
469
+
470
+ # ---- Pinned helper at the very top of the page ----
471
+ helper_top = st.container()
472
+ with helper_top:
473
+ st.subheader("Case Building")
474
+ if st.session_state.dev_file_loaded and st.session_state.dev_preview:
475
+ st.info("Previewed ✓ — now click **Run Model**.")
476
+ elif st.session_state.dev_file_loaded:
477
+ st.info("📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
478
+ else:
479
+ st.write("**Upload your data to build a case, then run the model to review development performance.**")
480
+
481
+ if run and st.session_state.dev_file_bytes:
482
+ book = read_book_bytes(st.session_state.dev_file_bytes)
483
+ sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
484
+ sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
485
+ if sh_train is None or sh_test is None:
486
+ st.error("Workbook must include Train/Training/training2 and Test/Testing/testing2 sheets."); st.stop()
487
+ tr = book[sh_train].copy(); te = book[sh_test].copy()
488
+ if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
489
+ st.error("Missing required columns."); st.stop()
490
+ tr["UCS_Pred"] = model.predict(tr[FEATURES])
491
+ te["UCS_Pred"] = model.predict(te[FEATURES])
492
+
493
+ # ---- metrics (R, RMSE, MAE) ----
494
+ st.session_state.results["Train"]=tr
495
+ st.session_state.results["Test"]=te
496
+ st.session_state.results["m_train"]={
497
+ "R": r_value(tr[TARGET], tr["UCS_Pred"]),
498
+ "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
499
+ "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
500
+ }
501
+ st.session_state.results["m_test"]={
502
+ "R": r_value(te[TARGET], te["UCS_Pred"]),
503
+ "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
504
+ "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
505
+ }
506
+
507
+ tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
508
+ st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
509
+ st.success("Case has been built and results are displayed below.")
510
+
511
+ def _dev_block(df, m):
512
+ c1,c2,c3 = st.columns(3)
513
+ c1.metric("R", f"{m['R']:.2f}")
514
+ c2.metric("RMSE", f"{m['RMSE']:.2f}")
515
+ c3.metric("MAE", f"{m['MAE']:.2f}")
516
+ left, spacer, right = st.columns(PLOT_COLS)
517
+ with left:
518
+ pad, plotcol = left.columns([CROSS_NUDGE, 1]) # shift cross-plot right inside its band
519
+ with plotcol:
520
+ st.plotly_chart(
521
+ cross_plot(df[TARGET], df["UCS_Pred"]),
522
+ use_container_width=False,
523
+ config={"displayModeBar": False, "scrollZoom": True}
524
+ )
525
+ with right:
526
+ st.plotly_chart(
527
+ track_plot(df, include_actual=True),
528
+ use_container_width=False,
529
+ config={"displayModeBar": False, "scrollZoom": True}
530
+ )
531
+
532
+ if "Train" in st.session_state.results or "Test" in st.session_state.results:
533
+ tab1, tab2 = st.tabs(["Training", "Testing"])
534
+ if "Train" in st.session_state.results:
535
+ with tab1: _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
536
+ if "Test" in st.session_state.results:
537
+ with tab2: _dev_block(st.session_state.results["Test"], st.session_state.results["m_test"])
538
+
539
+ # =========================
540
+ # VALIDATION (with actual UCS)
541
+ # =========================
542
+ if st.session_state.app_step == "validate":
543
+ st.sidebar.header("Validate the Model")
544
+ up = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"])
545
+ if up is not None:
546
+ book = read_book_bytes(up.getvalue())
547
+ if book:
548
+ df0 = next(iter(book.values()))
549
+ st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
550
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
551
+ preview_modal(read_book_bytes(up.getvalue()))
552
+ go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
553
+ if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
554
+ if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
555
+
556
+ st.subheader("Validate the Model")
557
+ st.write("Upload a dataset with the same **features** and **UCS** to evaluate performance.")
558
+
559
+ if go_btn and up is not None:
560
+ book = read_book_bytes(up.getvalue())
561
+ name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
562
+ df = book[name].copy()
563
+ if not ensure_cols(df, FEATURES+[TARGET]): st.error("Missing required columns."); st.stop()
564
+ df["UCS_Pred"] = model.predict(df[FEATURES])
565
+ st.session_state.results["Validate"]=df
566
+
567
+ ranges = st.session_state.train_ranges; oor_pct = 0.0; tbl=None
568
+ if ranges:
569
+ any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
570
+ oor_pct = float(any_viol.mean()*100.0)
571
+ if any_viol.any():
572
+ tbl = df.loc[any_viol, FEATURES].copy()
573
+ tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(
574
+ lambda r: ", ".join([c for c,v in r.items() if v]),
575
+ axis=1
576
+ )
577
+ st.session_state.results["m_val"]={
578
+ "R": r_value(df[TARGET], df["UCS_Pred"]),
579
+ "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
580
+ "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
581
+ }
582
+ st.session_state.results["sv_val"]={
583
+ "n":len(df),
584
+ "pred_min":float(df["UCS_Pred"].min()),
585
+ "pred_max":float(df["UCS_Pred"].max()),
586
+ "oor":oor_pct
587
+ }
588
+ st.session_state.results["oor_tbl"]=tbl
589
+
590
+ if "Validate" in st.session_state.results:
591
+ m = st.session_state.results["m_val"]
592
+ c1,c2,c3 = st.columns(3)
593
+ c1.metric("R", f"{m['R']:.2f}")
594
+ c2.metric("RMSE", f"{m['RMSE']:.2f}")
595
+ c3.metric("MAE", f"{m['MAE']:.2f}")
596
+
597
+ left, spacer, right = st.columns(PLOT_COLS)
598
+ with left:
599
+ pad, plotcol = left.columns([CROSS_NUDGE, 1])
600
+ with plotcol:
601
+ st.plotly_chart(
602
+ cross_plot(st.session_state.results["Validate"][TARGET],
603
+ st.session_state.results["Validate"]["UCS_Pred"]),
604
+ use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
605
+ )
606
+ with right:
607
+ st.plotly_chart(
608
+ track_plot(st.session_state.results["Validate"], include_actual=True),
609
+ use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
610
+ )
611
+
612
+ sv = st.session_state.results["sv_val"]
613
+ if sv["oor"] > 0: st.warning("Some inputs fall outside **training min–max** ranges.")
614
+ if st.session_state.results["oor_tbl"] is not None:
615
+ st.write("*Out-of-range rows (vs. Training min–max):*")
616
+ st.dataframe(st.session_state.results["oor_tbl"], use_container_width=True)
617
+
618
+ # =========================
619
+ # PREDICTION (no actual UCS)
620
+ # =========================
621
+ if st.session_state.app_step == "predict":
622
+ st.sidebar.header("Prediction (No Actual UCS)")
623
+ up = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"])
624
+ if up is not None:
625
+ book = read_book_bytes(up.getvalue())
626
+ if book:
627
+ df0 = next(iter(book.values()))
628
+ st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
629
+ if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
630
+ preview_modal(read_book_bytes(up.getvalue()))
631
+ go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
632
+ if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
633
+
634
+ st.subheader("Prediction")
635
+ st.write("Upload a dataset with the feature columns (no **UCS**).")
636
+
637
+ if go_btn and up is not None:
638
+ book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
639
+ df = book[name].copy()
640
+ if not ensure_cols(df, FEATURES): st.error("Missing required columns."); st.stop()
641
+ df["UCS_Pred"] = model.predict(df[FEATURES])
642
+ st.session_state.results["PredictOnly"]=df
643
+
644
+ ranges = st.session_state.train_ranges; oor_pct = 0.0
645
+ if ranges:
646
+ any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
647
+ oor_pct = float(any_viol.mean()*100.0)
648
+ st.session_state.results["sv_pred"]={
649
+ "n":len(df),
650
+ "pred_min":float(df["UCS_Pred"].min()),
651
+ "pred_max":float(df["UCS_Pred"].max()),
652
+ "pred_mean":float(df["UCS_Pred"].mean()),
653
+ "pred_std":float(df["UCS_Pred"].std(ddof=0)),
654
+ "oor":oor_pct
655
+ }
656
+
657
+ if "PredictOnly" in st.session_state.results:
658
+ df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
659
+
660
+ left, spacer, right = st.columns(PLOT_COLS)
661
+ with left:
662
+ table = pd.DataFrame({
663
+ "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
664
+ "Value": [sv["n"], sv["pred_min"], sv["pred_max"], sv["pred_mean"], sv["pred_std"], f'{sv["oor"]:.1f}%']
665
+ })
666
+ st.success("Predictions ready ✓")
667
+ st.dataframe(table, use_container_width=True, hide_index=True)
668
+ st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
669
+ with right:
670
+ st.plotly_chart(
671
+ track_plot(df, include_actual=False),
672
+ use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
673
+ )
674
+
675
+ # =========================
676
+ # Footer
677
+ # =========================
678
+ st.markdown("---")
679
+ st.markdown(
680
+ """
681
+ <div style='text-align:center; color:#6b7280; line-height:1.6'>
682
+ ST_GeoMech_UCS • © Smart Thinking<br/>
683
+ <strong>Visit our website:</strong> <a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>
684
+ </div>
685
+ """,
686
+ unsafe_allow_html=True
687
+ )
app_revised_with_r_value.py DELETED
@@ -1,26 +0,0 @@
1
- # Replace this function in your utils or metric section
2
- def r_value(y_true, y_pred):
3
- y_true = np.asarray(y_true)
4
- y_pred = np.asarray(y_pred)
5
- mask = np.isfinite(y_true) & np.isfinite(y_pred)
6
- return float(np.corrcoef(y_true[mask], y_pred[mask])[0, 1])
7
-
8
- # Replace your metrics assignment for train/test/validate like this:
9
- # Old:
10
- # "R2": r2_score(...)
11
- # New:
12
- "R": r_value(...)
13
-
14
- # In all metric displays like:
15
- # c1.metric("R²", f"{m['R2']:.4f}")
16
- # Change to:
17
- c1.metric("R", f"{m['R']:.2f}")
18
-
19
- # In all metric panels for Train, Test, Validate
20
- c1.metric("R", f"{m['R']:.2f}")
21
- c2.metric("RMSE", f"{m['RMSE']:.2f}")
22
- c3.metric("MAE", f"{m['MAE']:.2f}")
23
-
24
- # Also replace in the 'Validation' and 'Prediction' sections accordingly.
25
- # Ensure all table metrics and summary stats (pred_min, pred_max, etc.) use 2 decimal digits:
26
- f"{value:.2f}"