Spaces:
Sleeping
Sleeping
| import io, json, os, base64 | |
| from pathlib import Path | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error | |
| # ========================= | |
| # Defaults (overridden by models/meta.json or model.feature_names_in_) | |
| # ========================= | |
| FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"] | |
| TARGET = "UCS" | |
| MODELS_DIR = Path("models") | |
| DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib" | |
| MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"] | |
| # ========================= | |
| # Page / Theme | |
| # ========================= | |
| st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide") | |
| # Hide Streamlit default header/footer and tighten layout | |
| st.markdown("<style>header, footer{visibility:hidden !important;}</style>", unsafe_allow_html=True) | |
| st.markdown( | |
| """ | |
| <style> | |
| .stApp { background: #FFFFFF; } | |
| section[data-testid="stSidebar"] { background: #F6F9FC; } | |
| .block-container { padding-top: .5rem; padding-bottom: .5rem; } /* less vertical padding */ | |
| .stButton>button{ background:#007bff; color:#fff; font-weight:bold; border-radius:8px; border:none; padding:10px 24px; } | |
| .stButton>button:hover{ background:#0056b3; } | |
| .st-hero { display:flex; align-items:center; gap:16px; padding-top: 4px; } | |
| .st-hero .brand { width:110px; height:110px; object-fit:contain; } /* enlarged logo */ | |
| .st-hero h1 { margin:0; line-height:1.05; } | |
| .st-hero .tagline { margin:2px 0 0 2px; color:#6b7280; font-size:1.05rem; font-style:italic; } | |
| [data-testid="stBlock"]{ margin-top:0 !important; } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def _get_model_url(): | |
| """Read optional MODEL_URL from environment only (avoid st.secrets banner).""" | |
| return (os.environ.get("MODEL_URL", "") or "").strip() | |
| def rmse(y_true, y_pred): return float(np.sqrt(mean_squared_error(y_true, y_pred))) | |
| def ensure_cols(df, cols): | |
| miss = [c for c in cols if c not in df.columns] | |
| if miss: | |
| st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}") | |
| return False | |
| return True | |
| def load_model(model_path: str): | |
| return joblib.load(model_path) | |
| def parse_excel(data_bytes: bytes): | |
| bio = io.BytesIO(data_bytes) | |
| xl = pd.ExcelFile(bio) | |
| return {sh: xl.parse(sh) for sh in xl.sheet_names} | |
| def read_book(upload): | |
| if upload is None: return {} | |
| try: return parse_excel(upload.getvalue()) | |
| except Exception as e: | |
| st.error(f"Failed to read Excel: {e}"); return {} | |
| def find_sheet(book, names): | |
| low2orig = {k.lower(): k for k in book.keys()} | |
| for nm in names: | |
| if nm.lower() in low2orig: return low2orig[nm.lower()] | |
| return None | |
| def cross_plot(actual, pred, title, size=(4.6, 4.6)): | |
| fig, ax = plt.subplots(figsize=size, dpi=100) | |
| ax.scatter(actual, pred, s=14, alpha=0.8) | |
| lo = float(np.nanmin([actual.min(), pred.min()])) | |
| hi = float(np.nanmax([actual.max(), pred.max()])) | |
| pad = 0.03 * (hi - lo if hi > lo else 1.0) | |
| ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], '--', lw=1.2, color=(0.35, 0.35, 0.35)) | |
| ax.set_xlim(lo - pad, hi + pad) | |
| ax.set_ylim(lo - pad, hi + pad) | |
| ax.set_aspect('equal', 'box') # perfect 1:1 | |
| ax.set_xlabel("Actual UCS"); ax.set_ylabel("Predicted UCS"); ax.set_title(title) | |
| ax.grid(True, ls=":", alpha=0.4) | |
| return fig | |
| def depth_or_index_track(df, title=None, include_actual=True): | |
| # Find depth-like column if available | |
| depth_col = None | |
| for c in df.columns: | |
| if 'depth' in str(c).lower(): | |
| depth_col = c; break | |
| fig_h = 7.4 if depth_col is not None else 7.0 # taller track; still fits most screens | |
| fig, ax = plt.subplots(figsize=(6.0, fig_h), dpi=100) | |
| if depth_col is not None: | |
| ax.plot(df["UCS_Pred"], df[depth_col], '--', lw=1.6, label="UCS_Pred") | |
| if include_actual and TARGET in df.columns: | |
| ax.plot(df[TARGET], df[depth_col], '-', lw=2.0, alpha=0.85, label="UCS (actual)") | |
| ax.set_ylabel(depth_col); ax.set_xlabel("UCS") | |
| ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis() | |
| else: | |
| idx = np.arange(1, len(df) + 1) | |
| ax.plot(df["UCS_Pred"], idx, '--', lw=1.6, label="UCS_Pred") | |
| if include_actual and TARGET in df.columns: | |
| ax.plot(df[TARGET], idx, '-', lw=2.0, alpha=0.85, label="UCS (actual)") | |
| ax.set_ylabel("Point Index"); ax.set_xlabel("UCS") | |
| ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis() | |
| ax.grid(True, linestyle=":", alpha=0.4) | |
| if title: ax.set_title(title, pad=8) # no title if None/empty | |
| ax.legend(loc="best") | |
| return fig | |
| def export_workbook(sheets_dict, summary_df=None): | |
| try: import openpyxl # noqa | |
| except Exception: raise RuntimeError("Export requires openpyxl. Please add it to requirements or install it.") | |
| buf = io.BytesIO() | |
| with pd.ExcelWriter(buf, engine="openpyxl") as xw: | |
| for name, frame in sheets_dict.items(): | |
| frame.to_excel(xw, sheet_name=name[:31], index=False) | |
| if summary_df is not None: summary_df.to_excel(xw, sheet_name="Summary", index=False) | |
| return buf.getvalue() | |
| def toast(msg): | |
| try: st.toast(msg) | |
| except Exception: st.info(msg) | |
| def infer_features_from_model(m): | |
| try: | |
| if hasattr(m, "feature_names_in_") and len(getattr(m, "feature_names_in_")): | |
| return [str(x) for x in m.feature_names_in_] | |
| except Exception: pass | |
| try: | |
| if hasattr(m, "steps") and len(m.steps): | |
| last = m.steps[-1][1] | |
| if hasattr(last, "feature_names_in_") and len(last.feature_names_in_): | |
| return [str(x) for x in last.feature_names_in_] | |
| except Exception: pass | |
| return None | |
| def inline_logo(path="logo.png") -> str: | |
| try: | |
| p = Path(path) | |
| if not p.exists(): return "" | |
| return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}" | |
| except Exception: | |
| return "" | |
| # ========================= | |
| # Model presence (local or optional download) | |
| # ========================= | |
| MODEL_URL = _get_model_url() | |
| def ensure_model_present() -> Path: | |
| for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]: | |
| if p.exists() and p.stat().st_size > 0: | |
| return p | |
| if not MODEL_URL: | |
| return None | |
| try: | |
| import requests | |
| DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True) | |
| with st.status("Downloading model…", expanded=False): | |
| with requests.get(MODEL_URL, stream=True, timeout=30) as r: | |
| r.raise_for_status() | |
| with open(DEFAULT_MODEL, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1<<20): | |
| if chunk: f.write(chunk) | |
| return DEFAULT_MODEL | |
| except Exception as e: | |
| st.error(f"Failed to download model from MODEL_URL: {e}") | |
| return None | |
| model_path = ensure_model_present() | |
| if not model_path: | |
| st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL in Settings → Variables).") | |
| st.stop() | |
| try: | |
| model = load_model(str(model_path)) | |
| except Exception as e: | |
| st.error(f"Failed to load model: {model_path}\n{e}") | |
| st.stop() | |
| # Meta overrides or inference | |
| meta_path = MODELS_DIR / "meta.json" | |
| if meta_path.exists(): | |
| try: | |
| meta = json.loads(meta_path.read_text(encoding="utf-8")) | |
| FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET) | |
| except Exception: pass | |
| else: | |
| infer = infer_features_from_model(model) | |
| if infer: FEATURES = infer | |
| # ========================= | |
| # Session state | |
| # ========================= | |
| if "app_step" not in st.session_state: st.session_state.app_step = "intro" | |
| if "results" not in st.session_state: st.session_state.results = {} | |
| if "train_ranges" not in st.session_state: st.session_state.train_ranges = None | |
| if "dev_ready" not in st.session_state: st.session_state.dev_ready = False | |
| if ("Train" in st.session_state.results) or ("Test" in st.session_state.results): | |
| st.session_state.dev_ready = True | |
| # ========================= | |
| # Hero header (logo + title) | |
| # ========================= | |
| st.markdown( | |
| f""" | |
| <div class="st-hero"> | |
| <img src="{inline_logo()}" class="brand" /> | |
| <div> | |
| <h1>ST_GeoMech_UCS</h1> | |
| <div class="tagline">Real-Time UCS Tracking While Drilling — Cloud Ready</div> | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # ========================= | |
| # INTRO PAGE | |
| # ========================= | |
| if st.session_state.app_step == "intro": | |
| st.header("Welcome!") | |
| st.markdown( | |
| "This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data." | |
| ) | |
| st.subheader("Required Input Columns") | |
| st.markdown( | |
| "- Q, gpm — Flow rate (gallons per minute) \n" | |
| "- SPP(psi) — Stand pipe pressure \n" | |
| "- T (kft.lbf) — Torque (thousand foot-pounds) \n" | |
| "- WOB (klbf) — Weight on bit \n" | |
| "- ROP (ft/h) — Rate of penetration" | |
| ) | |
| st.subheader("How It Works") | |
| st.markdown( | |
| "1. **Upload your development data (Excel)** and click **Run Model** to compute metrics and review plots. \n" | |
| "2. Click **Proceed to Prediction** to upload a new dataset for validation and view results. \n" | |
| "3. Export results to Excel at any time." | |
| ) | |
| if st.button("Start Showcase", type="primary", key="start_showcase"): | |
| st.session_state.app_step = "dev"; st.rerun() | |
| # ========================= | |
| # MODEL DEVELOPMENT (Train/Test) | |
| # ========================= | |
| if st.session_state.app_step == "dev": | |
| st.sidebar.header("Model Development Data") | |
| train_test_file = st.sidebar.file_uploader("Upload Data (Excel)", type=["xlsx","xls"], key="dev_upload") | |
| run_btn = st.sidebar.button("Run Model", type="primary", use_container_width=True) | |
| # Proceed button BELOW run, always visible; enables immediately after first successful run | |
| st.sidebar.button( | |
| "Proceed to Prediction ▶", | |
| use_container_width=True, | |
| disabled=not st.session_state.dev_ready, | |
| on_click=(lambda: st.session_state.update(app_step="predict")) if st.session_state.dev_ready else None, | |
| ) | |
| # ---- Header + helper sentence positioned under the header (your request) ---- | |
| st.subheader("Model Development") | |
| st.write("Upload your data to train the model and review the development performance.") | |
| if run_btn and train_test_file is not None: | |
| with st.status("Processing…", expanded=False) as status: | |
| book = read_book(train_test_file) | |
| if not book: status.update(label="Failed to read workbook.", state="error"); st.stop() | |
| status.update(label="Workbook read ✓") | |
| # Internally still expect Train/Test sheets | |
| sh_train = find_sheet(book, ["Train","Training","training2","train","training"]) | |
| sh_test = find_sheet(book, ["Test","Testing","testing2","test","testing"]) | |
| if sh_train is None or sh_test is None: | |
| status.update(label="Workbook must include Train/Training/training2 and Test/Testing/testing2.", state="error"); st.stop() | |
| df_tr = book[sh_train].copy(); df_te = book[sh_test].copy() | |
| if not (ensure_cols(df_tr, FEATURES + [TARGET]) and ensure_cols(df_te, FEATURES + [TARGET])): | |
| status.update(label="Missing required columns.", state="error"); st.stop() | |
| status.update(label="Columns validated ✓"); status.update(label="Predicting…") | |
| df_tr["UCS_Pred"] = model.predict(df_tr[FEATURES]) | |
| df_te["UCS_Pred"] = model.predict(df_te[FEATURES]) | |
| st.session_state.results["Train"] = df_tr; st.session_state.results["Test"] = df_te | |
| st.session_state.results["metrics_train"] = { | |
| "R2": r2_score(df_tr[TARGET], df_tr["UCS_Pred"]), | |
| "RMSE": rmse(df_tr[TARGET], df_tr["UCS_Pred"]), | |
| "MAE": mean_absolute_error(df_tr[TARGET], df_tr["UCS_Pred"]), | |
| } | |
| st.session_state.results["metrics_test"] = { | |
| "R2": r2_score(df_te[TARGET], df_te["UCS_Pred"]), | |
| "RMSE": rmse(df_te[TARGET], df_te["UCS_Pred"]), | |
| "MAE": mean_absolute_error(df_te[TARGET], df_te["UCS_Pred"]), | |
| } | |
| tr_min = df_tr[FEATURES].min().to_dict(); tr_max = df_tr[FEATURES].max().to_dict() | |
| st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES} | |
| st.session_state.dev_ready = True # enable Proceed button immediately | |
| status.update(label="Done ✓", state="complete"); toast("Model run complete 🚀") | |
| st.rerun() # refresh to enable the sidebar button without a second click | |
| if ("Train" in st.session_state.results) or ("Test" in st.session_state.results): | |
| tab1, tab2 = st.tabs(["Training", "Testing"]) | |
| if "Train" in st.session_state.results: | |
| with tab1: | |
| df = st.session_state.results["Train"]; m = st.session_state.results["metrics_train"] | |
| c1,c2,c3 = st.columns(3) | |
| c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}") | |
| left,right = st.columns([1,1]) | |
| with left: | |
| st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Training: Actual vs Predicted"), use_container_width=True) | |
| with right: | |
| # no title on the track (cleaner) | |
| st.pyplot(depth_or_index_track(df, title=None, include_actual=True), use_container_width=True) | |
| if "Test" in st.session_state.results: | |
| with tab2: | |
| df = st.session_state.results["Test"]; m = st.session_state.results["metrics_test"] | |
| c1,c2,c3 = st.columns(3) | |
| c1.metric("R²", f"{m['R2']:.4f}"); c2.metric("RMSE", f"{m['RMSE']:.4f}"); c3.metric("MAE", f"{m['MAE']:.4f}") | |
| left,right = st.columns([1,1]) | |
| with left: | |
| st.pyplot(cross_plot(df[TARGET], df["UCS_Pred"], "Testing: Actual vs Predicted"), use_container_width=True) | |
| with right: | |
| st.pyplot(depth_or_index_track(df, title=None, include_actual=True), use_container_width=True) | |
| st.markdown("---") | |
| sheets = {}; rows = [] | |
| if "Train" in st.session_state.results: | |
| sheets["Train_with_pred"] = st.session_state.results["Train"] | |
| rows.append({"Split":"Train", **{k:round(v,6) for k,v in st.session_state.results["metrics_train"].items()}}) | |
| if "Test" in st.session_state.results: | |
| sheets["Test_with_pred"] = st.session_state.results["Test"] | |
| rows.append({"Split":"Test", **{k:round(v,6) for k,v in st.session_state.results["metrics_test"].items()}}) | |
| summary_df = pd.DataFrame(rows) if rows else None | |
| try: | |
| data_bytes = export_workbook(sheets, summary_df) | |
| st.download_button("Export Development Results to Excel", | |
| data=data_bytes, file_name="UCS_Dev_Results.xlsx", | |
| mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") | |
| except RuntimeError as e: | |
| st.warning(str(e)) | |
| # ========================= | |
| # PREDICTION (Validation) | |
| # ========================= | |
| if st.session_state.app_step == "predict": | |
| st.sidebar.header("Prediction (Validation)") | |
| validation_file = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"], key="val_upload") | |
| predict_btn = st.sidebar.button("Predict", type="primary", use_container_width=True) | |
| st.sidebar.button("⬅ Back", on_click=lambda: st.session_state.update(app_step="dev"), use_container_width=True) | |
| st.subheader("Prediction") | |
| st.write("Upload a new dataset to generate UCS predictions and evaluate performance on unseen data.") | |
| if predict_btn and validation_file is not None: | |
| with st.status("Predicting…", expanded=False) as status: | |
| vbook = read_book(validation_file) | |
| if not vbook: status.update(label="Could not read the Validation Excel.", state="error"); st.stop() | |
| status.update(label="Workbook read ✓") | |
| vname = find_sheet(vbook, ["Validation","Validate","validation2","Val","val"]) or list(vbook.keys())[0] | |
| df_val = vbook[vname].copy() | |
| if not ensure_cols(df_val, FEATURES): status.update(label="Missing required columns.", state="error"); st.stop() | |
| status.update(label="Columns validated ✓") | |
| df_val["UCS_Pred"] = model.predict(df_val[FEATURES]) | |
| st.session_state.results["Validate"] = df_val | |
| ranges = st.session_state.train_ranges; oor_table = None; oor_pct = 0.0 | |
| if ranges: | |
| viol = {f: (df_val[f] < ranges[f][0]) | (df_val[f] > ranges[f][1]) for f in FEATURES} | |
| any_viol = pd.DataFrame(viol).any(axis=1); oor_pct = float(any_viol.mean()*100.0) | |
| if any_viol.any(): | |
| offenders = df_val.loc[any_viol, FEATURES].copy() | |
| offenders["Violations"] = pd.DataFrame(viol).loc[any_viol].apply(lambda r: ", ".join([c for c,v in r.items() if v]), axis=1) | |
| offenders.index = offenders.index + 1; oor_table = offenders | |
| metrics_val = None | |
| if TARGET in df_val.columns: | |
| metrics_val = { | |
| "R2": r2_score(df_val[TARGET], df_val["UCS_Pred"]), | |
| "RMSE": rmse(df_val[TARGET], df_val["UCS_Pred"]), | |
| "MAE": mean_absolute_error(df_val[TARGET], df_val["UCS_Pred"]) | |
| } | |
| st.session_state.results["metrics_val"] = metrics_val | |
| st.session_state.results["summary_val"] = { | |
| "n_points": len(df_val), | |
| "pred_min": float(df_val["UCS_Pred"].min()), | |
| "pred_max": float(df_val["UCS_Pred"].max()), | |
| "oor_pct": oor_pct | |
| } | |
| st.session_state.results["oor_table"] = oor_table | |
| status.update(label="Predictions ready ✓", state="complete") | |
| if "Validate" in st.session_state.results: | |
| st.subheader("Validation Results") | |
| sv = st.session_state.results["summary_val"]; oor_table = st.session_state.results.get("oor_table") | |
| # ---- NEW: show OOR warning above the plots when applicable ---- | |
| if sv["oor_pct"] > 0: | |
| st.warning("Some validation inputs fall outside the **training min–max** ranges. Interpret predictions with caution.") | |
| c1,c2,c3,c4 = st.columns(4) | |
| c1.metric("points", f"{sv['n_points']}"); c2.metric("Pred min", f"{sv['pred_min']:.2f}") | |
| c3.metric("Pred max", f"{sv['pred_max']:.2f}"); c4.metric("OOR %", f"{sv['oor_pct']:.1f}%") | |
| left,right = st.columns([1,1]) | |
| with left: | |
| if TARGET in st.session_state.results["Validate"].columns: | |
| st.pyplot(cross_plot(st.session_state.results["Validate"][TARGET], st.session_state.results["Validate"]["UCS_Pred"], "Validation: Actual vs Predicted"), use_container_width=True) | |
| else: | |
| st.info("Actual UCS values are not available in the validation data. Cross-plot cannot be generated.") | |
| with right: | |
| st.pyplot(depth_or_index_track(st.session_state.results["Validate"], title=None, include_actual=(TARGET in st.session_state.results["Validate"].columns)), use_container_width=True) | |
| if oor_table is not None: | |
| st.write("*Out-of-range rows (vs. Training min–max):*") | |
| st.dataframe(oor_table, use_container_width=True) | |
| st.markdown("---") | |
| sheets = {"Validate_with_pred": st.session_state.results["Validate"]} | |
| rows = [] | |
| for name, key in [("Train","metrics_train"), ("Test","metrics_test"), ("Validate","metrics_val")]: | |
| m = st.session_state.results.get(key) | |
| if m: rows.append({"Split": name, **{k: round(v,6) for k,v in m.items()}}) | |
| summary_df = pd.DataFrame(rows) if rows else None | |
| try: | |
| data_bytes = export_workbook(sheets, summary_df) | |
| st.download_button("Export Validation Results to Excel", | |
| data=data_bytes, file_name="UCS_Validation_Results.xlsx", | |
| mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") | |
| except RuntimeError as e: | |
| st.warning(str(e)) | |
| # ========================= | |
| # Footer | |
| # ========================= | |
| st.markdown("---") | |
| st.markdown("<div style='text-align:center; color:#6b7280;'>ST_GeoMech_UCS • © Smart Thinking</div>", unsafe_allow_html=True) | |