UCS / app.py
UCS2014's picture
Update app.py
cbd43ea verified
raw
history blame
34.5 kB
import io, json, os, base64, math
from pathlib import Path
import streamlit as st
import pandas as pd
import numpy as np
import joblib
# Matplotlib for PREVIEW modal and for the CROSS-PLOT (static)
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import plotly.graph_objects as go
from sklearn.metrics import mean_squared_error, mean_absolute_error
# =========================
# Constants
# =========================
FEATURES = ["Q, gpm", "SPP(psi)", "T (kft.lbf)", "WOB (klbf)", "ROP (ft/h)"]
TARGET = "UCS"
MODELS_DIR = Path("models")
DEFAULT_MODEL = MODELS_DIR / "ucs_rf.joblib"
MODEL_FALLBACKS = [MODELS_DIR / "model.joblib", MODELS_DIR / "model.pkl"]
COLORS = {"pred": "#1f77b4", "actual": "#f2b702", "ref": "#5a5a5a"}
# ---- Plot sizing controls ----
CROSS_W = 250 # px (matplotlib figure size; Streamlit will still scale)
CROSS_H = 250
TRACK_H = 1000 # px (plotly height; width auto-fits column)
# NEW: Add a TRACK_W variable to control the width
TRACK_W = 500 # px (plotly width)
FONT_SZ  = 13
BOLD_FONT = "Arial Black, Arial, sans-serif"  # used for bold axis titles & ticks
# =========================
# Page / CSS
# =========================
st.set_page_config(page_title="ST_GeoMech_UCS", page_icon="logo.png", layout="wide")
# General CSS (logo helpers etc.)
st.markdown("""
<style>
  .brand-logo { width: 200px; height: auto; object-fit: contain; }
  .sidebar-header { display:flex; align-items:center; gap:12px; }
  .sidebar-header .text h1 { font-size: 1.05rem; margin:0; line-height:1.1; }
  .sidebar-header .text .tag { font-size: .85rem; color:#6b7280; margin:2px 0 0; }
  .centered-container {
    display: flex;
    flex-direction: column;
    align-items: center;
    text-align: center;
  }
</style>
""", unsafe_allow_html=True)
# CSS to make sticky headers work correctly by overriding Streamlit's overflow property
st.markdown("""
<style>
/* This targets the main content area */
.main .block-container {
    overflow: unset !important;
}
/* This targets the vertical block that holds all your elements */
div[data-testid="stVerticalBlock"] {
    overflow: unset !important;
}
</style>
""", unsafe_allow_html=True)
# Hide uploader helper text ("Drag and drop file here", limits, etc.)
st.markdown("""
<style>
/* Older builds (helper wrapped in a Markdown container) */
section[data-testid="stFileUploader"] div[data-testid="stMarkdownContainer"]{display:none !important;}
/* 1.31–1.34: helper is the first child in the dropzone */
section[data-testid="stFileUploader"] [data-testid="stFileUploaderDropzone"] > div:first-child{display:none !important;}
/* 1.35+: explicit helper container */
section[data-testid="stFileUploader"] [data-testid="stFileUploaderInstructions"]{display:none !important;}
/* Fallback: any paragraph/small text inside the uploader */
section[data-testid="stFileUploader"] p, section[data-testid="stFileUploader"] small{display:none !important;}
</style>
""", unsafe_allow_html=True)
# Make the Preview expander title & tabs sticky (pinned to the top)
st.markdown("""
<style>
div[data-testid="stExpander"] > details > summary {
  position: sticky;
  top: 0;
  z-index: 10;
  background: #fff;
  border-bottom: 1px solid #eee;
}
div[data-testid="stExpander"] div[data-baseweb="tab-list"] {
  position: sticky;
  top: 42px;    /* adjust if your expander header height differs */
  z-index: 9;
  background: #fff;
  padding-top: 6px;
}
</style>
""", unsafe_allow_html=True)
# Center text in all pandas Styler tables (headers + cells)
TABLE_CENTER_CSS = [
    dict(selector="th", props=[("text-align", "center")]),
    dict(selector="td", props=[("text-align", "center")]),
]
# NEW: CSS for the message box
st.markdown("""
<style>
.st-message-box {
    background-color: #f0f2f6;
    color: #333333;
    padding: 10px;
    border-radius: 10px;
    border: 1px solid #e6e9ef;
}
.st-message-box.st-success {
    background-color: #d4edda;
    color: #155724;
    border-color: #c3e6cb;
}
.st-message-box.st-warning {
    background-color: #fff3cd;
    color: #856404;
    border-color: #ffeeba;
}
.st-message-box.st-error {
    background-color: #f8d7da;
    color: #721c24;
    border-color: #f5c6cb;
}
</style>
""", unsafe_allow_html=True)
# =========================
# Password gate
# =========================
def inline_logo(path="logo.png") -> str:
    try:
        p = Path(path)
        if not p.exists(): return ""
        return f"data:image/png;base64,{base64.b64encode(p.read_bytes()).decode('ascii')}"
    except Exception:
        return ""
