UCS2014 commited on
Commit
114e9ff
·
verified ·
1 Parent(s): 78a370f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -212
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import io, json, os, base64, math
2
  from pathlib import Path
3
  import streamlit as st
@@ -15,90 +16,72 @@ import plotly.graph_objects as go
15
  from sklearn.metrics import mean_squared_error, mean_absolute_error
16
 
17
  # =========================
18
- # Constants
19
  # =========================
20
- FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
21
- TARGET = "UCS"
 
 
 
 
 
 
22
  MODELS_DIR = Path("models")
23
- DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
24
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
 
25
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
26
 
27
  # ---- Plot sizing controls ----
28
- CROSS_W = 350 # px (matplotlib figure size; Streamlit will still scale)
29
  CROSS_H = 350
30
- TRACK_H = 1000 # px (plotly height; width auto-fits column)
31
- # NEW: Add a TRACK_W variable to control the width
32
- TRACK_W = 500 # px (plotly width)
33
- FONT_SZ = 13
34
- BOLD_FONT = "Arial Black, Arial, sans-serif" # used for bold axis titles & ticks
35
 
36
  # =========================
37
  # Page / CSS
38
  # =========================
39
- st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
40
 
41
- # General CSS (logo helpers etc.)
42
  st.markdown("""
43
  <style>
44
  .brand-logo { width: 200px; height: auto; object-fit: contain; }
45
  .sidebar-header { display:flex; align-items:center; gap:12px; }
46
  .sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
47
  .sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
48
- .centered-container {
49
- display: flex;
50
- flex-direction: column;
51
- align-items: center;
52
- text-align: center;
53
- }
54
  </style>
55
  """, unsafe_allow_html=True)
56
 
57
- # CSS to make sticky headers work correctly by overriding Streamlit's overflow property
58
  st.markdown("""
59
  <style>
60
- /* This targets the main content area */
61
- .main .block-container {
62
- overflow: unset !important;
63
- }
64
-
65
- /* This targets the vertical block that holds all your elements */
66
- div[data-testid="stVerticalBlock"] {
67
- overflow: unset !important;
68
- }
69
  </style>
70
  """, unsafe_allow_html=True)
71
 
72
  # Hide uploader helper text ("Drag and drop file here", limits, etc.)
73
  st.markdown("""
74
  <style>
75
- /* Older builds (helper wrapped in a Markdown container) */
76
  section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"]{display:none !important;}
77
- /* 1.31–1.34: helper is the first child in the dropzone */
78
  section[data-testid="stFileUploader"] [data-testid="stFileUploaderDropzone"] > div:first-child{display:none !important;}
79
- /* 1.35+: explicit helper container */
80
  section[data-testid="stFileUploader"] [data-testid="stFileUploaderInstructions"]{display:none !important;}
81
- /* Fallback: any paragraph/small text inside the uploader */
82
  section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] small{display:none !important;}
83
  </style>
84
  """, unsafe_allow_html=True)
85
 
86
- # Make the Preview expander title & tabs sticky (pinned to the top)
87
  st.markdown("""
88
  <style>
89
  div[data-testid="stExpander"] > details > summary {
90
- position: sticky;
91
- top: 0;
92
- z-index: 10;
93
- background: #fff;
94
- border-bottom: 1px solid #eee;
95
  }
96
  div[data-testid="stExpander"] div[data-baseweb="tab-list"] {
97
- position: sticky;
98
- top: 42px; /* adjust if your expander header height differs */
99
- z-index: 9;
100
- background: #fff;
101
- padding-top: 6px;
102
  }
103
  </style>
104
  """, unsafe_allow_html=True)
@@ -109,31 +92,13 @@ TABLE_CENTER_CSS = [
109
  dict(selector="td", props=[("text-align", "center")]),
110
  ]
111
 
112
- # NEW: CSS for the message box
113
  st.markdown("""
114
  <style>
115
- .st-message-box {
116
- background-color: #f0f2f6;
117
- color: #333333;
118
- padding: 10px;
119
- border-radius: 10px;
120
- border: 1px solid #e6e9ef;
121
- }
122
- .st-message-box.st-success {
123
- background-color: #d4edda;
124
- color: #155724;
125
- border-color: #c3e6cb;
126
- }
127
- .st-message-box.st-warning {
128
- background-color: #fff3cd;
129
- color: #856404;
130
- border-color: #ffeeba;
131
- }
132
- .st-message-box.st-error {
133
- background-color: #f8d7da;
134
- color: #721c24;
135
- border-color: #f5c6cb;
136
- }
137
  </style>
138
  """, unsafe_allow_html=True)
139
 
@@ -164,7 +129,7 @@ def add_password_gate() -> None:
164
  st.sidebar.markdown(f"""
165
  <div class="centered-container">
166
  <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
167
- <div style='font-weight:800;font-size:1.2rem; margin-top: 10px;'>ST_GeoMech_UCS</div>
168
  <div style='color:#667085;'>Smart Thinking • Secure Access</div>
169
  </div>
170
  """, unsafe_allow_html=True
@@ -217,11 +182,10 @@ def find_sheet(book, names):
217
  if nm.lower() in low2orig: return low2orig[nm.lower()]
218
  return None
219
 
220
- def _nice_tick0(xmin: float, step: int = 100) -> float:
221
  return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
222
 
223
  def df_centered_rounded(df: pd.DataFrame, hide_index=True):
224
- """Center headers & cells; format numeric columns to 2 decimals."""
225
  out = df.copy()
226
  numcols = out.select_dtypes(include=[np.number]).columns
227
  styler = (
@@ -232,15 +196,57 @@ def df_centered_rounded(df: pd.DataFrame, hide_index=True):
232
  )
233
  st.dataframe(styler, use_container_width=True, hide_index=hide_index)
234
 
235
- # =========================
236
- # Cross plot (Matplotlib, fixed limits & ticks)
237
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def cross_plot_static(actual, pred):
239
  a = pd.Series(actual, dtype=float)
240
- p = pd.Series(pred, dtype=float)
241
 
242
- fixed_min, fixed_max = 6000, 10000
243
- ticks = np.arange(fixed_min, fixed_max + 1, 1000)
 
 
 
244
 
245
  dpi = 110
246
  fig, ax = plt.subplots(
@@ -257,15 +263,15 @@ def cross_plot_static(actual, pred):
257
  ax.set_ylim(fixed_min, fixed_max)
258
  ax.set_xticks(ticks)
259
  ax.set_yticks(ticks)
260
- ax.set_aspect("equal", adjustable="box") # true 45°
261
 
262
  fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
263
  ax.xaxis.set_major_formatter(fmt)
264
  ax.yaxis.set_major_formatter(fmt)
265
 
266
- ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=10, color="black")
267
- ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=10, color="black")
268
- ax.tick_params(labelsize=6, colors="black")
269
 
