singhn9 commited on
Commit
bfb18ef
·
verified ·
1 Parent(s): b2f829a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +20 -88
src/streamlit_app.py CHANGED
@@ -473,43 +473,31 @@ with tabs[4]:
473
 
474
  features = st.multiselect("Model input features (auto-suggested)", numeric_cols, default=suggested)
475
  st.markdown(f"Auto target: `{target}` · Suggested family hint: `{model_hint}`")
 
476
  # --- Sampling configuration ---
477
  max_rows = min(df.shape[0], 20000)
478
  sample_size = st.slider("Sample rows", 500, max_rows, min(1500, max_rows), step=100)
479
 
480
  # ---------- SAFE target & X preparation ----------
481
- # Ensure target is a single column name (string). If it's a list, pick the first and warn.
482
  if isinstance(target, (list, tuple)):
483
  st.warning(f"Target provided as list/tuple; using first element `{target[0]}` as target.")
484
  target = target[0]
485
 
486
- # Select only valid feature columns
487
- cols_needed = [c for c in features if c in df.columns]
488
- # Match exact name first
489
- if isinstance(target, (list, tuple)):
490
- st.warning(f"Target provided as list/tuple; using first element `{target[0]}` as target.")
491
- target = target[0]
492
-
493
- # Select only valid feature columns
494
  cols_needed = [c for c in features if c in df.columns]
495
 
496
- # --- Force single exact target column ---
497
  if target in df.columns:
498
  target_col = target
499
  else:
500
- # Case-insensitive exact match
501
  matches = [c for c in df.columns if c.lower() == target.lower()]
502
  if matches:
503
  target_col = matches[0]
504
  st.info(f"Auto-corrected to exact match: `{target_col}`")
505
  else:
506
- # Partial substring match (e.g., 'furnace_temp' vs 'furnace_temp_next')
507
  matches = [c for c in df.columns if target.lower() in c.lower()]
508
  if len(matches) == 1:
509
  target_col = matches[0]
510
  st.info(f"Auto-corrected to closest match: `{target_col}`")
511
  elif len(matches) > 1:
512
- # Prefer '_temp', '_ratio', or exact substring equality
513
  preferred = [m for m in matches if m.endswith("_temp") or m.endswith("_ratio") or m == target]
514
  if preferred:
515
  target_col = preferred[0]
@@ -521,7 +509,6 @@ with tabs[4]:
521
  st.error(f"Target `{target}` not found in dataframe columns.")
522
  st.stop()
523
 
524
- # --- Build sub_df safely — ensure unique and valid target ---
525
  valid_features = [c for c in cols_needed if c in df.columns and c != target_col]
526
  if not valid_features:
527
  st.error("No valid feature columns remain after cleaning. Check feature selection.")
@@ -530,22 +517,14 @@ with tabs[4]:
530
  sub_df = df.loc[:, valid_features + [target_col]].copy()
531
  sub_df = sub_df.sample(n=sample_size, random_state=42).reset_index(drop=True)
532
 
533
- # --- Construct clean X and y ---
534
  X = sub_df.drop(columns=[target_col])
535
  y = pd.Series(np.ravel(sub_df[target_col]), name=target_col)
536
 
537
-
538
-
539
-
540
-
541
-
542
- # Drop known leak or identifier columns
543
  leak_cols = ["furnace_temp_next", "pred_temp_30s", "run_timestamp", "timestamp", "batch_id_numeric", "batch_id"]
544
  for lc in leak_cols:
545
  if lc in X.columns:
546
  X.drop(columns=[lc], inplace=True)
547
 
548
- # Remove constant or near-constant columns
549
  nunique = X.nunique(dropna=False)
550
  const_cols = nunique[nunique <= 1].index.tolist()
551
  if const_cols:
@@ -555,7 +534,6 @@ with tabs[4]:
555
  st.error("No valid feature columns remain after cleaning. Check feature selection.")
556
  st.stop()
557
 
558
-
559
  st.markdown("### Ensemble & AutoML Settings")
560
  max_trials = st.slider("Optuna trials per family", 5, 80, 20, step=5)
561
  top_k = st.slider("Max base models in ensemble", 2, 8, 5)
@@ -653,20 +631,16 @@ with tabs[4]:
653
  st.caption(f"Tuning family: {fam}")