def add_password_gate() -> None:
    try:
        required = st.secrets.get("APP_PASSWORD", "")
    except Exception:
        required = os.environ.get("APP_PASSWORD", "")
    if not required:
        st.warning("Set APP_PASSWORD in Secrets (or environment) and restart.")
        st.stop()
    if st.session_state.get("auth_ok", False):
        return
    st.sidebar.markdown(f"""
        <div class="centered-container">
            <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
            <div style='font-weight:800;font-size:1.2rem; margin-top: 10px;'>ST_GeoMech_UCS</div>
            <div style='color:#667085;'>Smart Thinking • Secure Access</div>
        </div>
        """, unsafe_allow_html=True
    )
    pwd = st.sidebar.text_input("Access key", type="password", placeholder="••••••••")
    if st.sidebar.button("Unlock", type="primary"):
        if pwd == required:
            st.session_state.auth_ok = True
            st.rerun()
        else:
            st.error("Incorrect key.")
    st.stop()
add_password_gate()
# =========================
# Utilities
# =========================
def rmse(y_true, y_pred) -> float:
    return float(np.sqrt(mean_squared_error(y_true, y_pred)))
def pearson_r(y_true, y_pred) -> float:
    a = np.asarray(y_true, dtype=float)
    p = np.asarray(y_pred,   dtype=float)
    if a.size < 2: return float("nan")
    return float(np.corrcoef(a, p)[0, 1])
@st.cache_resource(show_spinner=False)
def load_model(model_path: str):
    return joblib.load(model_path)
@st.cache_data(show_spinner=False)
def parse_excel(data_bytes: bytes):
    bio = io.BytesIO(data_bytes)
    xl = pd.ExcelFile(bio)
    return {sh: xl.parse(sh) for sh in xl.sheet_names}
def read_book_bytes(b: bytes): return parse_excel(b) if b else {}
def ensure_cols(df, cols):
    miss = [c for c in cols if c not in df.columns]
    if miss:
        st.error(f"Missing columns: {miss}\nFound: {list(df.columns)}")
        return False
    return True
def find_sheet(book, names):
    low2orig = {k.lower(): k for k in book.keys()}
    for nm in names:
        if nm.lower() in low2orig: return low2orig[nm.lower()]
    return None
def _nice_tick0(xmin: float, step: int = 100) -> float:
    return step * math.floor(xmin / step) if np.isfinite(xmin) else xmin
def df_centered_rounded(df: pd.DataFrame, hide_index=True):
    """Center headers & cells; format numeric columns to 2 decimals."""
    out = df.copy()
    numcols = out.select_dtypes(include=[np.number]).columns
    styler = (
        out.style
            .format({c: "{:.2f}" for c in numcols})
            .set_properties(**{"text-align": "center"})
            .set_table_styles(TABLE_CENTER_CSS)
    )
    st.dataframe(styler, use_container_width=True, hide_index=hide_index)
