Spaces:
Sleeping
Sleeping
Upload app5.py
Browse files
app5.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# streamlit_app.py
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import joblib
|
| 6 |
+
from typing import Tuple, Dict, Any, List
|
| 7 |
+
|
| 8 |
+
st.set_page_config(page_title="USA Salary — Synthetic/Hybrid Prediction", layout="wide")
|
| 9 |
+
|
| 10 |
+
# --- Top nav to teammates' apps ---
|
| 11 |
+
c1, c2, c3 = st.columns(3)
|
| 12 |
+
with c1:
|
| 13 |
+
st.link_button("Hamna", "https://example.com/hamna") # TODO: replace with real link
|
| 14 |
+
with c2:
|
| 15 |
+
st.link_button("Mahesh", "https://example.com/mahesh") # TODO: replace with real link
|
| 16 |
+
with c3:
|
| 17 |
+
st.link_button("Tian", "https://example.com/tian") # TODO: replace with real link
|
| 18 |
+
|
| 19 |
+
st.title("USA Salary — Predict with Hybrid/Synthetic Inputs")
|
| 20 |
+
|
| 21 |
+
# =================== Load assets ===================
|
| 22 |
+
@st.cache_resource(show_spinner=False)
|
| 23 |
+
def _load_assets():
|
| 24 |
+
pipe = joblib.load("final_xgbr_usa_model.pkl")
|
| 25 |
+
usa_2024 = pd.read_csv("usa_salary_data.csv")
|
| 26 |
+
usa_2025 = pd.read_csv("2025_survey.csv")
|
| 27 |
+
return pipe, usa_2024, usa_2025
|
| 28 |
+
|
| 29 |
+
pipe, usa_data, usa_25 = _load_assets()
|
| 30 |
+
LABEL = "CompTotal"
|
| 31 |
+
|
| 32 |
+
# =================== RNG utilities ===================
|
| 33 |
+
def _ensure_rng():
|
| 34 |
+
if "rng_seedseq" not in st.session_state:
|
| 35 |
+
st.session_state.rng_seedseq = np.random.SeedSequence()
|
| 36 |
+
child = st.session_state.rng_seedseq.spawn(1)[0]
|
| 37 |
+
return np.random.default_rng(child)
|
| 38 |
+
|
| 39 |
+
def new_rng():
|
| 40 |
+
st.session_state.rng_seedseq = st.session_state.rng_seedseq.spawn(1)[0]
|
| 41 |
+
return np.random.default_rng(st.session_state.rng_seedseq)
|
| 42 |
+
|
| 43 |
+
# =================== Precompute dropdown choices (ONCE) ===================
|
| 44 |
+
@st.cache_resource(show_spinner=False)
|
| 45 |
+
def _precompute_choices(usa_2024: pd.DataFrame, usa_2025: pd.DataFrame, label: str) -> Dict[str, List]:
|
| 46 |
+
CHOICES: Dict[str, List] = {}
|
| 47 |
+
both = pd.concat([usa_2024, usa_2025.reindex(columns=usa_2024.columns, fill_value=np.nan)],
|
| 48 |
+
axis=0, ignore_index=True)
|
| 49 |
+
for col in usa_2024.columns:
|
| 50 |
+
if col == label:
|
| 51 |
+
continue
|
| 52 |
+
s = both[col]
|
| 53 |
+
if pd.api.types.is_numeric_dtype(usa_2024[col]):
|
| 54 |
+
s_num = pd.to_numeric(s, errors="coerce").dropna()
|
| 55 |
+
if len(s_num) == 0:
|
| 56 |
+
CHOICES[col] = [0]
|
| 57 |
+
continue
|
| 58 |
+
q_list = [5,10,20,30,40,50,60,70,80,90,95]
|
| 59 |
+
qs = np.percentile(s_num, q_list).tolist()
|
| 60 |
+
rounded = (np.round(s_num / 1000) * 1000).astype(int)
|
| 61 |
+
top_round = rounded.value_counts().head(10).index.astype(int).tolist()
|
| 62 |
+
merged = sorted(set(int(v) for v in qs + top_round))
|
| 63 |
+
CHOICES[col] = merged[:50]
|
| 64 |
+
else:
|
| 65 |
+
s_cat = s.astype(str).replace({"nan": None}).dropna()
|
| 66 |
+
CHOICES[col] = (s_cat.value_counts().head(40).index.tolist()
|
| 67 |
+
if len(s_cat) else [""])
|
| 68 |
+
return CHOICES
|
| 69 |
+
|
| 70 |
+
CHOICES_DICT = _precompute_choices(usa_data, usa_25, LABEL)
|
| 71 |
+
|
| 72 |
+
# =================== Core sampling funcs ===================
|
| 73 |
+
def _sample_from_2024(colname: str, df_2024: pd.DataFrame, rng: np.random.Generator):
|
| 74 |
+
series = df_2024[colname].dropna()
|
| 75 |
+
if series.empty:
|
| 76 |
+
return 0 if pd.api.types.is_numeric_dtype(df_2024[colname]) else ""
|
| 77 |
+
return series.sample(1, random_state=rng.integers(0, 10_000)).iloc[0]
|
| 78 |
+
|
| 79 |
+
def build_synthetic_row_with_trace(
|
| 80 |
+
usa_25: pd.DataFrame,
|
| 81 |
+
usa_2024: pd.DataFrame,
|
| 82 |
+
label: str = "CompTotal",
|
| 83 |
+
rng: np.random.Generator | None = None
|
| 84 |
+
) -> Tuple[pd.DataFrame, float | None, Dict[str, str], Dict[str, Any]]:
|
| 85 |
+
if rng is None:
|
| 86 |
+
rng = _ensure_rng()
|
| 87 |
+
expected_features = [c for c in usa_2024.columns if c != label]
|
| 88 |
+
row25 = usa_25.sample(1, random_state=rng.integers(0, 10_000)).iloc[0]
|
| 89 |
+
|
| 90 |
+
synthetic, source_info = {}, {}
|
| 91 |
+
for col in expected_features:
|
| 92 |
+
use_25_val, val = False, None
|
| 93 |
+
if col in row25.index:
|
| 94 |
+
val = row25[col]
|
| 95 |
+
if pd.api.types.is_numeric_dtype(usa_2024[col]):
|
| 96 |
+
val = pd.to_numeric(val, errors="coerce")
|
| 97 |
+
if not pd.isna(val):
|
| 98 |
+
use_25_val = True
|
| 99 |
+
if not use_25_val:
|
| 100 |
+
val = _sample_from_2024(col, usa_2024, rng)
|
| 101 |
+
source_info[col] = "2024"
|
| 102 |
+
else:
|
| 103 |
+
source_info[col] = "2025"
|
| 104 |
+
|
| 105 |
+
if pd.api.types.is_numeric_dtype(usa_2024[col]):
|
| 106 |
+
val = pd.to_numeric(val, errors="coerce")
|
| 107 |
+
if pd.isna(val):
|
| 108 |
+
val = _sample_from_2024(col, usa_2024, rng); source_info[col] = "2024"
|
| 109 |
+
else:
|
| 110 |
+
if pd.isna(val):
|
| 111 |
+
val = _sample_from_2024(col, usa_2024, rng); source_info[col] = "2024"
|
| 112 |
+
if not isinstance(val, str):
|
| 113 |
+
val = str(val)
|
| 114 |
+
synthetic[col] = val
|
| 115 |
+
|
| 116 |
+
X_one = pd.DataFrame([synthetic], columns=expected_features)
|
| 117 |
+
|
| 118 |
+
y_true = None
|
| 119 |
+
if label in row25.index:
|
| 120 |
+
y_true = pd.to_numeric(row25[label], errors="coerce")
|
| 121 |
+
if pd.isna(y_true): y_true = None
|
| 122 |
+
|
| 123 |
+
used_from_2025 = sum(v == "2025" for v in source_info.values())
|
| 124 |
+
used_from_2024 = sum(v == "2024" for v in source_info.values())
|
| 125 |
+
total = len(source_info)
|
| 126 |
+
report = dict(
|
| 127 |
+
total_features=total,
|
| 128 |
+
n_2025=used_from_2025,
|
| 129 |
+
n_2024=used_from_2024,
|
| 130 |
+
pct_2025=0.0 if total == 0 else used_from_2025 / total * 100,
|
| 131 |
+
pct_2024=0.0 if total == 0 else used_from_2024 / total * 100,
|
| 132 |
+
filled_cols=[c for c, s in source_info.items() if s == "2024"]
|
| 133 |
+
)
|
| 134 |
+
return X_one, y_true, source_info, report
|
| 135 |
+
|
| 136 |
+
def random_sample_row_2024(df_2024: pd.DataFrame, label: str, rng: np.random.Generator) -> pd.DataFrame:
|
| 137 |
+
expected_features = [c for c in df_2024.columns if c != label]
|
| 138 |
+
sampled = {col: _sample_from_2024(col, df_2024, rng) for col in expected_features}
|
| 139 |
+
return pd.DataFrame([sampled], columns=expected_features)
|
| 140 |
+
|
| 141 |
+
# =================== Session state ===================
|
| 142 |
+
if "history" not in st.session_state:
|
| 143 |
+
st.session_state.history = []
|
| 144 |
+
if "prepared" not in st.session_state:
|
| 145 |
+
st.session_state.prepared = None
|
| 146 |
+
# version token to force editor widgets to refresh when a new candidate is created
|
| 147 |
+
if "edit_version" not in st.session_state:
|
| 148 |
+
st.session_state.edit_version = 0
|
| 149 |
+
|
| 150 |
+
def _bump_edit_version():
|
| 151 |
+
st.session_state.edit_version += 1
|
| 152 |
+
|
| 153 |
+
# =================== Sidebar ===================
|
| 154 |
+
st.sidebar.header("Controls")
|
| 155 |
+
mode = st.sidebar.radio(
|
| 156 |
+
"Choose input mode:",
|
| 157 |
+
["Hybrid (Random 2025 + fill from 2024)", "Pure 2024 synthetic"],
|
| 158 |
+
index=0
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if st.sidebar.button("Reload new random data"):
|
| 162 |
+
rng = new_rng()
|
| 163 |
+
if mode.startswith("Hybrid"):
|
| 164 |
+
X_one, y_true, src_info, rep = build_synthetic_row_with_trace(usa_25, usa_data, label=LABEL, rng=rng)
|
| 165 |
+
st.session_state.prepared = dict(X=X_one, y_true=y_true, info=src_info, report=rep, mode="hybrid")
|
| 166 |
+
else:
|
| 167 |
+
X_2024 = random_sample_row_2024(usa_data, LABEL, rng)
|
| 168 |
+
st.session_state.prepared = dict(X=X_2024, y_true=None, info={}, report={}, mode="pure2024")
|
| 169 |
+
_bump_edit_version()
|
| 170 |
+
st.toast("New random input prepared.", icon="✅")
|
| 171 |
+
|
| 172 |
+
# =================== Explanation ===================
|
| 173 |
+
with st.expander("What is happening here? (plain English)"):
|
| 174 |
+
st.markdown(
|
| 175 |
+
"""
|
| 176 |
+
**Goal:** USA coder annual income prediction.
|
| 177 |
+
|
| 178 |
+
**Modes:**
|
| 179 |
+
1) **Hybrid** — Random 2025 row; missing required features are filled with samples from the 2024 distribution.
|
| 180 |
+
2) **Pure 2024** — Entire row sampled from the 2024 distribution.
|
| 181 |
+
|
| 182 |
+
**Editing flow:**
|
| 183 |
+
Use the editors below to adjust values. Then click **Apply edits** (this commits your selections to the current row).
|
| 184 |
+
Finally click **Submit & Predict**.
|
| 185 |
+
"""
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
st.markdown("---")
|
| 189 |
+
|
| 190 |
+
# =================== Prepare first candidate if needed ===================
|
| 191 |
+
def _prepare_candidate_if_needed():
|
| 192 |
+
if st.session_state.prepared is None:
|
| 193 |
+
rng = _ensure_rng()
|
| 194 |
+
if mode.startswith("Hybrid"):
|
| 195 |
+
X_one, y_true, src_info, rep = build_synthetic_row_with_trace(usa_25, usa_data, label=LABEL, rng=rng)
|
| 196 |
+
st.session_state.prepared = dict(X=X_one, y_true=y_true, info=src_info, report=rep, mode="hybrid")
|
| 197 |
+
else:
|
| 198 |
+
X_2024 = random_sample_row_2024(usa_data, LABEL, rng)
|
| 199 |
+
st.session_state.prepared = dict(X=X_2024, y_true=None, info={}, report={}, mode="pure2024")
|
| 200 |
+
_bump_edit_version()
|
| 201 |
+
|
| 202 |
+
_prepare_candidate_if_needed()
|
| 203 |
+
|
| 204 |
+
# =================== Helper: coerce edited values ===================
|
| 205 |
+
def _coerce_value(col: str, val):
|
| 206 |
+
if pd.api.types.is_numeric_dtype(usa_data[col]):
|
| 207 |
+
try:
|
| 208 |
+
return pd.to_numeric(val)
|
| 209 |
+
except Exception:
|
| 210 |
+
return np.nan
|
| 211 |
+
else:
|
| 212 |
+
return "" if val is None else str(val)
|
| 213 |
+
# =================== Predict & History ===================
|
| 214 |
+
left, right = st.columns([1.1, 0.9], gap="large")
|
| 215 |
+
|
| 216 |
+
with left:
|
| 217 |
+
st.subheader("Submit a prediction")
|
| 218 |
+
|
| 219 |
+
curr = st.session_state.prepared
|
| 220 |
+
mode_tag = "Hybrid (2025+2024)" if curr["mode"] == "hybrid" else "Pure 2024"
|
| 221 |
+
st.caption(f"Next input mode: **{mode_tag}**")
|
| 222 |
+
|
| 223 |
+
with st.expander("Show current input row (after Apply)", expanded=False):
|
| 224 |
+
st.dataframe(curr["X"].T.rename(columns={0: "value"}))
|
| 225 |
+
|
| 226 |
+
if curr["mode"] == "hybrid":
|
| 227 |
+
rep = curr["report"]
|
| 228 |
+
st.caption(
|
| 229 |
+
f"Data completion: **{rep['n_2025']}** from 2025 "
|
| 230 |
+
f"({rep['pct_2025']:.1f}%); **{rep['n_2024']}** from 2024 "
|
| 231 |
+
f"({rep['pct_2024']:.1f}%)"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
submitted = st.button("Submit & Predict", type="primary", use_container_width=True)
|
| 235 |
+
|
| 236 |
+
if submitted:
|
| 237 |
+
X_one = curr["X"].copy()
|
| 238 |
+
y_true = curr["y_true"]
|
| 239 |
+
|
| 240 |
+
# dtype alignment
|
| 241 |
+
for col in X_one.columns:
|
| 242 |
+
if pd.api.types.is_numeric_dtype(usa_data[col]):
|
| 243 |
+
X_one[col] = pd.to_numeric(X_one[col], errors="coerce")
|
| 244 |
+
else:
|
| 245 |
+
X_one[col] = X_one[col].astype(str).fillna("")
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
y_pred = float(pipe.predict(X_one)[0])
|
| 249 |
+
except Exception as e:
|
| 250 |
+
st.error(f"Prediction failed: {e}")
|
| 251 |
+
y_pred = None
|
| 252 |
+
|
| 253 |
+
if y_pred is not None:
|
| 254 |
+
st.success(f"**Predicted CompTotal:** {y_pred:,.0f} USD")
|
| 255 |
+
if y_true is not None:
|
| 256 |
+
st.info(f"**2025 true:** {y_true:,.0f} USD")
|
| 257 |
+
abs_err = abs(y_pred - y_true)
|
| 258 |
+
pct_err = abs_err / y_true * 100 if y_true != 0 else np.nan
|
| 259 |
+
st.write(f"**Absolute error:** {abs_err:,.0f} USD")
|
| 260 |
+
st.write(f"**Percentage error:** {pct_err:.2f}%")
|
| 261 |
+
st.session_state.history.append(
|
| 262 |
+
dict(pred=y_pred, truth=y_true, abs_err=abs_err, pct_err=pct_err)
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
st.warning("No ground-truth value available for this input (pure 2024 synthetic).")
|
| 266 |
+
|
| 267 |
+
# auto-prepare a new candidate after submit
|
| 268 |
+
rng = new_rng()
|
| 269 |
+
if mode.startswith("Hybrid"):
|
| 270 |
+
X_one2, y_true2, src_info2, rep2 = build_synthetic_row_with_trace(usa_25, usa_data, label=LABEL, rng=rng)
|
| 271 |
+
st.session_state.prepared = dict(X=X_one2, y_true=y_true2, info=src_info2, report=rep2, mode="hybrid")
|
| 272 |
+
else:
|
| 273 |
+
X_2024b = random_sample_row_2024(usa_data, LABEL, rng)
|
| 274 |
+
st.session_state.prepared = dict(X=X_2024b, y_true=None, info={}, report={}, mode="pure2024")
|
| 275 |
+
_bump_edit_version() # refresh editor defaults for the new candidate
|
| 276 |
+
st.toast("New random input prepared.", icon="✨")
|
| 277 |
+
|
| 278 |
+
with right:
|
| 279 |
+
st.subheader("Results history")
|
| 280 |
+
if len(st.session_state.history) == 0:
|
| 281 |
+
st.write("No submissions yet.")
|
| 282 |
+
else:
|
| 283 |
+
hist_df = pd.DataFrame(st.session_state.history)
|
| 284 |
+
st.dataframe(
|
| 285 |
+
hist_df.style.format({"pred": "{:,.0f}", "truth": "{:,.0f}", "abs_err": "{:,.0f}", "pct_err": "{:.2f}"}),
|
| 286 |
+
use_container_width=True
|
| 287 |
+
)
|
| 288 |
+
valid = hist_df.dropna(subset=["truth"])
|
| 289 |
+
if len(valid) > 0:
|
| 290 |
+
mae = valid["abs_err"].mean()
|
| 291 |
+
mape = valid["pct_err"].mean()
|
| 292 |
+
st.metric(label="Mean Absolute Error (USD)", value=f"{mae:,.0f}")
|
| 293 |
+
st.metric(label="Mean Absolute Percentage Error", value=f"{mape:.2f}%")
|
| 294 |
+
else:
|
| 295 |
+
st.write("No entries with ground truth yet.")
|
| 296 |
+
|
| 297 |
+
st.markdown("---")
|
| 298 |
+
st.caption("Workflow: Edit → **Apply edits** → **Submit & Predict**. Use the sidebar to reload a fresh random row.")
|
| 299 |
+
# =================== Editable input row (FORM + Apply button) ===================
|
| 300 |
+
st.subheader("Edit current input (optional)")
|
| 301 |
+
|
| 302 |
+
if st.session_state.prepared is not None:
|
| 303 |
+
curr = st.session_state.prepared
|
| 304 |
+
X_row = curr["X"].iloc[0] # view (we’ll copy when applying)
|
| 305 |
+
version = st.session_state.edit_version
|
| 306 |
+
|
| 307 |
+
with st.form(key=f"edit_form_{version}", clear_on_submit=False):
|
| 308 |
+
|
| 309 |
+
applied = st.form_submit_button("Apply edits", use_container_width=True,type="primary")
|
| 310 |
+
exp1 = st.expander("Categorical features", expanded=True)
|
| 311 |
+
exp2 = st.expander("Numeric features", expanded=False)
|
| 312 |
+
|
| 313 |
+
cat_cols = [c for c in X_row.index if not pd.api.types.is_numeric_dtype(usa_data[c])]
|
| 314 |
+
num_cols = [c for c in X_row.index if pd.api.types.is_numeric_dtype(usa_data[c])]
|
| 315 |
+
|
| 316 |
+
with exp1:
|
| 317 |
+
st.caption("Pick from common categories (precomputed).")
|
| 318 |
+
for col in cat_cols:
|
| 319 |
+
choices = CHOICES_DICT.get(col, [])
|
| 320 |
+
curr_val = "" if pd.isna(X_row[col]) else str(X_row[col])
|
| 321 |
+
if curr_val not in choices and curr_val != "":
|
| 322 |
+
choices = [curr_val] + choices
|
| 323 |
+
sel = st.selectbox(
|
| 324 |
+
label=col,
|
| 325 |
+
options=choices if len(choices) > 0 else [""],
|
| 326 |
+
index=(choices.index(curr_val) if curr_val in choices and len(choices)>0 else 0),
|
| 327 |
+
key=f"edit_cat_{col}_{version}",
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
with exp2:
|
| 331 |
+
st.caption("Pick typical numeric values (percentiles/rounded).")
|
| 332 |
+
for col in num_cols:
|
| 333 |
+
choices = CHOICES_DICT.get(col, [])
|
| 334 |
+
curr_val = X_row[col]
|
| 335 |
+
if pd.isna(curr_val):
|
| 336 |
+
curr_val = choices[0] if len(choices) else 0
|
| 337 |
+
if len(choices) == 0:
|
| 338 |
+
choices = [curr_val]
|
| 339 |
+
elif curr_val not in choices:
|
| 340 |
+
choices = [curr_val] + choices
|
| 341 |
+
sel = st.selectbox(
|
| 342 |
+
label=col,
|
| 343 |
+
options=choices,
|
| 344 |
+
index=(choices.index(curr_val) if curr_val in choices else 0),
|
| 345 |
+
key=f"edit_num_{col}_{version}",
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if applied:
|
| 349 |
+
# Build a fresh row dict from widget states
|
| 350 |
+
new_row = {}
|
| 351 |
+
for col in X_row.index:
|
| 352 |
+
key = (f"edit_cat_{col}_{version}"
|
| 353 |
+
if not pd.api.types.is_numeric_dtype(usa_data[col])
|
| 354 |
+
else f"edit_num_{col}_{version}")
|
| 355 |
+
if key in st.session_state:
|
| 356 |
+
new_row[col] = _coerce_value(col, st.session_state[key])
|
| 357 |
+
else:
|
| 358 |
+
new_row[col] = X_row[col]
|
| 359 |
+
|
| 360 |
+
# IMPORTANT: replace the whole DF to avoid chained-assignment pitfalls
|
| 361 |
+
new_df = pd.DataFrame([new_row], columns=curr["X"].columns)
|
| 362 |
+
|
| 363 |
+
# Commit atomically (don’t mutate nested objects in-place)
|
| 364 |
+
st.session_state.prepared = {
|
| 365 |
+
**st.session_state.prepared,
|
| 366 |
+
"X": new_df
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
st.toast("Edits applied to current input.", icon="✍️")
|
| 370 |
+
st.rerun() # ensure the preview below reflects the new DF immediately
|
| 371 |
+
|
| 372 |
+
|