654
  result = tune_family(fam, X, y, n_trials=max_trials)
655
  model_obj = result.get("model_obj")
656
-
657
- # Fix: ensure model is safe to access before fitting
658
  if hasattr(model_obj, "estimators_"):
659
- delattr(model_obj, "estimators_") # clear stale ref if any
660
  result["model_obj"] = model_obj
661
  tuned_results.append(result)
662
 
663
- # --- Leaderboard
664
  lb = pd.DataFrame([{"family": r["family"], "cv_r2": r["cv_score"], "params": r["best_params"]} for r in tuned_results])
665
  lb = lb.sort_values("cv_r2", ascending=False).reset_index(drop=True)
666
  st.markdown("### Tuning Leaderboard (by CV R²)")
667
  st.dataframe(lb[["family","cv_r2"]].round(4))
668
 
669
- # --- Enhanced Ensemble Stacking ---
670
  from sklearn.feature_selection import SelectKBest, f_regression
671
  from sklearn.linear_model import LinearRegression
672
  from sklearn.model_selection import KFold
@@ -683,26 +657,21 @@ with tabs[4]:
683
  kf = KFold(n_splits=5, shuffle=True, random_state=42)
684
  base_models, oof_preds = [], pd.DataFrame(index=X_sel.index)
685
 
686
- # Prevent premature __len__() access on unfitted ensemble models
687
  for r in tuned_results:
688
  m = r.get("model_obj")
689
- # Avoid implicit truth check that calls __len__
690
  if m is not None:
691
  try:
692
- # If model defines __len__, override before fit
693
  if "__len__" in dir(m) and not hasattr(m, "estimators_"):
694
  setattr(m, "__len__", lambda self=m: 0)
695
  except Exception:
696
  pass
697
 
698
-
699
  for fam, entry in [(r["family"], r) for r in tuned_results if r.get("model_obj") is not None]:
700
  model_obj = entry["model_obj"]
701
  oof = np.zeros(X_sel.shape[0])
702
  for tr_idx, val_idx in kf.split(X_sel):
703
  X_tr, X_val = X_sel.iloc[tr_idx], X_sel.iloc[val_idx]
704
- y_tr = y[tr_idx] if not hasattr(y, "iloc") else y.iloc[tr_idx]
705
-
706
  try:
707
  model_obj.fit(X_tr, y_tr)
708
  preds = model_obj.predict(X_val)
@@ -753,69 +722,44 @@ with tabs[4]:
753
  ax.plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--")
754
  st.pyplot(fig, clear_figure=True)
755
 
756
- st.session_state["automl_summary"] = {
757
- "leaderboard": summary_df[["family","cv_r2"]].to_dict(orient="records"),
758
- "final_r2": float(final_r2),
759
- "final_rmse": float(final_rmse),
760
- "target": target,
761
- "use_case": use_case
762
- }
763
-
764
- # --- Operator Advisory System + Llama-3-70B-Instruct ---
765
  st.markdown("---")
766
  st.subheader("Operator Advisory System — Real-Time Shift Recommendations")
767
 
768
  try:
769
  top_base = next((b for b in base_models if b["family"] == selected[0]), None)
770
  if top_base and hasattr(top_base["model"], "predict"):
771
- # --- Ensure numeric dtypes for SHAP ---
772
  sample_X = X_val.sample(min(300, len(X_val)), random_state=42).copy()
773
-
774
  def _clean_to_float(x):
775
- """Safely convert any numeric-looking string (even '[1.55E3]') to float."""
776
  if isinstance(x, (int, float, np.floating)):
777
  return float(x)
778
  try:
779
  x_str = str(x).replace("[", "").replace("]", "").replace(",", "").strip()
780
- # handle common non-numeric tokens
781
  if x_str.lower() in ("nan", "none", "", "null", "na", "n/a"):
782
  return 0.0
783
  return float(x_str.replace("E", "e"))
784
  except Exception:
785
  return 0.0
786
-
787
- # Apply cleaning to every column
788
  for col in sample_X.columns:
789
  sample_X[col] = sample_X[col].map(_clean_to_float)
790
-
791
- # Verify numeric dtype and replace NaN
792
  sample_X = sample_X.apply(pd.to_numeric, errors="coerce").fillna(0)