# =========================
# Cross plot (Matplotlib, fixed limits & ticks)
# =========================
def cross_plot_static(actual, pred):
    a = pd.Series(actual, dtype=float)
    p = pd.Series(pred,   dtype=float)
    fixed_min, fixed_max = 6000, 10000
    ticks = np.arange(fixed_min, fixed_max + 1, 1000)
    dpi = 110
    fig, ax = plt.subplots(
        figsize=(CROSS_W / dpi, CROSS_H / dpi),
        dpi=dpi,
        constrained_layout=False
    )
    ax.scatter(a, p, s=14, c=COLORS["pred"], alpha=0.9, linewidths=0)
    ax.plot([fixed_min, fixed_max], [fixed_min, fixed_max],
            linestyle="--", linewidth=1.2, color=COLORS["ref"])
    ax.set_xlim(fixed_min, fixed_max)
    ax.set_ylim(fixed_min, fixed_max)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_aspect("equal", adjustable="box"# true 45°
    fmt = FuncFormatter(lambda x, _: f"{int(x):,}")
    ax.xaxis.set_major_formatter(fmt)
    ax.yaxis.set_major_formatter(fmt)
    ax.set_xlabel("Actual UCS (psi)", fontweight="bold", fontsize=4, color="black")
    ax.set_ylabel("Predicted UCS (psi)", fontweight="bold", fontsize=4, color="black")
    ax.tick_params(labelsize=2, colors="black")
    ax.grid(True, linestyle=":", alpha=0.3)
    for spine in ax.spines.values():
        spine.set_linewidth(1.1)
        spine.set_color("#444")
    fig.subplots_adjust(left=0.16, bottom=0.16, right=0.98, top=0.98)
    return fig
# =========================
# Track plot (Plotly)
# =========================
def track_plot(df, include_actual=True):
    depth_col = next((c for c in df.columns if 'depth' in str(c).lower()), None)
    if depth_col is not None:
        y = pd.Series(df[depth_col]).astype(float)
        ylab = depth_col
        y_range = [float(y.max()), float(y.min())]  # reverse
    else:
        y = pd.Series(np.arange(1, len(df) + 1))
        ylab = "Point Index"
        y_range = [float(y.max()), float(y.min())]
    # X (UCS) range & ticks
    x_series = pd.Series(df.get("UCS_Pred", pd.Series(dtype=float))).astype(float)
    if include_actual and TARGET in df.columns:
        x_series = pd.concat([x_series, pd.Series(df[TARGET]).astype(float)], ignore_index=True)
    x_lo, x_hi = float(x_series.min()), float(x_series.max())
    x_pad = 0.03 * (x_hi - x_lo if x_hi > x_lo else 1.0)
    xmin, xmax = x_lo - x_pad, x_hi + x_pad
    tick0 = _nice_tick0(xmin, step=100)
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=df["UCS_Pred"], y=y, mode="lines",
        line=dict(color=COLORS["pred"], width=1.8),
        name="UCS_Pred",
        hovertemplate="UCS_Pred: %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
    ))
    if include_actual and TARGET in df.columns:
        fig.add_trace(go.Scatter(
            x=df[TARGET], y=y, mode="lines",
            line=dict(color=COLORS["actual"], width=2.0, dash="dot"),
            name="UCS (actual)",
            hovertemplate="UCS (actual): %{x:.0f}<br>"+ylab+": %{y}<extra></extra>"
        ))
    fig.update_layout(
        height=TRACK_H,
        width=TRACK_W, # Set the width here
        autosize=False, # Disable autosizing to respect the width
        paper_bgcolor="#fff", plot_bgcolor="#fff",
        margin=dict(l=64, r=16, t=36, b=48), hovermode="closest",
        font=dict(size=FONT_SZ, color="#000"),
        legend=dict(
            x=0.98, y=0.05, xanchor="right", yanchor="bottom",
            bgcolor="rgba(255,255,255,0.75)", bordercolor="#ccc", borderwidth=1
        ),
        legend_title_text=""
    )
    # Bold, black axis titles & ticks
    fig.update_xaxes(
        title_text="UCS (psi)",
        title_font=dict(size=20, family=BOLD_FONT, color="#000"),
        tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
        side="top",
        range=[xmin, xmax],
        ticks="outside",
        tickformat=",.0f",
        tickmode="auto",
        tick0=tick0,
        showline=True, linewidth=1.2, linecolor="#444", mirror=True,
        showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
    )
    fig.update_yaxes(
        title_text=ylab,
        title_font=dict(size=20, family=BOLD_FONT, color="#000"),
        tickfont=dict(size=15, family=BOLD_FONT, color="#000"),
        range=y_range,
        ticks="outside",
        showline=True, linewidth=1.2, linecolor="#444", mirror=True,
        showgrid=True, gridcolor="rgba(0,0,0,0.12)", automargin=True
    )
    return fig