270
  ax.grid(True, linestyle=":", alpha=0.3)
271
  for spine in ax.spines.values():
@@ -278,45 +284,43 @@ def cross_plot_static(actual, pred):
278
  # =========================
279
  # Track plot (Plotly)
280
  # =========================
281
- def track_plot(df, include_actual=True):
282
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
283
  if depth_col is not None:
284
  y = pd.Series(df[depth_col]).astype(float)
285
  ylab = depth_col
286
- y_range = [float(y.max()), float(y.min())] # reverse
287
  else:
288
  y = pd.Series(np.arange(1, len(df) + 1))
289
  ylab = "Point Index"
290
  y_range = [float(y.max()), float(y.min())]
291
 
292
- # X (UCS) range & ticks
293
- x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
294
- if include_actual and TARGET in df.columns:
295
- x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
296
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
297
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
298
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
299
- tick0 = _nice_tick0(xmin, step=100)
300
 
301
  fig = go.Figure()
302
  fig.add_trace(go.Scatter(
303
- x=df["UCS_Pred"], y=y, mode="lines",
304
  line=dict(color=COLORS["pred"], width=1.8),
305
- name="UCS_Pred",
306
- hovertemplate="UCS_Pred: %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
307
  ))
308
- if include_actual and TARGET in df.columns:
309
  fig.add_trace(go.Scatter(
310
- x=df[TARGET], y=y, mode="lines",
311
  line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
312
- name="UCS (actual)",
313
- hovertemplate="UCS (actual): %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
314
  ))
315
 
316
  fig.update_layout(
317
- height=TRACK_H,
318
- width=TRACK_W, # Set the width here
319
- autosize=False, # Disable autosizing to respect the width
320
  paper_bgcolor="#fff", plot_bgcolor="#fff",
321
  margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
322
  font=dict(size=FONT_SZ, color="#000"),
@@ -326,31 +330,23 @@ def track_plot(df, include_actual=True):
326
  ),
327
  legend_title_text=""
328
  )
329
-
330
- # Bold, black axis titles & ticks
331
  fig.update_xaxes(
332
- title_text="UCS (psi)",
333
  title_font=dict(size=20, family=BOLD_FONT, color="#000"),
334
  tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
335
- side="top",
336
- range=[xmin, xmax],
337
- ticks="outside",
338
- tickformat=",.0f",
339
- tickmode="auto",
340
- tick0=tick0,
341
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
342
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
343
  )
344
  fig.update_yaxes(
345
- title_text=ylab,
346
  title_font=dict(size=20, family=BOLD_FONT, color="#000"),
347
  tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
348
- range=y_range,
349
- ticks="outside",
350
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
351
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
352
  )
353
-
354
  return fig
355
 
356
  # ---------- Preview modal (matplotlib) ----------
@@ -384,25 +380,8 @@ except AttributeError:
384
  return wrapper
385
  return deco
386
 