793
-
794
- # Optional diagnostic
795
- non_numeric_cols = [c for c in sample_X.columns if not np.issubdtype(sample_X[c].dtype, np.number)]
796
- if non_numeric_cols:
797
- st.warning(f"Cleaned {len(non_numeric_cols)} potential non-numeric columns: {non_numeric_cols}")
798
-
799
 
800
-
801
- # --- SHAP computation ---
802
  model = top_base["model"]
803
  expl = shap.TreeExplainer(model)
804
  shap_vals = expl.shap_values(sample_X)
805
- if isinstance(shap_vals, list):
806
- shap_vals = shap_vals[0]
807
  shap_vals = np.array(shap_vals)
808
- mean_abs = np.abs(shap_vals).mean(axis=0)
809
- mean_sign = np.sign(shap_vals).mean(axis=0)
810
  importance = pd.DataFrame({
811
  "Feature": sample_X.columns,
812
- "Mean |SHAP|": mean_abs,
813
- "Mean SHAP Sign": mean_sign
814
  }).sort_values("Mean |SHAP|", ascending=False)
815
-
816
  st.markdown("### Top 5 Operational Drivers")
817
  st.dataframe(importance.head(5))
818
-
819
  recommendations = []
820
  for _, row in importance.head(5).iterrows():
821
  f, s = row["Feature"], row["Mean SHAP Sign"]
@@ -825,22 +769,17 @@ with tabs[4]:
825
  recommendations.append(f"Decrease `{f}` likely increases `{target}`")
826
  else:
827
  recommendations.append(f"`{f}` neutral for `{target}`")
828
-
829
  st.markdown("### Suggested Operator Adjustments")
830
  st.write("\n".join(recommendations))
831
 
832
-
833
- # --- Call HF Llama-3-70B-Instruct API for summary ---
834
- # --- Call HF Llama-3-70B-Instruct API for summary (robust + debug-safe) ---
835
  import requests, json, textwrap
836
-
837
- HF_TOKEN = os.getenv("HF_TOKEN") # Works on Hugging Face Spaces
838
  if not HF_TOKEN:
839
- st.error("HF_TOKEN not detected. Check the Secrets tab in your Space settings.")
840
  else:
841
  API_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-3-8B-Instruct"
842
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
843
-
844
  prompt = textwrap.dedent(f"""
845
  You are an expert metallurgical process advisor.
846
  Based on these SHAP-derived recommendations:
@@ -849,16 +788,9 @@ with tabs[4]:
849
  Use case: {use_case}
850
  Summarize in three concise, professional lines what the operator should do this shift.
851
  """)
852
-
853
- payload = {
854
- "inputs": prompt,
855
- "parameters": {"max_new_tokens": 150, "temperature": 0.6}
856
- }
857
-
858
  with st.spinner("Generating operator note (Llama-3-8B)…"):
859
  resp = requests.post(API_URL, headers=headers, json=payload, timeout=90)
860
-
861
- # --- Debug section (safe, no secrets printed) ---
862
  try:
863
  data = resp.json()
864
  st.caption("Raw HF response:")
@@ -867,8 +799,7 @@ with tabs[4]:
867
  st.warning(f"HF raw response parse error: {ex}")
868
  st.text(resp.text)
869
  data = None
870
-
871
- # --- Extract generated text robustly ---
872
  text = ""
873
  if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
874
  text = data[0]["generated_text"].strip()
@@ -876,13 +807,14 @@ with tabs[4]:
876
  text = data["generated_text"].strip()
877
  elif isinstance(data, str):
878
  text = data.strip()
879
-
880
  if text:
881
- st.success(" Operator Advisory Generated:")
882
  st.info(text)
883
  else:
884
  st.warning("Operator advisory skipped: no text returned from model.")
885
-
 
886
 
887
 
888
  # ----- Business Impact tab
 
473
 
474
  features = st.multiselect("Model input features (auto-suggested)", numeric_cols, default=suggested)
475
  st.markdown(f"Auto target: `{target}` · Suggested family hint: `{model_hint}`")
476
+
477
  # --- Sampling configuration ---
478
  max_rows = min(df.shape[0], 20000)
479
  sample_size = st.slider("Sample rows", 500, max_rows, min(1500, max_rows), step=100)
480
 
481
  # ---------- SAFE target & X preparation ----------
 