# ---------- Preview modal (matplotlib) ----------
def preview_tracks(df: pd.DataFrame, cols: list[str]):
    cols = [c for c in cols if c in df.columns]
    n = len(cols)
    if n == 0:
        fig, ax = plt.subplots(figsize=(4, 2))
        ax.text(0.5,0.5,"No selected columns",ha="center",va="center"); ax.axis("off")
        return fig
    fig, axes = plt.subplots(1, n, figsize=(2.2*n, 7.0), sharey=True, dpi=100)
    if n == 1: axes = [axes]
    idx = np.arange(1, len(df) + 1)
    for ax, col in zip(axes, cols):
        ax.plot(df[col], idx, '-', lw=1.4, color="#333")
        ax.set_xlabel(col); ax.xaxis.set_label_position('top'); ax.xaxis.tick_top(); ax.invert_yaxis()
        ax.grid(True, linestyle=":", alpha=0.3)
        for s in ax.spines.values(): s.set_visible(True)
    axes[0].set_ylabel("Point Index")
    return fig
# Modal wrapper (Streamlit compatibility)
try:
    dialog = st.dialog
except AttributeError:
    def dialog(title):
        def deco(fn):
            def wrapper(*args, **kwargs):
                with st.expander(title, expanded=True):
                    return fn(*args, **kwargs)
            return wrapper
        return deco
