UCS2014 commited on
Commit
074f655
·
verified ·
1 Parent(s): a380eee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -137
app.py CHANGED
@@ -5,10 +5,18 @@ import pandas as pd
5
  import numpy as np
6
  import joblib
7
  import matplotlib
8
- matplotlib.use("Agg")
9
  import matplotlib.pyplot as plt
10
  from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
11
 
 
 
 
 
 
 
 
 
12
  # =========================
13
  # Defaults
14
  # =========================
@@ -18,7 +26,7 @@ MODELS_DIR = Path("models")
18
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
19
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
20
 
21
- COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
22
 
23
  # =========================
24
  # Page / Theme
@@ -31,13 +39,27 @@ st.markdown(
31
  .stApp { background: #FFFFFF; }
32
  section[data-testid="stSidebar"] { background: #F6F9FC; }
33
  .block-container { padding-top: .5rem; padding-bottom: .5rem; }
34
- .stButton>button{ background:#007bff; color:#fff; font-weight:bold; border-radius:8px; border:none; padding:10px 24px; }
35
- .stButton>button:hover{ background:#0056b3; }
 
 
36
  .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
37
  .st-hero .brand { width:110px; height:110px; object-fit:contain; }
38
  .st-hero h1 { margin:0; line-height:1.05; }
39
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
40
  [data-testid="stBlock"]{ margin-top:0 !important; }
 
 
 
 
 
 
 
 
 
 
 
 
41
  </style>
42
  """,
43
  unsafe_allow_html=True
@@ -49,7 +71,6 @@ st.markdown(
49
  try:
50
  dialog = st.dialog
51
  except AttributeError:
52
- # Fallback (expander) if st.dialog is unavailable
53
  def dialog(title):
54
  def deco(fn):
55
  def wrapper(*args, **kwargs):
@@ -58,21 +79,16 @@ except AttributeError:
58
  return wrapper
59
  return deco
60
 
61
- def _get_model_url():
62
- return (os.environ.get("MODEL_URL", "") or "").strip()
63
-
64
  def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
65
 
66
  def ensure_cols(df, cols):
67
  miss = [c for c in cols if c not in df.columns]
68
- if miss:
69
- st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
70
- return False
71
  return True
72
 
73
  @st.cache_resource(show_spinner=False)
74
- def load_model(model_path: str):
75
- return joblib.load(model_path)
76
 
77
  @st.cache_data(show_spinner=False)
78
  def parse_excel(data_bytes: bytes):
@@ -83,8 +99,7 @@ def parse_excel(data_bytes: bytes):
83
  def read_book_bytes(data_bytes: bytes):
84
  if not data_bytes: return {}
85
  try: return parse_excel(data_bytes)
86
- except Exception as e:
87
- st.error(f"Failed to read Excel: {e}"); return {}
88
 
89
  def find_sheet(book, names):
90
  low2orig = {k.lower(): k for k in book.keys()}
@@ -92,52 +107,6 @@ def find_sheet(book, names):
92
  if nm.lower() in low2orig: return low2orig[nm.lower()]
93
  return None
94
 
95
- def cross_plot(actual, pred, title, size=(3.9, 3.9)):
96
- fig, ax = plt.subplots(figsize=size, dpi=100)
97
- ax.scatter(actual, pred, s=14, alpha=0.85, color=COLORS["pred"])
98
- lo = float(np.nanmin([actual.min(), pred.min()]))
99
- hi = float(np.nanmax([actual.max(), pred.max()]))
100
- pad = 0.03 * (hi - lo if hi > lo else 1.0)
101
- ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], '--', lw=1.2, color=COLORS["ref"])
102
- ax.set_xlim(lo - pad, hi + pad); ax.set_ylim(lo - pad, hi + pad)
103
- ax.set_aspect('equal', 'box')
104
- ax.set_xlabel("Actual UCS"); ax.set_ylabel("Predicted UCS"); ax.set_title(title)
105
- ax.grid(True, ls=":", alpha=0.4)
106
- return fig
107
-
108
- def depth_or_index_track(df, title=None, include_actual=True):
109
- depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
110
- fig_w = 3.1
111
- fig_h = 7.6 if depth_col is not None else 7.2
112
- fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=100)
113
- if depth_col is not None:
114
- ax.plot(df["UCS_Pred"], df[depth_col], '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
115
- if include_actual and TARGET in df.columns:
116
- ax.plot(df[TARGET], df[depth_col], ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
117
- ax.set.ylabel(depth_col); ax.set_xlabel("UCS")
118
- ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
119
- else:
120
- idx = np.arange(1, len(df) + 1)
121
- ax.plot(df["UCS_Pred"], idx, '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
122
- if include_actual and TARGET in df.columns:
123
- ax.plot(df[TARGET], idx, ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
124
- ax.set_ylabel("Point Index"); ax.set_xlabel("UCS")
125
- ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
126
- ax.grid(True, linestyle=":", alpha=0.4)
127
- if title: ax.set_title(title, pad=8)
128
- ax.legend(loc="best")
129
- return fig
130
-
131
- def export_workbook(sheets_dict, summary_df=None):
132
- try: import openpyxl # noqa
133
- except Exception: raise RuntimeError("Export requires openpyxl. Please add it to requirements or install it.")
134
- buf = io.BytesIO()
135
- with pd.ExcelWriter(buf, engine="openpyxl") as xw:
136
- for name, frame in sheets_dict.items():
137
- frame.to_excel(xw, sheet_name=name[:31], index=False)
138
- if summary_df is not None: summary_df.to_excel(xw, sheet_name="Summary", index=False)
139
- return buf.getvalue()
140
-
141
  def toast(msg):
142
  try: st.toast(msg)
143
  except Exception: st.info(msg)
@@ -163,27 +132,114 @@ def inline_logo(path="logo.png") -> str:
163
  except Exception:
164
  return ""
165
 
166
- # ---------- Preview modal helpers ----------
167
- def make_index_tracks(df: pd.DataFrame, cols: list[str]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  cols = [c for c in cols if c in df.columns]
 
 
 
 
 
 
169
  n = len(cols)
170
- if n == 0:
171
- fig, ax = plt.subplots(figsize=(4, 2))
172
- ax.text(0.5, 0.5, "No selected columns in sheet", ha="center", va="center")
173
- ax.axis("off"); return fig
174
- width_per = 2.2
175
- fig_h = 7.0
176
- fig, axes = plt.subplots(1, n, figsize=(width_per * n, fig_h), sharey=True, dpi=100)
177
- if n == 1: axes = [axes]
178
  idx = np.arange(1, len(df) + 1)
179
- for ax, col in zip(axes, cols):
180
- ax.plot(df[col], idx, '-', lw=1.4, color="#333")
181
- ax.set_xlabel(col)
182
- ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
183
- ax.grid(True, linestyle=":", alpha=0.3)
184
- axes[0].set_ylabel("Point Index")
 
 
 
185
  return fig
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def stats_table(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
188
  cols = [c for c in cols if c in df.columns]
189
  if not cols:
@@ -209,8 +265,13 @@ def preview_modal_dev(book: dict[str, pd.DataFrame], feature_cols: list[str]):
209
  for t, df in zip(t_objs, data):
210
  with t:
211
  t1, t2 = st.tabs(["Tracks", "Summary"])
212
- with t1: st.pyplot(make_index_tracks(df, feature_cols), use_container_width=True)
213
- with t2: st.dataframe(stats_table(df, feature_cols), use_container_width=True)
 
 
 
 
 
214
 
215
  @dialog("Preview data")
216
  def preview_modal_val(book: dict[str, pd.DataFrame], feature_cols: list[str]):
@@ -219,8 +280,13 @@ def preview_modal_val(book: dict[str, pd.DataFrame], feature_cols: list[str]):
219
  vname = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
220
  df = book[vname]
221
  t1, t2 = st.tabs(["Tracks", "Summary"])
222
- with t1: st.pyplot(make_index_tracks(df, feature_cols), use_container_width=True)
223
- with t2: st.dataframe(stats_table(df, feature_cols), use_container_width=True)
 
 
 
 
 
224
 
225
  # =========================
226
  # Model presence
@@ -231,8 +297,7 @@ def ensure_model_present() -> Path:
231
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
232
  if p.exists() and p.stat().st_size > 0:
233
  return p
234
- if not MODEL_URL:
235
- return None
236
  try:
237
  import requests
238
  DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
@@ -276,17 +341,10 @@ if "app_step" not in st.session_state: st.session_state.app_step = "intro"
276
  if "results" not in st.session_state: st.session_state.results = {}
277
  if "train_ranges" not in st.session_state: st.session_state.train_ranges = None
278
 
279
- # Dev page state (persist file)
280
  for k, v in {
281
- "dev_ready": False,
282
- "dev_file_loaded": False,
283
- "dev_previewed": False,
284
- "dev_file_signature": None,
285
- "dev_preview_request": False,
286
- "dev_file_bytes": b"",
287
- "dev_file_name": "",
288
- "dev_file_rows": 0,
289
- "dev_file_cols": 0,
290
  }.items():
291
  if k not in st.session_state: st.session_state[k] = v
292
 
@@ -311,10 +369,8 @@ st.markdown(
311
  # =========================
312
  if st.session_state.app_step == "intro":
313
  st.header("Welcome!")
314
- st.markdown(
315
- "This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data."
316
- )
317
- st.subheader("Expected Input Features")
318
  st.markdown(
319
  "- Q, gpm — Flow rate (gallons per minute) \n"
320
  "- SPP(psi) — Stand pipe pressure \n"
@@ -340,20 +396,17 @@ if st.session_state.app_step == "dev":
340
  dev_label = "Upload Data (Excel)" if not st.session_state.dev_file_name else "Replace data (Excel)"
341
  train_test_file = st.sidebar.file_uploader(dev_label, type=["xlsx","xls"], key="dev_upload")
342
 
343
- # Detect new/changed file and PERSIST BYTES
344
  if train_test_file is not None:
345
  try:
346
- file_bytes = train_test_file.getvalue()
347
- size = len(file_bytes)
348
  except Exception:
349
- file_bytes = b""
350
- size = 0
351
  sig = (train_test_file.name, size)
352
  if sig != st.session_state.dev_file_signature and size > 0:
353
  st.session_state.dev_file_signature = sig
354
  st.session_state.dev_file_name = train_test_file.name
355
  st.session_state.dev_file_bytes = file_bytes
356
- # Inspect first sheet for rows/cols
357
  _book_tmp = read_book_bytes(file_bytes)
358
  if _book_tmp:
359
  first_df = next(iter(_book_tmp.values()))
@@ -363,29 +416,23 @@ if st.session_state.app_step == "dev":
363
  st.session_state.dev_previewed = False
364
  st.session_state.dev_ready = False
365
 
366
- # Sidebar caption (from persisted info)
367
  if st.session_state.dev_file_loaded:
368
  st.sidebar.caption(
369
  f"**Data loaded:** {st.session_state.dev_file_name} • "
370
  f"{st.session_state.dev_file_rows} rows × {st.session_state.dev_file_cols} cols"
371
  )
372
 
373
- # Sidebar actions
 
374
  preview_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded)
375
- if preview_btn and st.session_state.dev_file_loaded:
376
- st.session_state.dev_preview_request = True
377
-
378
  run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
 
 
379
 
380
- proceed_clicked = st.sidebar.button(
381
- "Proceed to Prediction ▶",
382
- use_container_width=True,
383
- disabled=not st.session_state.dev_ready
384
- )
385
  if proceed_clicked and st.session_state.dev_ready:
386
  st.session_state.app_step = "predict"; st.rerun()
387
 
388
- # ----- ALWAYS-ON TOP: Title + helper -----
389
  helper_top = st.container()
390
  with helper_top:
391
  st.subheader("Model Development")
@@ -398,14 +445,11 @@ if st.session_state.app_step == "dev":
398
  else:
399
  st.write("**Upload your data to build a case, then run the model to review development performance.**")
400
 
401
- # If user clicked preview, open modal *after* helper so helper stays on top
402
- if st.session_state.dev_preview_request and st.session_state.dev_file_bytes:
403
  _book = read_book_bytes(st.session_state.dev_file_bytes)
404
  st.session_state.dev_previewed = True
405
- st.session_state.dev_preview_request = False
406
  preview_modal_dev(_book, FEATURES)
407
 
408
- # Run model (from persisted bytes)
409
  if run_btn and st.session_state.dev_file_bytes:
410
  with st.status("Processing…", expanded=False) as status:
411
  book = read_book_bytes(st.session_state.dev_file_bytes)
@@ -442,7 +486,7 @@ if st.session_state.app_step == "dev":
442
  status.update(label="Done ✓", state="complete"); toast("Model run complete 🚀")
443
  st.rerun()
444
 
445
- # Results (if available)
446
  if ("Train" in st.session_state.results) or ("Test" in st.session_state.results):
447
  tab1, tab2 = st.tabs(["Training", "Testing"])
448
  if "Train" in st.session_state.results:
@@ -452,9 +496,18 @@ if st.session_state.app_step == "dev":
452
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
453
  left, right = st.columns([0.9, 0.55])
454
  with left:
455
- st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True)
 
 
 
 
 
456
  with right:
457
- st.pyplot(depth_or_index_track(df, title=None, include_actual=True), use_container_width=True)
 
 
 
 
458
  if "Test" in st.session_state.results:
459
  with tab2:
460
  df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
@@ -462,9 +515,18 @@ if st.session_state.app_step == "dev":
462
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
463
  left, right = st.columns([0.9, 0.55])
464
  with left:
465
- st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True)
 
 
 
 
 
466
  with right:
467
- st.pyplot(depth_or_index_track(df, title=None, include_actual=True), use_container_width=True)
 
 
 
 
468
 
469
  st.markdown("---")
470
  sheets = {}; rows = []
@@ -484,7 +546,7 @@ if st.session_state.app_step == "dev":
484
  st.warning(str(e))
485
 
486
  # =========================
487
- # PREDICTION (Validation) — (kept simple; uploader works fine)
488
  # =========================
489
  if st.session_state.app_step == "predict":
490
  st.sidebar.header("Prediction (Validation)")
@@ -495,17 +557,19 @@ if st.session_state.app_step == "predict":
495
  first_df = next(iter(_book_tmp.values()))
496
  st.sidebar.caption(f"**Data loaded:** {validation_file.name} • {first_df.shape[0]} rows × {first_df.shape[1]} cols")
497
 
 
498
  preview_val_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=(validation_file is None))
499
- if preview_val_btn and validation_file is not None:
500
- _book = read_book_bytes(validation_file.getvalue())
501
- preview_modal_val(_book, FEATURES)
502
-
503
  predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
504
  st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
 
505
 
506
  st.subheader("Prediction")
507
  st.write("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.")
508
 
 
 
 
 
509
  if predict_btn and validation_file is not None:
510
  with st.status("Predicting…", expanded=False) as status:
511
  vbook = read_book_bytes(validation_file.getvalue())
@@ -524,7 +588,8 @@ if st.session_state.app_step == "predict":
524
  any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
525
  if any_viol.any():
526
  offenders = df_val.loc[any_viol, FEATURES].copy()
527
- offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
 
528
  offenders.index = offenders.index + 1; oor_table = offenders
529
 
530
  metrics_val = None
@@ -558,17 +623,27 @@ if st.session_state.app_step == "predict":
558
  left, right = st.columns([0.9, 0.55])
559
  with left:
560
  if TARGET in st.session_state.results["Validate"].columns:
561
- st.pyplot(cross_plot(st.session_state.results["Validate"][TARGET],
562
- st.session_state.results["Validate"]["UCS_Pred"],
563
- "Validation: Actual vs Predicted"),
564
- use_container_width=True)
 
 
 
 
 
 
565
  else:
566
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
567
  with right:
568
- st.pyplot(depth_or_index_track(
569
- st.session_state.results["Validate"], title=None,
570
- include_actual=(TARGET in st.session_state.results["Validate"].columns)
571
- ), use_container_width=True)
 
 
 
 
572
 
573
  if oor_table is not None:
574
  st.write("*Out-of-range rows (vs. Training min–max):*")
 
5
  import numpy as np
6
  import joblib
7
  import matplotlib
8
+ matplotlib.use("Agg") # fallback only
9
  import matplotlib.pyplot as plt
10
  from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
11
 
12
+ # Try Plotly for interactivity
13
+ HAVE_PLOTLY = True
14
+ try:
15
+ import plotly.graph_objects as go
16
+ from plotly.subplots import make_subplots
17
+ except Exception:
18
+ HAVE_PLOTLY = False
19
+
20
  # =========================
21
  # Defaults
22
  # =========================
 
26
  DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
27
  MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
28
 
29
+ COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a", "orange": "#f59e0b", "green": "#198754"}
30
 
31
  # =========================
32
  # Page / Theme
 
39
  .stApp { background: #FFFFFF; }
40
  section[data-testid="stSidebar"] { background: #F6F9FC; }
41
  .block-container { padding-top: .5rem; padding-bottom: .5rem; }
42
+ .stButton>button{ background:#0d6efd; color:#fff; font-weight:bold; border-radius:8px; border:none; padding:10px 24px; }
43
+ .stButton>button:hover{ filter: brightness(0.92); }
44
+
45
+ /* Hero header */
46
  .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; }
47
  .st-hero .brand { width:110px; height:110px; object-fit:contain; }
48
  .st-hero h1 { margin:0; line-height:1.05; }
49
  .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; }
50
  [data-testid="stBlock"]{ margin-top:0 !important; }
51
+
52
+ /* Colorize sidebar button groups we wrap with custom classes */
53
+ section[data-testid="stSidebar"] .dev-actions .stButton:nth-of-type(1) button { background: #f59e0b !important; } /* Preview - orange */
54
+ section[data-testid="stSidebar"] .dev-actions .stButton:nth-of-type(2) button { background: #0d6efd !important; } /* Run - blue */
55
+ section[data-testid="stSidebar"] .dev-actions .stButton:nth-of-type(3) button { background: #198754 !important; } /* Proceed - green */
56
+
57
+ section[data-testid="stSidebar"] .val-actions .stButton:nth-of-type(1) button { background: #f59e0b !important; } /* Preview - orange */
58
+ section[data-testid="stSidebar"] .val-actions .stButton:nth-of-type(2) button { background: #0d6efd !important; } /* Predict - blue */
59
+
60
+ /* Disabled look */
61
+ section[data-testid="stSidebar"] .dev-actions .stButton button:disabled,
62
+ section[data-testid="stSidebar"] .val-actions .stButton button:disabled { filter: grayscale(40%); opacity:.6; }
63
  </style>
64
  """,
65
  unsafe_allow_html=True
 
71
  try:
72
  dialog = st.dialog
73
  except AttributeError:
 
74
  def dialog(title):
75
  def deco(fn):
76
  def wrapper(*args, **kwargs):
 
79
  return wrapper
80
  return deco
81
 
82
+ def _get_model_url(): return (os.environ.get("MODEL_URL", "") or "").strip()
 
 
83
  def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred)))
84
 
85
  def ensure_cols(df, cols):
86
  miss = [c for c in cols if c not in df.columns]
87
+ if miss: st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}"); return False
 
 
88
  return True
89
 
90
  @st.cache_resource(show_spinner=False)
91
+ def load_model(model_path: str): return joblib.load(model_path)
 
92
 
93
  @st.cache_data(show_spinner=False)
94
  def parse_excel(data_bytes: bytes):
 
99
  def read_book_bytes(data_bytes: bytes):
100
  if not data_bytes: return {}
101
  try: return parse_excel(data_bytes)
102
+ except Exception as e: st.error(f"Failed to read Excel: {e}"); return {}
 
103
 
104
  def find_sheet(book, names):
105
  low2orig = {k.lower(): k for k in book.keys()}
 
107
  if nm.lower() in low2orig: return low2orig[nm.lower()]
108
  return None
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def toast(msg):
111
  try: st.toast(msg)
112
  except Exception: st.info(msg)
 
132
  except Exception:
133
  return ""
134
 
135
+ # -------- Plotting (Plotly first, Matplotlib fallback) --------
136
+ def cross_plotly(actual, pred, title):
137
+ lo = float(np.nanmin([actual.min(), pred.min()]))
138
+ hi = float(np.nanmax([actual.max(), pred.max()]))
139
+ pad = 0.03 * (hi - lo if hi > lo else 1.0)
140
+ fig = go.Figure()
141
+ fig.add_trace(go.Scatter(
142
+ x=actual, y=pred, mode="markers",
143
+ marker=dict(size=6, color=COLORS["pred"]),
144
+ hovertemplate="Actual: %{x:.2f}<br>Pred: %{y:.2f}<extra></extra>",
145
+ name="Points"
146
+ ))
147
+ fig.add_trace(go.Scatter(
148
+ x=[lo - pad, hi + pad], y=[lo - pad, hi + pad],
149
+ mode="lines", line=dict(dash="dash", width=1.5, color=COLORS["ref"]),
150
+ hoverinfo="skip", showlegend=False
151
+ ))
152
+ fig.update_layout(
153
+ title=title, margin=dict(l=10, r=10, t=40, b=10), height=350,
154
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0)
155
+ )
156
+ fig.update_xaxes(title_text="Actual UCS", scaleanchor="y", scaleratio=1)
157
+ fig.update_yaxes(title_text="Predicted UCS")
158
+ return fig
159
+
160
+ def track_plotly(df, include_actual=True):
161
+ depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
162
+ if depth_col is not None:
163
+ y = df[depth_col]
164
+ y_label = depth_col
165
+ else:
166
+ y = np.arange(1, len(df) + 1)
167
+ y_label = "Point Index"
168
+ fig = go.Figure()
169
+ fig.add_trace(go.Scatter(
170
+ x=df["UCS_Pred"], y=y, mode="lines",
171
+ line=dict(color=COLORS["pred"], width=2),
172
+ name="UCS_Pred",
173
+ hovertemplate="UCS_Pred: %{x:.2f}<br>"+y_label+": %{y}<extra></extra>"
174
+ ))
175
+ if include_actual and TARGET in df.columns:
176
+ fig.add_trace(go.Scatter(
177
+ x=df[TARGET], y=y, mode="lines",
178
+ line=dict(color=COLORS["actual"], dash="dot", width=2.2),
179
+ name="UCS (actual)",
180
+ hovertemplate="UCS (actual): %{x:.2f}<br>"+y_label+": %{y}<extra></extra>"
181
+ ))
182
+ fig.update_yaxes(autorange="reversed", title_text=y_label)
183
+ fig.update_xaxes(title_text="UCS", side="top")
184
+ fig.update_layout(
185
+ margin=dict(l=10, r=10, t=40, b=10), height=650,
186
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0)
187
+ )
188
+ return fig
189
+
190
+ def make_index_tracks_plotly(df: pd.DataFrame, cols: list[str]):
191
  cols = [c for c in cols if c in df.columns]