482
  if isinstance(target, (list, tuple)):
483
  st.warning(f"Target provided as list/tuple; using first element `{target[0]}` as target.")
484
  target = target[0]
485
 
 
 
 
 
 
 
 
 
486
  cols_needed = [c for c in features if c in df.columns]
487
 
 
488
  if target in df.columns:
489
  target_col = target
490
  else:
 
491
  matches = [c for c in df.columns if c.lower() == target.lower()]
492
  if matches:
493
  target_col = matches[0]
494
  st.info(f"Auto-corrected to exact match: `{target_col}`")
495
  else:
 
496
  matches = [c for c in df.columns if target.lower() in c.lower()]
497
  if len(matches) == 1:
498
  target_col = matches[0]
499
  st.info(f"Auto-corrected to closest match: `{target_col}`")
500
  elif len(matches) > 1:
 
501
  preferred = [m for m in matches if m.endswith("_temp") or m.endswith("_ratio") or m == target]
502
  if preferred:
503
  target_col = preferred[0]
 
509
  st.error(f"Target `{target}` not found in dataframe columns.")
510
  st.stop()
511
 
 
512
  valid_features = [c for c in cols_needed if c in df.columns and c != target_col]
513
  if not valid_features:
514
  st.error("No valid feature columns remain after cleaning. Check feature selection.")
 
517
  sub_df = df.loc[:, valid_features + [target_col]].copy()
518
  sub_df = sub_df.sample(n=sample_size, random_state=42).reset_index(drop=True)
519
 
 
520
  X = sub_df.drop(columns=[target_col])
521
  y = pd.Series(np.ravel(sub_df[target_col]), name=target_col)
522
 
 
 
 
 
 
 
523
  leak_cols = ["furnace_temp_next", "pred_temp_30s", "run_timestamp", "timestamp", "batch_id_numeric", "batch_id"]
524
  for lc in leak_cols:
525
  if lc in X.columns:
526
  X.drop(columns=[lc], inplace=True)
527
 
 
528
  nunique = X.nunique(dropna=False)
529
  const_cols = nunique[nunique <= 1].index.tolist()
530
  if const_cols:
 
534
  st.error("No valid feature columns remain after cleaning. Check feature selection.")
535
  st.stop()
536
 
 
537
  st.markdown("### Ensemble & AutoML Settings")
538
  max_trials = st.slider("Optuna trials per family", 5, 80, 20, step=5)
539
  top_k = st.slider("Max base models in ensemble", 2, 8, 5)
 
631
  st.caption(f"Tuning family: {fam}")
632
  result = tune_family(fam, X, y, n_trials=max_trials)
633
  model_obj = result.get("model_obj")
 
 
634
  if hasattr(model_obj, "estimators_"):
635
+ delattr(model_obj, "estimators_")
636
  result["model_obj"] = model_obj
637
  tuned_results.append(result)
638
 
 
639
  lb = pd.DataFrame([{"family": r["family"], "cv_r2": r["cv_score"], "params": r["best_params"]} for r in tuned_results])
640
  lb = lb.sort_values("cv_r2", ascending=False).reset_index(drop=True)
641
  st.markdown("### Tuning Leaderboard (by CV R²)")
642
  st.dataframe(lb[["family","cv_r2"]].round(4))
643
 
 
644
  from sklearn.feature_selection import SelectKBest, f_regression
645
  from sklearn.linear_model import LinearRegression
646
  from sklearn.model_selection import KFold
 
657
  kf = KFold(n_splits=5, shuffle=True, random_state=42)
658
  base_models, oof_preds = [], pd.DataFrame(index=X_sel.index)
659
 
 
660
  for r in tuned_results:
661
  m = r.get("model_obj")
 
662
  if m is not None:
663
  try:
 
664
  if "__len__" in dir(m) and not hasattr(m, "estimators_"):
665
  setattr(m, "__len__", lambda self=m: 0)
666
  except Exception:
667
  pass
668
 
 
669
  for fam, entry in [(r["family"], r) for r in tuned_results if r.get("model_obj") is not None]:
670
  model_obj = entry["model_obj"]
671
  oof = np.zeros(X_sel.shape[0])
672
  for tr_idx, val_idx in kf.split(X_sel):
673
  X_tr, X_val = X_sel.iloc[tr_idx], X_sel.iloc[val_idx]