def preview_modal(book: dict[str, pd.DataFrame]):
    if not book:
        st.info("No data loaded yet."); return
    names = list(book.keys())
    tabs = st.tabs(names)
    for t, name in zip(tabs, names):
        with t:
            df = book[name]
            t1, t2 = st.tabs(["Tracks", "Summary"])
            with t1:
                st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
            with t2:
                tbl = (df[FEATURES]
                        .agg(['min','max','mean','std'])
                        .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
                df_centered_rounded(tbl.reset_index(names="Feature"))
# =========================
# Load model
# =========================
def ensure_model() -> Path|None:
    for p in [DEFAULT_MODEL, *MODEL_FALLBACKS]:
        if p.exists() and p.stat().st_size > 0: return p
    url = os.environ.get("MODEL_URL", "")
    if not url: return None
    try:
        import requests
        DEFAULT_MODEL.parent.mkdir(parents=True, exist_ok=True)
        with requests.get(url, stream=True, timeout=30) as r:
            r.raise_for_status()
            with open(DEFAULT_MODEL, "wb") as f:
                for chunk in r.iter_content(1<<20):
                    if chunk: f.write(chunk)
        return DEFAULT_MODEL
    except Exception:
        return None
mpath = ensure_model()
if not mpath:
    st.error("Model not found. Upload models/ucs_rf.joblib (or set MODEL_URL).")
    st.stop()
try:
    model = load_model(str(mpath))
except Exception as e:
    st.error(f"Failed to load model: {e}")
    st.stop()
meta_path = MODELS_DIR / "meta.json"
if meta_path.exists():
    try:
        meta = json.loads(meta_path.read_text(encoding="utf-8"))
        FEATURES = meta.get("features", FEATURES); TARGET = meta.get("target", TARGET)
    except Exception:
        pass
# =========================
# Session state
# =========================
st.session_state.setdefault("app_step", "intro")
st.session_state.setdefault("results", {})
st.session_state.setdefault("train_ranges", None)
st.session_state.setdefault("dev_file_name","")
st.session_state.setdefault("dev_file_bytes",b"")
st.session_state.setdefault("dev_file_loaded",False)
st.session_state.setdefault("dev_preview",False)
st.session_state.setdefault("show_preview_modal", False) # New state variable
# =========================
# Branding in Sidebar
# =========================
st.sidebar.markdown(f"""
    <div class="centered-container">
        <img src="{inline_logo('logo.png')}" style="width: 200px; height: auto; object-fit: contain;">
        <div style='font-weight:800;font-size:1.2rem;'>ST_GeoMech_UCS</div>
        <div style='color:#667085;'>Real-Time UCS Tracking While Drilling</div>
    </div>
    """, unsafe_allow_html=True
)
# =========================
# Reusable Sticky Header Function
# =========================
def sticky_header(title, message):
    st.markdown(
        f"""
        <style>
        .sticky-container {{
            position: sticky;
            top: 0;
            background-color: white;
            z-index: 100;
            padding-top: 10px;
            padding-bottom: 10px;
            border-bottom: 1px solid #eee;
        }}
        </style>
        <div class="sticky-container">
            <h3>{title}</h3>
            <p>{message}</p>
        </div>
        """,
        unsafe_allow_html=True
    )
# =========================
# INTRO
# =========================
if st.session_state.app_step == "intro":
    st.header("Welcome!")
    st.markdown("This software is developed by *Smart Thinking AI-Solutions Team* to estimate UCS from drilling data.")
    st.subheader("How It Works")
    st.markdown(
        "1) **Upload your data to build the case and preview the performance of our model.** \n"
        "2) Click **Run Model** to compute metrics and plots.  \n"
        "3) **Proceed to Validation** (with actual UCS) or **Proceed to Prediction** (no UCS)."
    )
    if st.button("Start Showcase", type="primary"):
        st.session_state.app_step = "dev"; st.rerun()
# =========================
# CASE BUILDING
# =========================
if st.session_state.app_step == "dev":
    st.sidebar.header("Case Building")
    up = st.sidebar.file_uploader("Upload Your Data File", type=["xlsx","xls"])
    if up is not None:
        st.session_state.dev_file_bytes = up.getvalue()
        st.session_state.dev_file_name = up.name
        st.session_state.dev_file_loaded = True
        st.session_state.dev_preview = False
    if st.session_state.dev_file_loaded:
        tmp = read_book_bytes(st.session_state.dev_file_bytes)
        if tmp:
            df0 = next(iter(tmp.values()))
            st.sidebar.caption(f"**Data loaded:** {st.session_state.dev_file_name}{df0.shape[0]} rows × {df0.shape[1]} cols")
    if st.sidebar.button("Preview data", use_container_width=True, disabled=not st.session_state.dev_file_loaded):
        st.session_state.show_preview_modal = True  # Set state to show modal
        st.session_state.dev_preview = True
    run = st.sidebar.button("Run Model", type="primary", use_container_width=True)
    if st.sidebar.button("Proceed to Validation ▶", use_container_width=True): st.session_state.app_step="validate"; st.rerun()
    if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
    # Apply sticky header
    if st.session_state.dev_file_loaded and st.session_state.dev_preview:
        sticky_header("Case Building", "Previewed ✓ — now click **Run Model**.")
    elif st.session_state.dev_file_loaded:
        sticky_header("Case Building", "📄 **Preview uploaded data** using the sidebar button, then click **Run Model**.")
    else:
        sticky_header("Case Building", "**Upload your data to build a case, then run the model to review development performance.**")
    if run and st.session_state.dev_file_bytes:
        book = read_book_bytes(st.session_state.dev_file_bytes)
        sh_train = find_sheet(book, ["Train","Training","training2","train","training"])
        sh_test  = find_sheet(book, ["Test","Testing","testing2","test","testing"])
        if sh_train is None or sh_test is None:
            st.markdown('<div class="st-message-box st-error">Workbook must include Train/Training/training2 and Test/Testing/testing2 sheets.</div>', unsafe_allow_html=True)
            st.stop()
        tr = book[sh_train].copy(); te = book[sh_test].copy()
        if not (ensure_cols(tr, FEATURES+[TARGET]) and ensure_cols(te, FEATURES+[TARGET])):
            st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True)
            st.stop()
        tr["UCS_Pred"] = model.predict(tr[FEATURES])
        te["UCS_Pred"] = model.predict(te[FEATURES])
        st.session_state.results["Train"]=tr; st.session_state.results["Test"]=te
        st.session_state.results["m_train"]={
            "R": pearson_r(tr[TARGET], tr["UCS_Pred"]),
            "RMSE": rmse(tr[TARGET], tr["UCS_Pred"]),
            "MAE": mean_absolute_error(tr[TARGET], tr["UCS_Pred"])
        }
        st.session_state.results["m_test"]={
            "R": pearson_r(te[TARGET], te["UCS_Pred"]),
            "RMSE": rmse(te[TARGET], te["UCS_Pred"]),
            "MAE": mean_absolute_error(te[TARGET], te["UCS_Pred"])
        }
        tr_min = tr[FEATURES].min().to_dict(); tr_max = tr[FEATURES].max().to_dict()
        st.session_state.train_ranges = {f:(float(tr_min[f]), float(tr_max[f])) for f in FEATURES}
        st.markdown('<div class="st-message-box st-success">Case has been built and results are displayed below.</div>', unsafe_allow_html=True)
    def _dev_block(df, m):
        c1,c2,c3 = st.columns(3)
        c1.metric("R", f"{m['R']:.2f}")
        c2.metric("RMSE", f"{m['RMSE']:.2f}")
        c3.metric("MAE", f"{m['MAE']:.2f}")
        # NEW: Footer for metric abbreviations
        st.markdown("""
            <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
                <strong>R:</strong> Pearson Correlation Coefficient<br>
                <strong>RMSE:</strong> Root Mean Square Error<br>
                <strong>MAE:</strong> Mean Absolute Error
            </div>
        """, unsafe_allow_html=True)
        # 2-column layout, big gap (prevents overlap)
        col_cross, col_track = st.columns([3, 2], gap="large")
        with col_cross:
            st.pyplot(cross_plot_static(df[TARGET], df["UCS_Pred"]), use_container_width=False)
        with col_track:
            st.plotly_chart(
                track_plot(df, include_actual=True),
                use_container_width=False, # Set to False to honor the width in track_plot()
                config={"displayModeBar": False, "scrollZoom": True}
            )
    if "Train" in st.session_state.results or "Test" in st.session_state.results:
        tab1, tab2 = st.tabs(["Training", "Testing"])
        if "Train" in st.session_state.results:
            with tab1: _dev_block(st.session_state.results["Train"], st.session_state.results["m_train"])
        if "Test" in st.session_state.results:
            with tab2: _dev_block(st.session_state.results["Test"],  st.session_state.results["m_test"])