192
+ if not cols:
193
+ fig = go.Figure()
194
+ fig.add_annotation(text="No selected columns in sheet", showarrow=False, x=0.5, y=0.5)
195
+ fig.update_xaxes(visible=False); fig.update_yaxes(visible=False)
196
+ fig.update_layout(height=200, margin=dict(l=10,r=10,t=10,b=10))
197
+ return fig
198
  n = len(cols)
199
+ fig = make_subplots(rows=1, cols=n, shared_y=True, horizontal_spacing=0.05)
 
 
 
 
 
 
 
200
  idx = np.arange(1, len(df) + 1)
201
+ for i, col in enumerate(cols, start=1):
202
+ fig.add_trace(
203
+ go.Scatter(x=df[col], y=idx, mode="lines", line=dict(color="#333", width=1.2),
204
+ hovertemplate=f"{col}: "+"%{x:.2f}<br>Index: %{y}<extra></extra>", name=col, showlegend=False),
205
+ row=1, col=i
206
+ )
207
+ fig.update_xaxes(title_text=col, side="top", row=1, col=i)
208
+ fig.update_yaxes(autorange="reversed", title_text="Point Index", row=1, col=1)
209
+ fig.update_layout(height=650, margin=dict(l=10, r=10, t=40, b=10))
210
  return fig
