singhn9 commited on
Commit
54e00e0
·
verified ·
1 Parent(s): 7986d55

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +80 -48
src/streamlit_app.py CHANGED
@@ -24,13 +24,29 @@ from sklearn.metrics import mean_squared_error, r2_score
24
 
25
  # SHAP
26
  import shap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if "llm_result" not in st.session_state:
28
  st.session_state["llm_result"] = None
29
  if "automl_summary" not in st.session_state:
30
  st.session_state["automl_summary"] = {}
31
  if "shap_recommendations" not in st.session_state:
32
  st.session_state["shap_recommendations"] = []
33
-
 
34
 
35
  # -------------------------
36
  # Config & paths
@@ -667,8 +683,13 @@ with tabs[4]:
667
  return {"model_obj": model, "cv_score": score, "best_params": best, "family": family_name, "study": study}
668
 
669
  # --- Run tuning across available families (user triggered) ---
670
- run_btn = st.button(" Run expanded AutoML + Stacking")
671
- if run_btn:
 
 
 
 
 
672
  log("AutoML + Stacking initiated.")
673
  with st.spinner("Tuning multiple families (this may take a while depending on choices)..."):
674
  families_to_try = ["RandomForest", "ExtraTrees", "MLP"]
@@ -937,7 +958,7 @@ with tabs[4]:
937
  recommended_shift[name] *= 0.97 # -3%
938
 
939
  # Delta table