# =========================
# VALIDATION (with actual UCS)
# =========================
if st.session_state.app_step == "validate":
    st.sidebar.header("Validate the Model")
    up = st.sidebar.file_uploader("Upload Validation Excel", type=["xlsx","xls"])
    if up is not None:
        book = read_book_bytes(up.getvalue())
        if book:
            df0 = next(iter(book.values()))
            st.sidebar.caption(f"**Data loaded:** {up.name}{df0.shape[0]} rows × {df0.shape[1]} cols")
    if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
        st.session_state.show_preview_modal = True  # Set state to show modal
    go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
    if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
    if st.sidebar.button("Proceed to Prediction ▶", use_container_width=True): st.session_state.app_step="predict"; st.rerun()
    sticky_header("Validate the Model", "Upload a dataset with the same **features** and **UCS** to evaluate performance.")
    if go_btn and up is not None:
        book = read_book_bytes(up.getvalue())
        name = find_sheet(book, ["Validation","Validate","validation2","Val","val"]) or list(book.keys())[0]
        df = book[name].copy()
        if not ensure_cols(df, FEATURES+[TARGET]): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
        df["UCS_Pred"] = model.predict(df[FEATURES])
        st.session_state.results["Validate"]=df
        ranges = st.session_state.train_ranges; oor_pct = 0.0; tbl=None
        if ranges:
            any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
            oor_pct = float(any_viol.mean()*100.0)
            if any_viol.any():
                tbl = df.loc[any_viol, FEATURES].copy()
                for c in FEATURES:
                    if pd.api.types.is_numeric_dtype(tbl[c]): tbl[c] = tbl[c].round(2)
                tbl["Violations"] = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).loc[any_viol].apply(lambda r:", ".join([c for c,v in r.items() if v]), axis=1)
        st.session_state.results["m_val"]={
            "R": pearson_r(df[TARGET], df["UCS_Pred"]),
            "RMSE": rmse(df[TARGET], df["UCS_Pred"]),
            "MAE": mean_absolute_error(df[TARGET], df["UCS_Pred"])
        }
        st.session_state.results["sv_val"]={"n":len(df),"pred_min":float(df["UCS_Pred"].min()),"pred_max":float(df["UCS_Pred"].max()),"oor":oor_pct}
        st.session_state.results["oor_tbl"]=tbl
    if "Validate" in st.session_state.results:
        m = st.session_state.results["m_val"]
        c1,c2,c3 = st.columns(3)
        c1.metric("R", f"{m['R']:.2f}")
        c2.metric("RMSE", f"{m['RMSE']:.2f}")
        c3.metric("MAE", f"{m['MAE']:.2f}")
        # NEW: Footer for metric abbreviations
        st.markdown("""
            <div style='text-align: left; font-size: 0.8em; color: #6b7280; margin-top: -16px; margin-bottom: 8px;'>
                <strong>R:</strong> Pearson Correlation Coefficient<br>
                <strong>RMSE:</strong> Root Mean Square Error<br>
                <strong>MAE:</strong> Mean Absolute Error
            </div>
        """, unsafe_allow_html=True)
   
        col_cross, col_track = st.columns([3, 2], gap="large")
        with col_cross:
            st.pyplot(
                cross_plot_static(st.session_state.results["Validate"][TARGET],
                                     st.session_state.results["Validate"]["UCS_Pred"]),
                use_container_width=False
            )
        with col_track:
            st.plotly_chart(
                track_plot(st.session_state.results["Validate"], include_actual=True),
                use_container_width=False, # Set to False to honor the width in track_plot()
                config={"displayModeBar": False, "scrollZoom": True}
            )
        sv = st.session_state.results["sv_val"]
        if sv["oor"] > 0: st.markdown('<div class="st-message-box st-warning">Some inputs fall outside **training min–max** ranges.</div>', unsafe_allow_html=True)
        if st.session_state.results["oor_tbl"] is not None:
            st.write("*Out-of-range rows (vs. Training min–max):*")
            df_centered_rounded(st.session_state.results["oor_tbl"])