674
+ y_tr = y.iloc[tr_idx]
 
675
  try:
676
  model_obj.fit(X_tr, y_tr)
677
  preds = model_obj.predict(X_val)
 
722
  ax.plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--")
723
  st.pyplot(fig, clear_figure=True)
724
 
725
+ # --- Operator Advisory ---
 
 
 
 
 
 
 
 
726
  st.markdown("---")
727
  st.subheader("Operator Advisory System — Real-Time Shift Recommendations")
728
 
729
  try:
730
  top_base = next((b for b in base_models if b["family"] == selected[0]), None)
731
  if top_base and hasattr(top_base["model"], "predict"):
 
732
  sample_X = X_val.sample(min(300, len(X_val)), random_state=42).copy()
733
+
734
  def _clean_to_float(x):
 
735
  if isinstance(x, (int, float, np.floating)):
736
  return float(x)
737
  try:
738
  x_str = str(x).replace("[", "").replace("]", "").replace(",", "").strip()
 
739
  if x_str.lower() in ("nan", "none", "", "null", "na", "n/a"):
740
  return 0.0
741
  return float(x_str.replace("E", "e"))
742
  except Exception:
743
  return 0.0
744
+
 
745
  for col in sample_X.columns:
746
  sample_X[col] = sample_X[col].map(_clean_to_float)
 
 
747
  sample_X = sample_X.apply(pd.to_numeric, errors="coerce").fillna(0)
 
 
 
 
 
 
748
 
 
 
749
  model = top_base["model"]
750
  expl = shap.TreeExplainer(model)
751
  shap_vals = expl.shap_values(sample_X)
752
+ if isinstance(shap_vals, list): shap_vals = shap_vals[0]
 
753
  shap_vals = np.array(shap_vals)
 
 
754
  importance = pd.DataFrame({
755
  "Feature": sample_X.columns,
756
+ "Mean |SHAP|": np.abs(shap_vals).mean(axis=0),
757
+ "Mean SHAP Sign": np.sign(shap_vals).mean(axis=0)
758
  }).sort_values("Mean |SHAP|", ascending=False)
759
+
760
  st.markdown("### Top 5 Operational Drivers")
761
  st.dataframe(importance.head(5))
762
+
763
  recommendations = []
764
  for _, row in importance.head(5).iterrows():
765
  f, s = row["Feature"], row["Mean SHAP Sign"]
 
769
  recommendations.append(f"Decrease `{f}` likely increases `{target}`")
770
  else:
771
  recommendations.append(f"`{f}` neutral for `{target}`")
772
+
773
  st.markdown("### Suggested Operator Adjustments")
774
  st.write("\n".join(recommendations))
775
 
 
 
 
776
  import requests, json, textwrap
777
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
778
  if not HF_TOKEN:
779
+ st.error("HF_TOKEN not detected. Check the Secrets tab.")
780
  else:
781
  API_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-3-8B-Instruct"
782
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
 
783
  prompt = textwrap.dedent(f"""
784
  You are an expert metallurgical process advisor.
785
  Based on these SHAP-derived recommendations:
 
788
  Use case: {use_case}
789
  Summarize in three concise, professional lines what the operator should do this shift.
790
  """)
791
+ payload = {"inputs": prompt, "parameters": {"max_new_tokens": 150, "temperature": 0.6}}
 
 
 
 
 
792
  with st.spinner("Generating operator note (Llama-3-8B)…"):
793
  resp = requests.post(API_URL, headers=headers, json=payload, timeout=90)
 
 
794
  try:
795
  data = resp.json()
796
  st.caption("Raw HF response:")
 
799
  st.warning(f"HF raw response parse error: {ex}")
800
  st.text(resp.text)
801
  data = None
802
+
 
803
  text = ""
804
  if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
805
  text = data[0]["generated_text"].strip()
 
807
  text = data["generated_text"].strip()
808
  elif isinstance(data, str):
809
  text = data.strip()
810
+
811
  if text:
812
+ st.success(" Operator Advisory Generated:")
813
  st.info(text)
814
  else:
815
  st.warning("Operator advisory skipped: no text returned from model.")
816
+ except Exception as e:
817
+ st.warning(f"Operator advisory skipped: {e}")
818
 
819
 
820
  # ----- Business Impact tab