387
- def preview_modal(book: dict[str, pd.DataFrame]):
388
- if not book:
389
- st.info("No data loaded yet."); return
390
- names = list(book.keys())
391
- tabs = st.tabs(names)
392
- for t, name in zip(tabs, names):
393
- with t:
394
- df = book[name]
395
- t1, t2 = st.tabs(["Tracks", "Summary"])
396
- with t1:
397
- st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
398
- with t2:
399
- tbl = (df[FEATURES]
400
- .agg(['min','max','mean','std'])
401
- .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
402
- df_centered_rounded(tbl.reset_index(names="Feature"))
403
-
404
  # =========================
405
- # Load model
406
  # =========================
407
  def ensure_model() -> Path|None:
408
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
@@ -423,7 +402,7 @@ def ensure_model() -> Path|None:
423
 
424
  mpath = ensure_model()
425
  if not mpath:
426
- st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
427
  st.stop()
428
  try:
429
  model = load_model(str(mpath))
@@ -435,7 +414,10 @@ meta_path = MODELS_DIR / "meta.json"
435
  if meta_path.exists():
436
  try:
437
  meta = json.loads(meta_path.read_text(encoding="utf-8"))
438
- FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
 
 
 
439
  except Exception:
440
  pass
441
 
@@ -449,7 +431,7 @@ st.session_state.setdefault("dev_file_name","")
449
  st.session_state.setdefault("dev_file_bytes",b"")
450
  st.session_state.setdefault("dev_file_loaded",False)
451
  st.session_state.setdefault("dev_preview",False)
452
- st.session_state.setdefault("show_preview_modal", False) # New state variable
453
 
454
  # =========================
455
  # Branding in Sidebar
@@ -457,27 +439,20 @@ st.session_state.setdefault("show_preview_modal", False) # New state variable
457
  st.sidebar.markdown(f"""
458
  <div class="centered-container">
459
  <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
460
- <div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>
461
- <div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>
462
  </div>
463
  """, unsafe_allow_html=True
464
  )
465
 
466
- # =========================
467
- # Reusable Sticky Header Function
468
- # =========================
469
  def sticky_header(title, message):
470
  st.markdown(
471
  f"""
472
  <style>
473
  .sticky-container {{
474
- position: sticky;
475
- top: 0;
476
- background-color: white;
477
- z-index: 100;
478
- padding-top: 10px;
479
- padding-bottom: 10px;
480
- border-bottom: 1px solid #eee;
481
  }}
482
  </style>
483
  <div class="sticky-container">
@@ -493,12 +468,12 @@ def sticky_header(title, message):
493
  # =========================
494
  if st.session_state.app_step == "intro":
495
  st.header("Welcome!")
496
- st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
497
  st.subheader("How It Works")
498
  st.markdown(
499
- "1) **Upload your data to build the case and preview the performance of our model.** \n"
500
- "2) Click **Run Model** to compute metrics and plots. \n"
501
- "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
502
  )
503
  if st.button("Start Showcase", type="primary"):
504
  st.session_state.app_step = "dev"; st.rerun()
@@ -521,14 +496,14 @@ if st.session_state.app_step == "dev":
521
  st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
522
 
523
  if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
524
- st.session_state.show_preview_modal = True # Set state to show modal
525
  st.session_state.dev_preview = True
526
 
527
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
528
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
529
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
530
 
531
- # Apply sticky header
532
  if st.session_state.dev_file_loaded and st.session_state.dev_preview:
533
  sticky_header("Case Building", "Previewed ✓ — now click **Run Model**.")
534
  elif st.session_state.dev_file_loaded:
@@ -539,27 +514,34 @@ if st.session_state.app_step == "dev":
539
  if run and st.session_state.dev_file_bytes:
540
  book = read_book_bytes(st.session_state.dev_file_bytes)
541
  sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
542
- sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
543
  if sh_train is None or sh_test is None:
544
- st.markdown('<div class="st-message-box st-error">Workbook must include Train/Training/training2 and Test/Testing/testing2 sheets.</div>', unsafe_allow_html=True)
545
  st.stop()
546
  tr = book[sh_train].copy(); te = book[sh_test].copy()
547
- if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
548
- st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True)
549
- st.stop()
550
- tr["UCS_Pred"] = model.predict(tr[FEATURES])
551
- te["UCS_Pred"] = model.predict(te[FEATURES])
 
 
 
 
 
 
 
552
 
553
  st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
554
  st.session_state.results["m_train"]={
555
- "R": pearson_r(tr[TARGET], tr["UCS_Pred"]),
556
- "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
557
- "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
558
  }
559
  st.session_state.results["m_test"]={
560
- "R": pearson_r(te[TARGET], te["UCS_Pred"]),
561
- "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
562
- "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
563
  }
564
 
565
  tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
@@ -571,37 +553,34 @@ if st.session_state.app_step == "dev":
571
  c1.metric("R", f"{m['R']:.2f}")
572
  c2.metric("RMSE", f"{m['RMSE']:.2f}")
573
  c3.metric("MAE", f"{m['MAE']:.2f}")
574
-
575
- # NEW: Footer for metric abbreviations
576
  st.markdown("""
577
- <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
578
  <strong>R:</strong> Pearson Correlation Coefficient<br>
579
  <strong>RMSE:</strong> Root Mean Square Error<br>
580
  <strong>MAE:</strong> Mean Absolute Error
581
  </div>
582
  """, unsafe_allow_html=True)
583
 
584
- # 2-column layout, big gap (prevents overlap)
585
  col_track, col_cross = st.columns([2, 3], gap="large")
586
  with col_track:
587
  st.plotly_chart(
588
- track_plot(df, include_actual=True),
589
- use_container_width=False, # Set to False to honor the width in track_plot()
 
590
  config={"displayModeBar": False, "scrollZoom": True}
591
  )
592
  with col_cross:
593
- st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
594
-
595
 
596
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
597
  tab1, tab2 = st.tabs(["Training", "Testing"])
598
  if "Train" in st.session_state.results:
599
  with tab1: _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
600
  if "Test" in st.session_state.results:
601
- with tab2: _dev_block(st.session_state.results["Test"], st.session_state.results["m_test"])
602
 
603
  # =========================
604
- # VALIDATION (with actual UCS)
605
  # =========================
606
  if st.session_state.app_step == "validate":
607
  st.sidebar.header("Validate the Model")
@@ -612,19 +591,27 @@ if st.session_state.app_step == "validate":
612
  df0 = next(iter(book.values()))
613
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
614
  if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
615
- st.session_state.show_preview_modal = True # Set state to show modal
616
  go_btn = st.sidebar.button("Predict & Validate", type="primary", use_container_width=True)
617
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
618
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
619
 
620
- sticky_header("Validate the Model", "Upload a dataset with the same **features** and **UCS** to evaluate performance.")
621
 
622
  if go_btn and up is not None:
623
  book = read_book_bytes(up.getvalue())
624
  name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
625
  df = book[name].copy()
626
- if not ensure_cols(df, FEATURES+[TARGET]): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
627
- df["UCS_Pred"] = model.predict(df[FEATURES])
 
 
 
 
 
 
 
 
628
  st.session_state.results["Validate"]=df
629
 
630
  ranges = st.session_state.train_ranges; oor_pct = 0.0; tbl=None
@@ -636,12 +623,13 @@ if st.session_state.app_step == "validate":
636
  for c in FEATURES:
637
  if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
638
  tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
 
639
  st.session_state.results["m_val"]={
640
- "R": pearson_r(df[TARGET], df["UCS_Pred"]),
641
- "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
642
- "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
643
  }
644
- st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
645
  st.session_state.results["oor_tbl"]=tbl
646
 
647
  if "Validate" in st.session_state.results:
@@ -650,27 +638,25 @@ if st.session_state.app_step == "validate":
650
  c1.metric("R", f"{m['R']:.2f}")
651
  c2.metric("RMSE", f"{m['RMSE']:.2f}")
652
  c3.metric("MAE", f"{m['MAE']:.2f}")
653
-
654
- # NEW: Footer for metric abbreviations
655
  st.markdown("""
656
- <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
657
  <strong>R:</strong> Pearson Correlation Coefficient<br>
658
  <strong>RMSE:</strong> Root Mean Square Error<br>
659
  <strong>MAE:</strong> Mean Absolute Error
660
  </div>
661
  """, unsafe_allow_html=True)
662
-
663
  col_track, col_cross = st.columns([2, 3], gap="large")
664
  with col_track:
665
  st.plotly_chart(
666
- track_plot(st.session_state.results["Validate"], include_actual=True),
667
- use_container_width=False, # Set to False to honor the width in track_plot()
668
- config={"displayModeBar": False, "scrollZoom": True}
669
  )
670
  with col_cross:
671
  st.pyplot(
672
- cross_plot_static(st.session_state.results["Validate"][TARGET],
673
- st.session_state.results["Validate"]["UCS_Pred"]),
674
  use_container_width=False
675
  )
676
 
@@ -681,10 +667,10 @@ if st.session_state.app_step == "validate":
681
  df_centered_rounded(st.session_state.results["oor_tbl"])
682
 
683
  # =========================
684
- # PREDICTION (no actual UCS)
685
  # =========================
686
  if st.session_state.app_step == "predict":
687
- st.sidebar.header("Prediction (No Actual UCS)")
688
  up = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"])
689
  if up is not None:
690
  book = read_book_bytes(up.getvalue())
@@ -692,17 +678,19 @@ if st.session_state.app_step == "predict":
692
  df0 = next(iter(book.values()))
693
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
694
  if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
695
- st.session_state.show_preview_modal = True # Set state to show modal
696
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
697
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
698
 
699
- sticky_header("Prediction", "Upload a dataset with the feature columns (no **UCS**).")
700
 
701
  if go_btn and up is not None:
702
  book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
703
  df = book[name].copy()
704
- if not ensure_cols(df, FEATURES): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
705
- df["UCS_Pred"] = model.predict(df[FEATURES])
 
 
706
  st.session_state.results["PredictOnly"]=df
707
 
708
  ranges = st.session_state.train_ranges; oor_pct = 0.0
@@ -711,10 +699,10 @@ if st.session_state.app_step == "predict":
711
  oor_pct = float(any_viol.mean()*100.0)
712
  st.session_state.results["sv_pred"]={
713
  "n":len(df),
714
- "pred_min":float(df["UCS_Pred"].min()),
715
- "pred_max":float(df["UCS_Pred"].max()),
716
- "pred_mean":float(df["UCS_Pred"].mean()),
717
- "pred_std":float(df["UCS_Pred"].std(ddof=0)),
718
  "oor":oor_pct
719
  }
720
 
@@ -725,28 +713,27 @@ if st.session_state.app_step == "predict":
725
  with col_left:
726
  table = pd.DataFrame({
727
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
728
- "Value": [sv["n"],
729
- round(sv["pred_min"],2),
730
- round(sv["pred_max"],2),
731
- round(sv["pred_mean"],2),
732
- round(sv["pred_std"],2),
733
- f'{sv["oor"]:.1f}%']
734
  })
735
  st.markdown('<div class="st-message-box st-success">Predictions ready ✓</div>', unsafe_allow_html=True)
736
  df_centered_rounded(table, hide_index=True)
737
  st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
738
  with col_right:
739
  st.plotly_chart(
740
- track_plot(df, include_actual=False),
741
- use_container_width=False, # Set to False to honor the width in track_plot()
742
- config={"displayModeBar": False, "scrollZoom": True}
743
  )
744
 
745
  # =========================
746
- # Run preview modal after all other elements
747
  # =========================
748
  if st.session_state.show_preview_modal:
749
- # Get the correct book based on the current app step
750
  book_to_preview = {}
751
  if st.session_state.app_step == "dev":
752
  book_to_preview = read_book_bytes(st.session_state.dev_file_bytes)
@@ -770,9 +757,7 @@ if st.session_state.show_preview_modal:
770
  .agg(['min','max','mean','std'])
771
  .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
772
  df_centered_rounded(tbl.reset_index(names="Feature"))
773
- # Reset the state variable after the modal is displayed
774
  st.session_state.show_preview_modal = False
775
-
776
  # =========================
777
  # Footer
778
  # =========================
@@ -780,6 +765,7 @@ st.markdown("""
780
  <br><br><br>
781
  <hr>
782
  <div style='text-align:center;color:#6b7280;font-size:0.8em;'>
783
- © 2024 Smart Thinking AI-Solutions Team. All rights reserved.
 
784
  </div>
785
  """, unsafe_allow_html=True)
 
1
+ # app.py — ST_GR (Gamma Ray) app adapted from your UCS app, same flow & design
2
  import io, json, os, base64, math
3
  from pathlib import Path
4
  import streamlit as st
 
16
  from sklearn.metrics import mean_squared_error, mean_absolute_error
17
 
18
  # =========================
19
+ # Constants (GR)
20
  # =========================
21
+ APP_NAME = "ST_GR"
22
+ TAGLINE = "Gamma Ray Prediction"
23
+ # If meta.json is present, these will be overridden
24
+ FEATURES = ["Feat1","Feat2","Feat3","Feat4","Feat5","Feat6"] # 6 inputs (placeholder; meta.json wins)
25
+ TARGET = "log_GR" # typical training target; meta.json wins
26
+ TARGET_TRANSFORM = "log10" # "log10" | "ln" | "none" (meta.json wins)
27
+ ACTUAL_COL = "GR" # if present in sheets; if not, we'll derive from TARGET + transform
28
+
29
  MODELS_DIR = Path("models")
30
+ DEFAULT_MODEL = MODELS_DIR / "gr_rf.joblib"
31
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
32
+
33
  COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
34
 
35
  # ---- Plot sizing controls ----
36
+ CROSS_W = 350 # px (matplotlib figure size; Streamlit will still scale)
37
  CROSS_H = 350
38
+ TRACK_H = 1000 # px (plotly height)
39
+ TRACK_W = 500 # px (plotly width)
40
+ FONT_SZ = 13
41
+ BOLD_FONT = "Arial Black, Arial, sans-serif"
 
42
 
43
  # =========================
44
  # Page / CSS
45
  # =========================
46
+ st.set_page_config(page_title=APP_NAME, page_icon="logo.png", layout="wide")
47
 
48
+ # General CSS
49
  st.markdown("""
50
  <style>
51
  .brand-logo { width: 200px; height: auto; object-fit: contain; }
52
  .sidebar-header { display:flex; align-items:center; gap:12px; }
53
  .sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
54
  .sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
55
+ .centered-container { display:flex; flex-direction:column; align-items:center; text-align:center; }
 
 
 
 
 
56
  </style>
57
  """, unsafe_allow_html=True)
58
 
59
+ # Allow sticky bits (preview expander header & tabs)
60
  st.markdown("""
61
  <style>
62
+ .main .block-container { overflow: unset !important; }
63
+ div[data-testid="stVerticalBlock"] { overflow: unset !important; }
 
 
 
 
 
 
 
64
  </style>
65
  """, unsafe_allow_html=True)
66
 
67
  # Hide uploader helper text ("Drag and drop file here", limits, etc.)
68
  st.markdown("""
69
  <style>
 
70
  section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"]{display:none !important;}
 
71
  section[data-testid="stFileUploader"] [data-testid="stFileUploaderDropzone"] > div:first-child{display:none !important;}
 
72
  section[data-testid="stFileUploader"] [data-testid="stFileUploaderInstructions"]{display:none !important;}
 
73
  section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] small{display:none !important;}
74
  </style>
75
  """, unsafe_allow_html=True)
76
 
77
+ # Sticky Preview expander & its tabs
78
  st.markdown("""
79
  <style>
80
  div[data-testid="stExpander"] > details > summary {
81
+ position: sticky; top: 0; z-index: 10; background: #fff; border-bottom: 1px solid #eee;
 
 
 
 
82
  }
83
  div[data-testid="stExpander"] div[data-baseweb="tab-list"] {
84
+ position: sticky; top: 42px; z-index: 9; background: #fff; padding-top: 6px;
 
 
 
 
85
  }
86
  </style>
87
  """, unsafe_allow_html=True)
 
92
  dict(selector="td", props=[("text-align", "center")]),
93
  ]
94
 
95
+ # Message box styles
96
  st.markdown("""
97
  <style>
98
+ .st-message-box { background:#f0f2f6; color:#333; padding:10px; border-radius:10px; border:1px solid #e6e9ef; }
99
+ .st-message-box.st-success { background:#d4edda; color:#155724; border-color:#c3e6cb; }
100
+ .st-message-box.st-warning { background:#fff3cd; color:#856404; border-color:#ffeeba; }
101
+ .st-message-box.st-error { background:#f8d7da; color:#721c24; border-color:#f5c6cb; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  </style>
103
  """, unsafe_allow_html=True)
104
 
 
129
  st.sidebar.markdown(f"""
130
  <div class="centered-container">
131
  <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
132
+ <div style='font-weight:800;font-size:1.2rem;'>{APP_NAME}</div>
133
  <div style='color:#667085;'>Smart Thinking • Secure Access</div>
134
  </div>
135
  """, unsafe_allow_html=True
 
182
  if nm.lower() in low2orig: return low2orig[nm.lower()]
183
  return None
184
 
185
+ def _nice_tick0(xmin: float, step: int = 5) -> float:
186
  return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
187
 
188
  def df_centered_rounded(df: pd.DataFrame, hide_index=True):
 
189
  out = df.copy()
190
  numcols = out.select_dtypes(include=[np.number]).columns
191
  styler = (
 
196
  )
197
  st.dataframe(styler, use_container_width=True, hide_index=hide_index)
198
 
199
+ # --- target transform helpers (to support models trained on log(GR)) ---
200
+ def inverse_target(x: np.ndarray, transform: str) -> np.ndarray:
201
+ t = (transform or "none").lower()
202
+ if t in ["log10", "log_10", "log10()"]:
203
+ return np.power(10.0, x)
204
+ if t in ["ln", "log", "log_e", "natural"]:
205
+ return np.exp(x)
206
+ return x # "none"
207
+
208
+ def to_actual_series(df: pd.DataFrame, target_col: str, actual_col_hint: str, transform: str) -> pd.Series:
209
+ """
210
+ Return the 'actual GR' series (API).
211
+ If an explicit actual column exists, use it; else invert the target.
212
+ """
213
+ if actual_col_hint and actual_col_hint in df.columns:
214
+ return pd.Series(df[actual_col_hint], dtype=float)
215
+ # else, if target exists, invert:
216
+ if target_col in df.columns:
217
+ return pd.Series(inverse_target(np.asarray(df[target_col], dtype=float), transform), dtype=float)
218
+ # fallback: if a column named "GR" exists, use it
219
+ if "GR" in df.columns:
220
+ return pd.Series(df["GR"], dtype=float)
221
+ raise ValueError("Cannot find actual GR column or target to invert.")
222
+
223
+ # =========================
224
+ # Cross plot (Matplotlib) — auto limits for GR
225
+ # =========================
226
+ def _nice_bounds(arr_min, arr_max, n_ticks=5):
227
+ # pick a "nice" range and step for GR (typically 0–200+ API)
228
+ if not np.isfinite(arr_min) or not np.isfinite(arr_max):
229
+ return 0.0, 100.0, 20.0
230
+ span = arr_max - arr_min
231
+ if span <= 0:
232
+ return max(arr_min-5, 0), arr_max+5, 5.0
233
+ raw_step = span / max(n_ticks, 1)
234
+ mag = 10 ** math.floor(math.log10(raw_step))
235
+ steps = np.array([1, 2, 2.5, 5, 10]) * mag
236
+ step = steps[np.argmin(np.abs(steps - raw_step))]
237
+ lo = step * math.floor(arr_min / step)
238
+ hi = step * math.ceil(arr_max / step)
239
+ return float(lo), float(hi), float(step)
240
+
241
  def cross_plot_static(actual, pred):
242
  a = pd.Series(actual, dtype=float)
243
+ p = pd.Series(pred, dtype=float)
244
 
245
+ # auto bounds & ticks for GR
246
+ lo = min(a.min(), p.min())
247
+ hi = max(a.max(), p.max())
248
+ fixed_min, fixed_max, step = _nice_bounds(lo, hi, n_ticks=6)
249
+ ticks = np.arange(fixed_min, fixed_max + step, step)
250
 
251
  dpi = 110
252
  fig, ax = plt.subplots(
 
263
  ax.set_ylim(fixed_min, fixed_max)
264
  ax.set_xticks(ticks)
265
  ax.set_yticks(ticks)
266
+ ax.set_aspect("equal", adjustable="box") # true 1:1
267
 
268
  fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
269
  ax.xaxis.set_major_formatter(fmt)
270
  ax.yaxis.set_major_formatter(fmt)
271
 
272
+ ax.set_xlabel("Actual GR (API)", fontweight="bold", fontsize=10, color="black")
273
+ ax.set_ylabel("Predicted GR (API)", fontweight="bold", fontsize=10, color="black")
274
+ ax.tick_params(labelsize=8, colors="black")
275
 
276
  ax.grid(True, linestyle=":", alpha=0.3)
277
  for spine in ax.spines.values():
 
284
  # =========================
285
  # Track plot (Plotly)
286
  # =========================
287
+ def track_plot(df, include_actual=True, pred_col="GR_Pred", actual_col="GR"):
288
  depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
289
  if depth_col is not None:
290
  y = pd.Series(df[depth_col]).astype(float)
291
  ylab = depth_col
292
+ y_range = [float(y.max()), float(y.min())] # reverse for logs
293
  else:
294
  y = pd.Series(np.arange(1, len(df) + 1))
295
  ylab = "Point Index"
296
  y_range = [float(y.max()), float(y.min())]
297
 
298
+ # X (GR) range & ticks
299
+ x_series = pd.Series(df.get(pred_col, pd.Series(dtype=float))).astype(float)
300
+ if include_actual and actual_col in df.columns:
301
+ x_series = pd.concat([x_series, pd.Series(df[actual_col]).astype(float)], ignore_index=True)
302
  x_lo, x_hi = float(x_series.min()), float(x_series.max())
303
  x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
304
  xmin, xmax = x_lo - x_pad, x_hi + x_pad
305
+ tick0 = _nice_tick0(xmin, step=5)
306
 
307
  fig = go.Figure()
308
  fig.add_trace(go.Scatter(
309
+ x=df[pred_col], y=y, mode="lines",
310
  line=dict(color=COLORS["pred"], width=1.8),
311
+ name="GR_Pred",
312
+ hovertemplate="GR_Pred: %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
313
  ))
314
+ if include_actual and actual_col in df.columns:
315
  fig.add_trace(go.Scatter(
316
+ x=df[actual_col], y=y, mode="lines",
317
  line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
318
+ name="GR (actual)",
319
+ hovertemplate="GR (actual): %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
320
  ))
321
 
322
  fig.update_layout(
323
+ height=TRACK_H, width=TRACK_W, autosize=False,
 
 
324
  paper_bgcolor="#fff", plot_bgcolor="#fff",
325
  margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
326
  font=dict(size=FONT_SZ, color="#000"),
 
330
  ),
331
  legend_title_text=""
332
  )
 
 
333
  fig.update_xaxes(
334
+ title_text="GR (API)",
335
  title_font=dict(size=20, family=BOLD_FONT, color="#000"),
336
  tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
337
+ side="top", range=[xmin, xmax],
338
+ ticks="outside", tickformat=",.0f", tickmode="auto", tick0=tick0,
 
 
 
 
339
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
340
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
341
  )
342
  fig.update_yaxes(
343
+ title_text=f"{ylab}",
344
  title_font=dict(size=20, family=BOLD_FONT, color="#000"),
345
  tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
346
+ range=y_range, ticks="outside",
 
347
  showline=True, linewidth=1.2, linecolor="#444", mirror=True,
348
  showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
349
  )
 
350
  return fig
351
 
352
  # ---------- Preview modal (matplotlib) ----------
 
380
  return wrapper
381
  return deco
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  # =========================
384
+ # Load model + meta
385
  # =========================
386
  def ensure_model() -> Path|None:
387
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
 
402
 
403
  mpath = ensure_model()
404
  if not mpath:
405
+ st.error("Model not found. Upload models/gr_rf.joblib (or set MODEL_URL).")
406
  st.stop()
407
  try:
408
  model = load_model(str(mpath))
 
414
  if meta_path.exists():
415
  try:
416
  meta = json.loads(meta_path.read_text(encoding="utf-8"))
417
+ FEATURES = meta.get("features", FEATURES)
418
+ TARGET = meta.get("target", TARGET)
419
+ TARGET_TRANSFORM = meta.get("target_transform", TARGET_TRANSFORM)
420
+ ACTUAL_COL = meta.get("actual_col", ACTUAL_COL)
421
  except Exception:
422
  pass
423
 
 
431
  st.session_state.setdefault("dev_file_bytes",b"")
432
  st.session_state.setdefault("dev_file_loaded",False)
433
  st.session_state.setdefault("dev_preview",False)
434
+ st.session_state.setdefault("show_preview_modal", False)
435
 
436
  # =========================
437
  # Branding in Sidebar
 
439
  st.sidebar.markdown(f"""
440
  <div class="centered-container">
441
  <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
442
+ <div style='font-weight:800;font-size:1.2rem;'>{APP_NAME}</div>
443
+ <div style='color:#667085;'>{TAGLINE}</div>
444
  </div>
445
  """, unsafe_allow_html=True
446
  )
447
 
448
+ # Reusable sticky header
 
 
449
  def sticky_header(title, message):
450
  st.markdown(
451
  f"""
452
  <style>
453
  .sticky-container {{
454
+ position: sticky; top: 0; background-color: white; z-index: 100;
455
+ padding-top: 10px; padding-bottom: 10px; border-bottom: 1px solid #eee;
 
 
 
 
 
456
  }}
457
  </style>
458
  <div class="sticky-container">
 
468
  # =========================
469
  if st.session_state.app_step == "intro":
470
  st.header("Welcome!")
471
+ st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate Gamma Ray (GR) from input features.")
472
  st.subheader("How It Works")
473
  st.markdown(
474
+ "1) **Upload your data to build the case and preview the performance of our model.** \n"
475
+ "2) Click **Run Model** to compute metrics and plots. \n"
476
+ "3) **Proceed to Validation** (with actual GR) or **Proceed to Prediction** (no GR)."
477
  )
478
  if st.button("Start Showcase", type="primary"):
479
  st.session_state.app_step = "dev"; st.rerun()
 
496
  st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
497
 
498
  if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
499
+ st.session_state.show_preview_modal = True
500
  st.session_state.dev_preview = True
501
 
502
  run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
503
  if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
504
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
505
 
506
+ # Sticky helper
507
  if st.session_state.dev_file_loaded and st.session_state.dev_preview:
508
  sticky_header("Case Building", "Previewed ✓ — now click **Run Model**.")
509
  elif st.session_state.dev_file_loaded:
 
514
  if run and st.session_state.dev_file_bytes:
515
  book = read_book_bytes(st.session_state.dev_file_bytes)
516
  sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
517
+ sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"])
518
  if sh_train is None or sh_test is None:
519
+ st.markdown('<div class="st-message-box st-error">Workbook must include Train/Training and Test/Testing sheets.</div>', unsafe_allow_html=True)
520
  st.stop()
521
  tr = book[sh_train].copy(); te = book[sh_test].copy()
522
+ if not (ensure_cols(tr, FEATURES) and ensure_cols(te, FEATURES)):
523
+ st.markdown('<div class="st-message-box st-error">Missing required feature columns.</div>', unsafe_allow_html=True); st.stop()
524
+
525
+ # predictions (handle log targets)
526
+ tr_pred_raw = model.predict(tr[FEATURES])
527
+ te_pred_raw = model.predict(te[FEATURES])
528
+ tr["GR_Pred"] = inverse_target(np.asarray(tr_pred_raw, dtype=float), TARGET_TRANSFORM)
529
+ te["GR_Pred"] = inverse_target(np.asarray(te_pred_raw, dtype=float), TARGET_TRANSFORM)
530
+
531
+ # actual GR (for metrics/plots)
532
+ tr["GR_Actual"] = to_actual_series(tr, TARGET, ACTUAL_COL, TARGET_TRANSFORM)
533
+ te["GR_Actual"] = to_actual_series(te, TARGET, ACTUAL_COL, TARGET_TRANSFORM)
534
 
535
  st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
536
  st.session_state.results["m_train"]={
537
+ "R": pearson_r(tr["GR_Actual"], tr["GR_Pred"]),
538
+ "RMSE": rmse(tr["GR_Actual"], tr["GR_Pred"]),
539
+ "MAE": mean_absolute_error(tr["GR_Actual"], tr["GR_Pred"])
540
  }
541
  st.session_state.results["m_test"]={
542
+ "R": pearson_r(te["GR_Actual"], te["GR_Pred"]),
543
+ "RMSE": rmse(te["GR_Actual"], te["GR_Pred"]),
544
+ "MAE": mean_absolute_error(te["GR_Actual"], te["GR_Pred"])
545
  }
546
 
547
  tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
 
553
  c1.metric("R", f"{m['R']:.2f}")
554
  c2.metric("RMSE", f"{m['RMSE']:.2f}")
555
  c3.metric("MAE", f"{m['MAE']:.2f}")
 
 
556
  st.markdown("""
557
+ <div style='text-align:left;font-size:0.8em;color:#6b7280;margin-top:-16px;margin-bottom:8px;'>
558
  <strong>R:</strong> Pearson Correlation Coefficient<br>
559
  <strong>RMSE:</strong> Root Mean Square Error<br>
560
  <strong>MAE:</strong> Mean Absolute Error
561
  </div>
562
  """, unsafe_allow_html=True)
563
 
 
564
  col_track, col_cross = st.columns([2, 3], gap="large")
565
  with col_track:
566
  st.plotly_chart(
567
+ track_plot(df.rename(columns={"GR_Actual":"GR"}), include_actual=True,
568
+ pred_col="GR_Pred", actual_col="GR"),
569
+ use_container_width=False,
570
  config={"displayModeBar": False, "scrollZoom": True}
571
  )
572
  with col_cross:
573
+ st.pyplot(cross_plot_static(df["GR_Actual"], df["GR_Pred"]), use_container_width=False)
 
574
 
575
  if "Train" in st.session_state.results or "Test" in st.session_state.results:
576
  tab1, tab2 = st.tabs(["Training", "Testing"])
577
  if "Train" in st.session_state.results:
578
  with tab1: _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
579
  if "Test" in st.session_state.results:
580
+ with tab2: _dev_block(st.session_state.results["Test"], st.session_state.results["m_test"])
581
 
582
  # =========================
583
+ # VALIDATION (with actual GR)
584
  # =========================
585
  if st.session_state.app_step == "validate":
586
  st.sidebar.header("Validate the Model")
 
591
  df0 = next(iter(book.values()))
592
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
593
  if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
594
+ st.session_state.show_preview_modal = True
595
  go_btn = st.sidebar.button("Predict & Validate", type="primary", use_container_width=True)
596
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
597
  if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
598
 
599
+ sticky_header("Validate the Model", "Upload a dataset with the same **features** and **GR** to evaluate performance.")
600
 
601
  if go_btn and up is not None:
602
  book = read_book_bytes(up.getvalue())
603
  name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
604
  df = book[name].copy()
605
+ if not ensure_cols(df, FEATURES): st.markdown('<div class="st-message-box st-error">Missing required feature columns.</div>', unsafe_allow_html=True); st.stop()
606
+
607
+ pred_raw = model.predict(df[FEATURES])
608
+ df["GR_Pred"] = inverse_target(np.asarray(pred_raw, dtype=float), TARGET_TRANSFORM)
609
+ # actual GR
610
+ try:
611
+ df["GR_Actual"] = to_actual_series(df, TARGET, ACTUAL_COL, TARGET_TRANSFORM)
612
+ except Exception:
613
+ st.markdown('<div class="st-message-box st-error">Validation sheet must include actual GR (or a target column that can be inverse-transformed).</div>', unsafe_allow_html=True); st.stop()
614
+
615
  st.session_state.results["Validate"]=df
616
 
617
  ranges = st.session_state.train_ranges; oor_pct = 0.0; tbl=None
 
623
  for c in FEATURES:
624
  if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
625
  tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
626
+
627
  st.session_state.results["m_val"]={
628
+ "R": pearson_r(df["GR_Actual"], df["GR_Pred"]),
629
+ "RMSE": rmse(df["GR_Actual"], df["GR_Pred"]),
630
+ "MAE": mean_absolute_error(df["GR_Actual"], df["GR_Pred"])
631
  }
632
+ st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["GR_Pred"].min()),"pred_max":float(df["GR_Pred"].max()),"oor":oor_pct}
633
  st.session_state.results["oor_tbl"]=tbl
634
 
635
  if "Validate" in st.session_state.results:
 
638
  c1.metric("R", f"{m['R']:.2f}")
639
  c2.metric("RMSE", f"{m['RMSE']:.2f}")
640
  c3.metric("MAE", f"{m['MAE']:.2f}")
 
 
641
  st.markdown("""
642
+ <div style='text-align:left;font-size:0.8em;color:#6b7280;margin-top:-16px;margin-bottom:8px;'>
643
  <strong>R:</strong> Pearson Correlation Coefficient<br>
644
  <strong>RMSE:</strong> Root Mean Square Error<br>
645
  <strong>MAE:</strong> Mean Absolute Error
646
  </div>
647
  """, unsafe_allow_html=True)
648
+
649
  col_track, col_cross = st.columns([2, 3], gap="large")
650
  with col_track:
651
  st.plotly_chart(
652
+ track_plot(st.session_state.results["Validate"].rename(columns={"GR_Actual":"GR"}),
653
+ include_actual=True, pred_col="GR_Pred", actual_col="GR"),
654
+ use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
655
  )
656
  with col_cross:
657
  st.pyplot(
658
+ cross_plot_static(st.session_state.results["Validate"]["GR_Actual"],
659
+ st.session_state.results["Validate"]["GR_Pred"]),
660
  use_container_width=False
661
  )
662
 
 
667
  df_centered_rounded(st.session_state.results["oor_tbl"])
668
 
669
  # =========================
670
+ # PREDICTION (no actual GR)
671
  # =========================
672
  if st.session_state.app_step == "predict":
673
+ st.sidebar.header("Prediction (No Actual GR)")
674
  up = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"])
675
  if up is not None:
676
  book = read_book_bytes(up.getvalue())
 
678
  df0 = next(iter(book.values()))
679
  st.sidebar.caption(f"**Data loaded:** {up.name} • {df0.shape[0]} rows × {df0.shape[1]} cols")
680
  if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
681
+ st.session_state.show_preview_modal = True
682
  go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
683
  if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
684
 
685
+ sticky_header("Prediction", "Upload a dataset with the feature columns (no **GR**).")
686
 
687
  if go_btn and up is not None:
688
  book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
689
  df = book[name].copy()
690
+ if not ensure_cols(df, FEATURES): st.markdown('<div class="st-message-box st-error">Missing required feature columns.</div>', unsafe_allow_html=True); st.stop()
691
+
692
+ pred_raw = model.predict(df[FEATURES])
693
+ df["GR_Pred"] = inverse_target(np.asarray(pred_raw, dtype=float), TARGET_TRANSFORM)
694
  st.session_state.results["PredictOnly"]=df
695
 
696
  ranges = st.session_state.train_ranges; oor_pct = 0.0
 
699
  oor_pct = float(any_viol.mean()*100.0)
700
  st.session_state.results["sv_pred"]={
701
  "n":len(df),
702
+ "pred_min":float(df["GR_Pred"].min()),
703
+ "pred_max":float(df["GR_Pred"].max()),
704
+ "pred_mean":float(df["GR_Pred"].mean()),
705
+ "pred_std":float(df["GR_Pred"].std(ddof=0)),
706
  "oor":oor_pct
707
  }
708
 
 
713
  with col_left:
714
  table = pd.DataFrame({
715
  "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
716
+ "Value": [sv["n"],
717
+ round(sv["pred_min"],2),
718
+ round(sv["pred_max"],2),
719
+ round(sv["pred_mean"],2),
720
+ round(sv["pred_std"],2),
721
+ f'{sv["oor"]:.1f}%']
722
  })
723
  st.markdown('<div class="st-message-box st-success">Predictions ready ✓</div>', unsafe_allow_html=True)
724
  df_centered_rounded(table, hide_index=True)
725
  st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
726
  with col_right:
727
  st.plotly_chart(
728
+ track_plot(df.rename(columns={"GR_Pred":"GR_Pred"}), include_actual=False,
729
+ pred_col="GR_Pred", actual_col="GR"),
730
+ use_container_width=False, config={"displayModeBar": False, "scrollZoom": True}
731
  )
732
 
733
  # =========================
734
+ # Preview modal (re-usable)
735
  # =========================
736
  if st.session_state.show_preview_modal:
 
737
  book_to_preview = {}
738
  if st.session_state.app_step == "dev":
739
  book_to_preview = read_book_bytes(st.session_state.dev_file_bytes)
 
757
  .agg(['min','max','mean','std'])
758
  .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
759
  df_centered_rounded(tbl.reset_index(names="Feature"))
 
760
  st.session_state.show_preview_modal = False
 
761
  # =========================
762
  # Footer
763
  # =========================
 
765
  <br><br><br>
766
  <hr>
767
  <div style='text-align:center;color:#6b7280;font-size:0.8em;'>
768
+ © 2024 Smart Thinking AI-Solutions Team. All rights reserved.<br>
769
+ Contact: smartthinking@smartthinking.com.sa
770
  </div>
771
  """, unsafe_allow_html=True)