# =========================
# PREDICTION (no actual UCS)
# =========================
if st.session_state.app_step == "predict":
    st.sidebar.header("Prediction (No Actual UCS)")
    up = st.sidebar.file_uploader("Upload Prediction Excel", type=["xlsx","xls"])
    if up is not None:
        book = read_book_bytes(up.getvalue())
        if book:
            df0 = next(iter(book.values()))
            st.sidebar.caption(f"**Data loaded:** {up.name}{df0.shape[0]} rows × {df0.shape[1]} cols")
    if st.sidebar.button("Preview data", use_container_width=True, disabled=(up is None)):
        st.session_state.show_preview_modal = True  # Set state to show modal
    go_btn = st.sidebar.button("Predict", type="primary", use_container_width=True)
    if st.sidebar.button("⬅ Back to Case Building", use_container_width=True): st.session_state.app_step="dev"; st.rerun()
    sticky_header("Prediction", "Upload a dataset with the feature columns (no **UCS**).")
    if go_btn and up is not None:
        book = read_book_bytes(up.getvalue()); name = list(book.keys())[0]
        df = book[name].copy()
        if not ensure_cols(df, FEATURES): st.markdown('<div class="st-message-box st-error">Missing required columns.</div>', unsafe_allow_html=True); st.stop()
        df["UCS_Pred"] = model.predict(df[FEATURES])
        st.session_state.results["PredictOnly"]=df
        ranges = st.session_state.train_ranges; oor_pct = 0.0
        if ranges:
            any_viol = pd.DataFrame({f:(df[f]<ranges[f][0])|(df[f]>ranges[f][1]) for f in FEATURES}).any(axis=1)
            oor_pct = float(any_viol.mean()*100.0)
        st.session_state.results["sv_pred"]={
            "n":len(df),
            "pred_min":float(df["UCS_Pred"].min()),
            "pred_max":float(df["UCS_Pred"].max()),
            "pred_mean":float(df["UCS_Pred"].mean()),
            "pred_std":float(df["UCS_Pred"].std(ddof=0)),
            "oor":oor_pct
        }
    if "PredictOnly" in st.session_state.results:
        df = st.session_state.results["PredictOnly"]; sv = st.session_state.results["sv_pred"]
        col_left, col_right = st.columns([2,3], gap="large")
        with col_left:
            table = pd.DataFrame({
                "Metric": ["# points","Pred min","Pred max","Pred mean","Pred std","OOR %"],
                "Value":  [sv["n"],
                          round(sv["pred_min"],2),
                          round(sv["pred_max"],2),
                          round(sv["pred_mean"],2),
                          round(sv["pred_std"],2),
                          f'{sv["oor"]:.1f}%']
            })
            st.markdown('<div class="st-message-box st-success">Predictions ready ✓</div>', unsafe_allow_html=True)
            df_centered_rounded(table, hide_index=True)
            st.caption("**★ OOR** = % of rows whose input features fall outside the training min–max range.")
        with col_right:
            st.plotly_chart(
                track_plot(df, include_actual=False),
                use_container_width=False, # Set to False to honor the width in track_plot()
                config={"displayModeBar": False, "scrollZoom": True}
            )