940
- st.markdown("### 🧾 Shift Adjustment Summary (vs Previous 200 Samples)")
941
  deltas = pd.DataFrame({
942
  "Current Avg": prev_shift,
943
  "Suggested": recommended_shift,
@@ -1052,68 +1073,79 @@ with tabs[4]:
1052
  else:
1053
  st.session_state["shap_recommendations"] = recommendations
1054
 
1055
- # --- AI Recommendation Assistant (in-memory safe for Hugging Face) ---
1056
  st.markdown("---")
1057
- st.subheader("AI Recommendation Assistant (in-memory mode)")
1058
  st.caption("Generates quick local AI suggestions — no file writes required.")
1059
 
 
 
 
1060
  if "llm_result" not in st.session_state:
1061
  st.session_state["llm_result"] = None
1062
 
1063
- if st.button("Get AI Recommendation (via HF API)", key="ai_reco"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1064
  summary = st.session_state.get("automl_summary", {})
1065
  if not summary:
1066
  st.warning("Please run AutoML first to generate context.")
1067
- st.stop()
1068
- try:
1069
- import requests, json
1070
- st.info("Contacting Hugging Face Inference API (Mixtral-8x7B-Instruct)…")
1071
-
1072
- API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
1073
- headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
1074
-
1075
- prompt = f"""
1076
- You are an ML model tuning advisor.
1077
- Based on this AutoML summary, suggest 3 concise, actionable steps
1078
- to improve model performance if overfitting, underfitting, or data-quality issues are observed.
1079
 
1080
- Use case: {summary.get('use_case')}
1081
- Target: {summary.get('target')}
1082
- Final R²: {summary.get('final_r2')}
1083
- Final RMSE: {summary.get('final_rmse')}
1084
- Leaderboard: {summary.get('leaderboard')}
1085
- """
1086
 
1087
- payload = {
1088
- "inputs": prompt,
1089
- "parameters": {"max_new_tokens": 200, "temperature": 0.7}
1090
- }
 
 
1091
 
1092
- response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
1093
- response.raise_for_status()
1094
- result = response.json()
 
1095
 
1096
- if isinstance(result, list) and "generated_text" in result[0]:
1097
- text = result[0]["generated_text"]
1098
- elif isinstance(result, dict) and "generated_text" in result:
1099
- text = result["generated_text"]
1100
- else:
1101
- text = json.dumps(result, indent=2)
1102
 
1103
- st.session_state["llm_result"] = text.strip()
1104
- log("HF API recommendation generated successfully.")
1105
- st.success("AI Recommendation (Mixtral-8x7B-Instruct):")
1106
- st.markdown(st.session_state["llm_result"])
1107
 
1108
- except Exception as e:
1109
- err_msg = f"HF Inference API call failed: {e}"
1110
- st.error(err_msg)
1111
- log(err_msg)
1112
 
1113
 
1114
- # Persist output even after rerun
 
1115
  if st.session_state["llm_result"]:
1116
- st.success("AI Recommendation (cached):")
1117
  st.markdown(st.session_state["llm_result"])
1118
 
1119
 
 
24
 
25
  # SHAP
26
  import shap
27
+
28
+
29
+ # --- Safe defaults for Streamlit session state ---
30
+ defaults = {
31
+ "llm_result": None,
32
+ "automl_summary": {},
33
+ "shap_recommendations": [],
34
+ "hf_clicked": False,
35
+ "hf_ran_once": False,
36
+ "run_automl_clicked": False,
37
+ }
38
+ for k, v in defaults.items():
39
+ st.session_state.setdefault(k, v)
40
+
41
+
42
  if "llm_result" not in st.session_state:
43
  st.session_state["llm_result"] = None
44
  if "automl_summary" not in st.session_state:
45
  st.session_state["automl_summary"] = {}
46
  if "shap_recommendations" not in st.session_state:
47
  st.session_state["shap_recommendations"] = []
48
+ if "hf_clicked" not in st.session_state:
49
+ st.session_state["hf_clicked"] = False
50
 
51
  # -------------------------
52
  # Config & paths
 
683
  return {"model_obj": model, "cv_score": score, "best_params": best, "family": family_name, "study": study}
684
 
685
  # --- Run tuning across available families (user triggered) ---
686
+ if "run_automl_clicked" not in st.session_state:
687
+ st.session_state["run_automl_clicked"] = False
688
+
689
+ if st.button("Run expanded AutoML + Stacking"):
690
+ st.session_state["run_automl_clicked"] = True
691
+
692
+ if st.session_state["run_automl_clicked"]:
693
  log("AutoML + Stacking initiated.")
694
  with st.spinner("Tuning multiple families (this may take a while depending on choices)..."):
695
  families_to_try = ["RandomForest", "ExtraTrees", "MLP"]
 
958
  recommended_shift[name] *= 0.97 # -3%
959
 
960
  # Delta table
961
+ st.markdown("### Shift Adjustment Summary (vs Previous 200 Samples)")
962
  deltas = pd.DataFrame({
963
  "Current Avg": prev_shift,
964
  "Suggested": recommended_shift,
 
1073
  else:
1074
  st.session_state["shap_recommendations"] = recommendations
1075
 
1076
+ # --- AI Recommendation Assistant ---
1077
  st.markdown("---")
1078
+ st.subheader("AI Recommendation Assistant ")
1079
  st.caption("Generates quick local AI suggestions — no file writes required.")
1080
 
1081
+ # Create or reset button states safely
1082
+ if "hf_clicked" not in st.session_state:
1083
+ st.session_state["hf_clicked"] = False
1084
  if "llm_result" not in st.session_state:
1085
  st.session_state["llm_result"] = None
1086
 
1087
+ # --- Buttons ---
1088
+ col1, col2 = st.columns(2)
1089
+ # Click handlers with isolated session flags
1090
+ if col1.button("Get AI Recommendation (via HF API)", key="ai_reco"):
1091
+ st.session_state["hf_clicked"] = True
1092
+ st.session_state["hf_ran_once"] = False # reset internal control
1093
+
1094
+ if col2.button("Reset Recommendation Output"):
1095
+ st.session_state["hf_clicked"] = False
1096
+ st.session_state["llm_result"] = None
1097
+ st.session_state["hf_ran_once"] = False
1098
+ st.info("Recommendation output cleared.")
1099
+
1100
+ # Execute API call only once
1101
+ if st.session_state["hf_clicked"] and not st.session_state.get("hf_ran_once", False):
1102
  summary = st.session_state.get("automl_summary", {})
1103
  if not summary:
1104
  st.warning("Please run AutoML first to generate context.")
1105
+ else:
1106
+ try:
1107
+ import requests, json
1108
+ st.info("Contacting Hugging Face Inference API (Mixtral-8x7B-Instruct)…")
 
 
 
 
 
 
 
 
1109
 
1110
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
1111
+ headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
1112
+ prompt = f"""
1113
+ You are an ML model tuning advisor.
1114
+ Based on this AutoML summary, suggest 3 concise, actionable steps
1115
+ to improve model performance if overfitting, underfitting, or data-quality issues are observed.
1116
 
1117
+ Use case: {summary.get('use_case')}
1118
+ Target: {summary.get('target')}
1119
+ Final R²: {summary.get('final_r2')}
1120
+ Final RMSE: {summary.get('final_rmse')}
1121
+ Leaderboard: {summary.get('leaderboard')}
1122
+ """
1123
 
1124
+ payload = {"inputs": prompt, "parameters": {"max_new_tokens": 200, "temperature": 0.7}}
1125
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
1126
+ response.raise_for_status()
1127
+ result = response.json()
1128
 
1129
+ if isinstance(result, list) and "generated_text" in result[0]:
1130
+ text = result[0]["generated_text"]
1131
+ elif isinstance(result, dict) and "generated_text" in result:
1132
+ text = result["generated_text"]
1133
+ else:
1134
+ text = json.dumps(result, indent=2)
1135
 
1136
+ st.session_state["llm_result"] = text.strip()
1137
+ st.session_state["hf_ran_once"] = True
1138
+ st.success("AI Recommendation (Mixtral-8x7B-Instruct):")
1139
+ st.markdown(st.session_state["llm_result"])
1140
 
1141
+ except Exception as e:
1142
+ st.error(f"HF Inference API call failed: {e}")
 
 
1143
 
1144
 
1145
+
1146
+ # --- Always display cached result, even on rerun ---
1147
  if st.session_state["llm_result"]:
1148
+ st.markdown("### Cached AI Recommendation:")
1149
  st.markdown(st.session_state["llm_result"])
1150
 
1151