211
 
212
+ # Fallbacks (kept if Plotly missing)
213
+ def cross_plot_mpl(actual, pred, title, size=(3.9,3.9)):
214
+ fig, ax = plt.subplots(figsize=size, dpi=100)
215
+ ax.scatter(actual, pred, s=14, alpha=0.85, color=COLORS["pred"])
216
+ lo = float(np.nanmin([actual.min(), pred.min()])); hi = float(np.nanmax([actual.max(), pred.max()]))
217
+ pad = 0.03 * (hi - lo if hi > lo else 1.0)
218
+ ax.plot([lo-pad, hi+pad], [lo-pad, hi+pad], '--', lw=1.2, color=COLORS["ref"])
219
+ ax.set_xlim(lo-pad, hi+pad); ax.set_ylim(lo-pad, hi+pad); ax.set_aspect('equal','box')
220
+ ax.set_xlabel("Actual UCS"); ax.set_ylabel("Predicted UCS"); ax.set_title(title); ax.grid(True, ls=":", alpha=0.4)
221
+ return fig
222
+
223
+ def depth_or_index_track_mpl(df, title=None, include_actual=True):
224
+ depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
225
+ fig, ax = plt.subplots(figsize=(3.1, 7.2), dpi=100)
226
+ if depth_col is not None:
227
+ ax.plot(df["UCS_Pred"], df[depth_col], '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
228
+ if include_actual and TARGET in df.columns:
229
+ ax.plot(df[TARGET], df[depth_col], ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
230
+ ax.set_ylabel(depth_col); ax.set_xlabel("UCS"); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
231
+ else:
232
+ idx = np.arange(1, len(df) + 1)
233
+ ax.plot(df["UCS_Pred"], idx, '-', lw=1.8, color=COLORS["pred"], label="UCS_Pred")
234
+ if include_actual and TARGET in df.columns:
235
+ ax.plot(df[TARGET], idx, ':', lw=2.0, color=COLORS["actual"], alpha=0.95, label="UCS (actual)")
236
+ ax.set_ylabel("Point Index"); ax.set_xlabel("UCS"); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
237
+ ax.grid(True, linestyle=":", alpha=0.4);
238
+ if title: ax.set_title(title, pad=8)
239
+ ax.legend(loc="best")
240
+ return fig
241
+
242
+ # ---------- Preview modal helpers ----------
243
  def stats_table(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
244
  cols = [c for c in cols if c in df.columns]
245
  if not cols:
 
265
  for t, df in zip(t_objs, data):
266
  with t:
267
  t1, t2 = st.tabs(["Tracks", "Summary"])
268
+ with t1:
269
+ if HAVE_PLOTLY:
270
+ st.plotly_chart(make_index_tracks_plotly(df, feature_cols), use_container_width=True, theme=None)
271
+ else:
272
+ st.pyplot(depth_or_index_track_mpl(df, title=None, include_actual=False), use_container_width=True)
273
+ with t2:
274
+ st.dataframe(stats_table(df, feature_cols), use_container_width=True)
275
 
276
  @dialog("Preview data")
277
  def preview_modal_val(book: dict[str, pd.DataFrame], feature_cols: list[str]):
 
280
  vname = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
281
  df = book[vname]
282
  t1, t2 = st.tabs(["Tracks", "Summary"])
283
+ with t1:
284
+ if HAVE_PLOTLY:
285
+ st.plotly_chart(make_index_tracks_plotly(df, feature_cols), use_container_width=True, theme=None)
286
+ else:
287
+ st.pyplot(depth_or_index_track_mpl(df, title=None, include_actual=False), use_container_width=True)
288
+ with t2:
289
+ st.dataframe(stats_table(df, feature_cols), use_container_width=True)
290
 
291
  # =========================
292
  # Model presence
 
297
  for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
298
  if p.exists() and p.stat().st_size > 0:
299
  return p
300
+ if not MODEL_URL: return None
 
301
  try:
302
  import requests
303
  DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
 
341
  if "results" not in st.session_state: st.session_state.results = {}
342
  if "train_ranges" not in st.session_state: st.session_state.train_ranges = None
343
 
 
344
  for k, v in {
345
+ "dev_ready": False, "dev_file_loaded": False, "dev_previewed": False,
346
+ "dev_file_signature": None, "dev_preview_request": False,
347
+ "dev_file_bytes": b"", "dev_file_name": "", "dev_file_rows": 0, "dev_file_cols": 0,
 
 
 
 
 
 
348
  }.items():
349
  if k not in st.session_state: st.session_state[k] = v
350
 
 
369
  # =========================
370
  if st.session_state.app_step == "intro":
371
  st.header("Welcome!")
372
+ st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
373
+ st.subheader("Required Input Columns")
 
 
374
  st.markdown(
375
  "- Q, gpm — Flow rate (gallons per minute) \n"
376
  "- SPP(psi) — Stand pipe pressure \n"
 
396
  dev_label = "Upload Data (Excel)" if not st.session_state.dev_file_name else "Replace data (Excel)"
397
  train_test_file = st.sidebar.file_uploader(dev_label, type=["xlsx","xls"], key="dev_upload")
398
 
399
+ # Persist upload
400
  if train_test_file is not None:
401
  try:
402
+ file_bytes = train_test_file.getvalue(); size = len(file_bytes)
 
403
  except Exception:
404
+ file_bytes = b""; size = 0
 
405
  sig = (train_test_file.name, size)
406
  if sig != st.session_state.dev_file_signature and size > 0:
407
  st.session_state.dev_file_signature = sig
408
  st.session_state.dev_file_name = train_test_file.name
409
  st.session_state.dev_file_bytes = file_bytes
 
410
  _book_tmp = read_book_bytes(file_bytes)
411
  if _book_tmp:
412
  first_df = next(iter(_book_tmp.values()))
 
416
  st.session_state.dev_previewed = False
417
  st.session_state.dev_ready = False
418
 
 
419
  if st.session_state.dev_file_loaded:
420
  st.sidebar.caption(
421
  f"**Data loaded:** {st.session_state.dev_file_name} • "
422
  f"{st.session_state.dev_file_rows} rows × {st.session_state.dev_file_cols} cols"
423
  )
424
 
425
+ # Button group with wrapper to color via CSS
426
+ st.sidebar.markdown('<div class="dev-actions">', unsafe_allow_html=True)
427
  preview_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded)
 
 
 
428
  run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True)
429
+ proceed_clicked = st.sidebar.button("Proceed to Prediction ▶", use_container_width=True, disabled=not st.session_state.dev_ready)
430
+ st.sidebar.markdown('</div>', unsafe_allow_html=True)
431
 
 
 
 
 
 
432
  if proceed_clicked and st.session_state.dev_ready:
433
  st.session_state.app_step = "predict"; st.rerun()
434
 
435
+ # Pinned helper
436
  helper_top = st.container()
437
  with helper_top:
438
  st.subheader("Model Development")
 
445
  else:
446
  st.write("**Upload your data to build a case, then run the model to review development performance.**")
447
 
448
+ if preview_btn and st.session_state.dev_file_loaded and st.session_state.dev_file_bytes:
 
449
  _book = read_book_bytes(st.session_state.dev_file_bytes)
450
  st.session_state.dev_previewed = True
 
451
  preview_modal_dev(_book, FEATURES)
452
 
 
453
  if run_btn and st.session_state.dev_file_bytes:
454
  with st.status("Processing…", expanded=False) as status:
455
  book = read_book_bytes(st.session_state.dev_file_bytes)
 
486
  status.update(label="Done ✓", state="complete"); toast("Model run complete 🚀")
487
  st.rerun()
488
 
489
+ # Results
490
  if ("Train" in st.session_state.results) or ("Test" in st.session_state.results):
491
  tab1, tab2 = st.tabs(["Training", "Testing"])
492
  if "Train" in st.session_state.results:
 
496
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
497
  left, right = st.columns([0.9, 0.55])
498
  with left:
499
+ if HAVE_PLOTLY:
500
+ st.plotly_chart(cross_plotly(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"),
501
+ use_container_width=True, theme=None)
502
+ else:
503
+ st.pyplot(cross_plot_mpl(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"),
504
+ use_container_width=True)
505
  with right:
506
+ if HAVE_PLOTLY:
507
+ st.plotly_chart(track_plotly(df, include_actual=True), use_container_width=True, theme=None)
508
+ else:
509
+ st.pyplot(depth_or_index_track_mpl(df, title=None, include_actual=True),
510
+ use_container_width=True)
511
  if "Test" in st.session_state.results:
512
  with tab2:
513
  df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"]
 
515
  c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}")
516
  left, right = st.columns([0.9, 0.55])
517
  with left:
518
+ if HAVE_PLOTLY:
519
+ st.plotly_chart(cross_plotly(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"),
520
+ use_container_width=True, theme=None)
521
+ else:
522
+ st.pyplot(cross_plot_mpl(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"),
523
+ use_container_width=True)
524
  with right:
525
+ if HAVE_PLOTLY:
526
+ st.plotly_chart(track_plotly(df, include_actual=True), use_container_width=True, theme=None)
527
+ else:
528
+ st.pyplot(depth_or_index_track_mpl(df, title=None, include_actual=True),
529
+ use_container_width=True)
530
 
531
  st.markdown("---")
532
  sheets = {}; rows = []
 
546
  st.warning(str(e))
547
 
548
  # =========================
549
+ # PREDICTION (Validation)
550
  # =========================
551
  if st.session_state.app_step == "predict":
552
  st.sidebar.header("Prediction (Validation)")
 
557
  first_df = next(iter(_book_tmp.values()))
558
  st.sidebar.caption(f"**Data loaded:** {validation_file.name} • {first_df.shape[0]} rows × {first_df.shape[1]} cols")
559
 
560
+ st.sidebar.markdown('<div class="val-actions">', unsafe_allow_html=True)
561
  preview_val_btn = st.sidebar.button("Preview data", use_container_width=True, disabled=(validation_file is None))
 
 
 
 
562
  predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
563
  st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True)
564
+ st.sidebar.markdown('</div>', unsafe_allow_html=True)
565
 
566
  st.subheader("Prediction")
567
  st.write("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.")
568
 
569
+ if preview_val_btn and validation_file is not None:
570
+ _book = read_book_bytes(validation_file.getvalue())
571
+ preview_modal_val(_book, FEATURES)
572
+
573
  if predict_btn and validation_file is not None:
574
  with st.status("Predicting…", expanded=False) as status:
575
  vbook = read_book_bytes(validation_file.getvalue())
 
588
  any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0)
589
  if any_viol.any():
590
  offenders = df_val.loc[any_viol, FEATURES].copy()
591
+ offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(
592
+ lambda r: ", ".join([c for c,v in r.items() if v]), axis=1)
593
  offenders.index = offenders.index + 1; oor_table = offenders
594
 
595
  metrics_val = None
 
623
  left, right = st.columns([0.9, 0.55])
624
  with left:
625
  if TARGET in st.session_state.results["Validate"].columns:
626
+ if HAVE_PLOTLY:
627
+ st.plotly_chart(cross_plotly(st.session_state.results["Validate"][TARGET],
628
+ st.session_state.results["Validate"]["UCS_Pred"],
629
+ "Validation: Actual vs Predicted"),
630
+ use_container_width=True, theme=None)
631
+ else:
632
+ st.pyplot(cross_plot_mpl(st.session_state.results["Validate"][TARGET],
633
+ st.session_state.results["Validate"]["UCS_Pred"],
634
+ "Validation: Actual vs Predicted"),
635
+ use_container_width=True)
636
  else:
637
  st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.")
638
  with right:
639
+ if HAVE_PLOTLY:
640
+ st.plotly_chart(track_plotly(st.session_state.results["Validate"],
641
+ include_actual=(TARGET in st.session_state.results["Validate"].columns)),
642
+ use_container_width=True, theme=None)
643
+ else:
644
+ st.pyplot(depth_or_index_track_mpl(st.session_state.results["Validate"], title=None,
645
+ include_actual=(TARGET in st.session_state.results["Validate"].columns)),
646
+ use_container_width=True)
647
 
648
  if oor_table is not None:
649
  st.write("*Out-of-range rows (vs. Training min–max):*")