# =========================
# Run preview modal after all other elements
# =========================
if st.session_state.show_preview_modal:
    # Get the correct book based on the current app step
    book_to_preview = {}
    if st.session_state.app_step == "dev":
        book_to_preview = read_book_bytes(st.session_state.dev_file_bytes)
    elif st.session_state.app_step in ["validate", "predict"] and up is not None:
        book_to_preview = read_book_bytes(up.getvalue())
    # Use a try-except block to handle cases where 'up' might be None
    # and the logic tries to access its attributes.
    try:
        if st.session_state.app_step == "validate" and up is not None:
              book_to_preview = read_book_bytes(up.getvalue())
        elif st.session_state.app_step == "predict" and up is not None:
              book_to_preview = read_book_bytes(up.getvalue())
    except NameError:
        book_to_preview = {}
    with st.expander("Preview data", expanded=True):
        if not book_to_preview:
            st.markdown('<div class="st-message-box">No data loaded yet.</div>', unsafe_allow_html=True)
        else:
            names = list(book_to_preview.keys())
            tabs = st.tabs(names)
            for t, name in zip(tabs, names):
                with t:
                    df = book_to_preview[name]
                    t1, t2 = st.tabs(["Tracks", "Summary"])
                    with t1:
                        st.pyplot(preview_tracks(df, FEATURES), use_container_width=True)
                    with t2:
                        tbl = (df[FEATURES]
                                 .agg(['min','max','mean','std'])
                                 .T.rename(columns={"min":"Min","max":"Max","mean":"Mean","std":"Std"}))
                        df_centered_rounded(tbl.reset_index(names="Feature"))
    # Reset the state variable after the modal is displayed
    st.session_state.show_preview_modal = False
# =========================
# Footer
# =========================
st.markdown("---")
st.markdown(
    """
    <div style='text-align:center; color:#6b7280; line-height:1.6'>
      ST_GeoMech_UCS • © Smart Thinking<br/>
      <strong>Visit our website:</strong> <a href='https://www.smartthinking.com.sa' target='_blank'>smartthinking.com.sa</a>
    </div>
    """,
    unsafe_allow_html=True
)