Update src/streamlit_app.py
Browse files- src/streamlit_app.py +80 -48
src/streamlit_app.py
CHANGED
|
@@ -24,13 +24,29 @@ from sklearn.metrics import mean_squared_error, r2_score
|
|
| 24 |
|
| 25 |
# SHAP
|
| 26 |
import shap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
if "llm_result" not in st.session_state:
|
| 28 |
st.session_state["llm_result"] = None
|
| 29 |
if "automl_summary" not in st.session_state:
|
| 30 |
st.session_state["automl_summary"] = {}
|
| 31 |
if "shap_recommendations" not in st.session_state:
|
| 32 |
st.session_state["shap_recommendations"] = []
|
| 33 |
-
|
|
|
|
| 34 |
|
| 35 |
# -------------------------
|
| 36 |
# Config & paths
|
|
@@ -667,8 +683,13 @@ with tabs[4]:
|
|
| 667 |
return {"model_obj": model, "cv_score": score, "best_params": best, "family": family_name, "study": study}
|
| 668 |
|
| 669 |
# --- Run tuning across available families (user triggered) ---
|
| 670 |
-
|
| 671 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
log("AutoML + Stacking initiated.")
|
| 673 |
with st.spinner("Tuning multiple families (this may take a while depending on choices)..."):
|
| 674 |
families_to_try = ["RandomForest", "ExtraTrees", "MLP"]
|
|
@@ -937,7 +958,7 @@ with tabs[4]:
|
|
| 937 |
recommended_shift[name] *= 0.97 # -3%
|
| 938 |
|
| 939 |
# Delta table
|
| 940 |
-
st.markdown("###
|
| 941 |
deltas = pd.DataFrame({
|
| 942 |
"Current Avg": prev_shift,
|
| 943 |
"Suggested": recommended_shift,
|
|
@@ -1052,68 +1073,79 @@ with tabs[4]:
|
|
| 1052 |
else:
|
| 1053 |
st.session_state["shap_recommendations"] = recommendations
|
| 1054 |
|
| 1055 |
-
# --- AI Recommendation Assistant
|
| 1056 |
st.markdown("---")
|
| 1057 |
-
st.subheader("AI Recommendation Assistant
|
| 1058 |
st.caption("Generates quick local AI suggestions — no file writes required.")
|
| 1059 |
|
|
|
|
|
|
|
|
|
|
| 1060 |
if "llm_result" not in st.session_state:
|
| 1061 |
st.session_state["llm_result"] = None
|
| 1062 |
|
| 1063 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
summary = st.session_state.get("automl_summary", {})
|
| 1065 |
if not summary:
|
| 1066 |
st.warning("Please run AutoML first to generate context.")
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
| 1073 |
-
headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
|
| 1074 |
-
|
| 1075 |
-
prompt = f"""
|
| 1076 |
-
You are an ML model tuning advisor.
|
| 1077 |
-
Based on this AutoML summary, suggest 3 concise, actionable steps
|
| 1078 |
-
to improve model performance if overfitting, underfitting, or data-quality issues are observed.
|
| 1079 |
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
|
|
|
|
|
|
| 1091 |
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
|
|
|
| 1095 |
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
st.error(err_msg)
|
| 1111 |
-
log(err_msg)
|
| 1112 |
|
| 1113 |
|
| 1114 |
-
|
|
|
|
| 1115 |
if st.session_state["llm_result"]:
|
| 1116 |
-
st.
|
| 1117 |
st.markdown(st.session_state["llm_result"])
|
| 1118 |
|
| 1119 |
|
|
|
|
| 24 |
|
| 25 |
# SHAP
|
| 26 |
import shap
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# --- Safe defaults for Streamlit session state ---
|
| 30 |
+
defaults = {
|
| 31 |
+
"llm_result": None,
|
| 32 |
+
"automl_summary": {},
|
| 33 |
+
"shap_recommendations": [],
|
| 34 |
+
"hf_clicked": False,
|
| 35 |
+
"hf_ran_once": False,
|
| 36 |
+
"run_automl_clicked": False,
|
| 37 |
+
}
|
| 38 |
+
for k, v in defaults.items():
|
| 39 |
+
st.session_state.setdefault(k, v)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
if "llm_result" not in st.session_state:
|
| 43 |
st.session_state["llm_result"] = None
|
| 44 |
if "automl_summary" not in st.session_state:
|
| 45 |
st.session_state["automl_summary"] = {}
|
| 46 |
if "shap_recommendations" not in st.session_state:
|
| 47 |
st.session_state["shap_recommendations"] = []
|
| 48 |
+
if "hf_clicked" not in st.session_state:
|
| 49 |
+
st.session_state["hf_clicked"] = False
|
| 50 |
|
| 51 |
# -------------------------
|
| 52 |
# Config & paths
|
|
|
|
| 683 |
return {"model_obj": model, "cv_score": score, "best_params": best, "family": family_name, "study": study}
|
| 684 |
|
| 685 |
# --- Run tuning across available families (user triggered) ---
|
| 686 |
+
if "run_automl_clicked" not in st.session_state:
|
| 687 |
+
st.session_state["run_automl_clicked"] = False
|
| 688 |
+
|
| 689 |
+
if st.button("Run expanded AutoML + Stacking"):
|
| 690 |
+
st.session_state["run_automl_clicked"] = True
|
| 691 |
+
|
| 692 |
+
if st.session_state["run_automl_clicked"]:
|
| 693 |
log("AutoML + Stacking initiated.")
|
| 694 |
with st.spinner("Tuning multiple families (this may take a while depending on choices)..."):
|
| 695 |
families_to_try = ["RandomForest", "ExtraTrees", "MLP"]
|
|
|
|
| 958 |
recommended_shift[name] *= 0.97 # -3%
|
| 959 |
|
| 960 |
# Delta table
|
| 961 |
+
st.markdown("### Shift Adjustment Summary (vs Previous 200 Samples)")
|
| 962 |
deltas = pd.DataFrame({
|
| 963 |
"Current Avg": prev_shift,
|
| 964 |
"Suggested": recommended_shift,
|
|
|
|
| 1073 |
else:
|
| 1074 |
st.session_state["shap_recommendations"] = recommendations
|
| 1075 |
|
| 1076 |
+
# --- AI Recommendation Assistant ---
|
| 1077 |
st.markdown("---")
|
| 1078 |
+
st.subheader("AI Recommendation Assistant ")
|
| 1079 |
st.caption("Generates quick local AI suggestions — no file writes required.")
|
| 1080 |
|
| 1081 |
+
# Create or reset button states safely
|
| 1082 |
+
if "hf_clicked" not in st.session_state:
|
| 1083 |
+
st.session_state["hf_clicked"] = False
|
| 1084 |
if "llm_result" not in st.session_state:
|
| 1085 |
st.session_state["llm_result"] = None
|
| 1086 |
|
| 1087 |
+
# --- Buttons ---
|
| 1088 |
+
col1, col2 = st.columns(2)
|
| 1089 |
+
# Click handlers with isolated session flags
|
| 1090 |
+
if col1.button("Get AI Recommendation (via HF API)", key="ai_reco"):
|
| 1091 |
+
st.session_state["hf_clicked"] = True
|
| 1092 |
+
st.session_state["hf_ran_once"] = False # reset internal control
|
| 1093 |
+
|
| 1094 |
+
if col2.button("Reset Recommendation Output"):
|
| 1095 |
+
st.session_state["hf_clicked"] = False
|
| 1096 |
+
st.session_state["llm_result"] = None
|
| 1097 |
+
st.session_state["hf_ran_once"] = False
|
| 1098 |
+
st.info("Recommendation output cleared.")
|
| 1099 |
+
|
| 1100 |
+
# Execute API call only once
|
| 1101 |
+
if st.session_state["hf_clicked"] and not st.session_state.get("hf_ran_once", False):
|
| 1102 |
summary = st.session_state.get("automl_summary", {})
|
| 1103 |
if not summary:
|
| 1104 |
st.warning("Please run AutoML first to generate context.")
|
| 1105 |
+
else:
|
| 1106 |
+
try:
|
| 1107 |
+
import requests, json
|
| 1108 |
+
st.info("Contacting Hugging Face Inference API (Mixtral-8x7B-Instruct)…")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1109 |
|
| 1110 |
+
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
| 1111 |
+
headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
|
| 1112 |
+
prompt = f"""
|
| 1113 |
+
You are an ML model tuning advisor.
|
| 1114 |
+
Based on this AutoML summary, suggest 3 concise, actionable steps
|
| 1115 |
+
to improve model performance if overfitting, underfitting, or data-quality issues are observed.
|
| 1116 |
|
| 1117 |
+
Use case: {summary.get('use_case')}
|
| 1118 |
+
Target: {summary.get('target')}
|
| 1119 |
+
Final R²: {summary.get('final_r2')}
|
| 1120 |
+
Final RMSE: {summary.get('final_rmse')}
|
| 1121 |
+
Leaderboard: {summary.get('leaderboard')}
|
| 1122 |
+
"""
|
| 1123 |
|
| 1124 |
+
payload = {"inputs": prompt, "parameters": {"max_new_tokens": 200, "temperature": 0.7}}
|
| 1125 |
+
response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
|
| 1126 |
+
response.raise_for_status()
|
| 1127 |
+
result = response.json()
|
| 1128 |
|
| 1129 |
+
if isinstance(result, list) and "generated_text" in result[0]:
|
| 1130 |
+
text = result[0]["generated_text"]
|
| 1131 |
+
elif isinstance(result, dict) and "generated_text" in result:
|
| 1132 |
+
text = result["generated_text"]
|
| 1133 |
+
else:
|
| 1134 |
+
text = json.dumps(result, indent=2)
|
| 1135 |
|
| 1136 |
+
st.session_state["llm_result"] = text.strip()
|
| 1137 |
+
st.session_state["hf_ran_once"] = True
|
| 1138 |
+
st.success("✅ AI Recommendation (Mixtral-8x7B-Instruct):")
|
| 1139 |
+
st.markdown(st.session_state["llm_result"])
|
| 1140 |
|
| 1141 |
+
except Exception as e:
|
| 1142 |
+
st.error(f"HF Inference API call failed: {e}")
|
|
|
|
|
|
|
| 1143 |
|
| 1144 |
|
| 1145 |
+
|
| 1146 |
+
# --- Always display cached result, even on rerun ---
|
| 1147 |
if st.session_state["llm_result"]:
|
| 1148 |
+
st.markdown("### Cached AI Recommendation:")
|
| 1149 |
st.markdown(st.session_state["llm_result"])
|
| 1150 |
|
| 1151 |
|