singhn9 commited on
Commit
dceb721
·
verified ·
1 Parent(s): ce5795d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +122 -7
src/streamlit_app.py CHANGED
@@ -31,6 +31,9 @@ import shap
31
  # -------------------------
32
 
33
  st.set_page_config(page_title="Steel Authority of India Limited (MODEX)", layout="wide")
 
 
 
34
 
35
  LOG_DIR = "./logs"
36
  os.makedirs(LOG_DIR, exist_ok=True)
@@ -414,15 +417,37 @@ with tabs[0]:
414
 
415
  # ----- Visualize tab
416
  with tabs[1]:
417
- st.subheader("Feature visualization")
418
  col = st.selectbox("Choose numeric feature", numeric_cols, index=0)
419
  bins = st.slider("Histogram bins", 10, 200, 50)
420
- fig, ax = plt.subplots(figsize=(8,4))
421
- sns.histplot(df[col], bins=bins, kde=True, ax=ax)
422
- ax.set_title(col)
423
- st.pyplot(fig)
 
 
 
 
 
424
  st.write(df[col].describe().to_frame().T)
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  # ----- Correlations tab
427
  with tabs[2]:
428
  st.subheader("Correlation explorer")
@@ -431,8 +456,14 @@ with tabs[2]:
431
  if len(corr_sel) >= 2:
432
  corr = df[corr_sel].corr()
433
  fig, ax = plt.subplots(figsize=(10,8))
434
- sns.heatmap(corr, cmap="coolwarm", center=0, ax=ax)
435
- st.pyplot(fig)
 
 
 
 
 
 
436
  else:
437
  st.info("Choose at least 2 numeric features to compute correlation.")
438
 
@@ -659,6 +690,21 @@ with tabs[4]:
659
  lb = lb.sort_values("cv_r2", ascending=False).reset_index(drop=True)
660
  st.markdown("### Tuning Leaderboard (by CV R²)")
661
  st.dataframe(lb[["family","cv_r2"]].round(4))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
 
663
  # --- Build base-models and collect out-of-fold preds for stacking ---
664
  st.markdown("### Building base models & out-of-fold predictions for stacking")
@@ -849,6 +895,75 @@ with tabs[4]:
849
 
850
  st.success(" AutoML + Stacking complete — metrics, artifacts, and SHAP ready.")
851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
 
853
 
854
  # ----- Target & Business Impact tab
 
31
  # -------------------------
32
 
33
  st.set_page_config(page_title="Steel Authority of India Limited (MODEX)", layout="wide")
34
+ plt.style.use("seaborn-v0_8-muted")
35
+ sns.set_palette("muted")
36
+ sns.set_style("whitegrid")
37
 
38
  LOG_DIR = "./logs"
39
  os.makedirs(LOG_DIR, exist_ok=True)
 
417
 
418
  # ----- Visualize tab
419
  with tabs[1]:
420
+ st.subheader("Feature Visualization")
421
  col = st.selectbox("Choose numeric feature", numeric_cols, index=0)
422
  bins = st.slider("Histogram bins", 10, 200, 50)
423
+
424
+ # --- Improved Histogram with style ---
425
+ fig, ax = plt.subplots(figsize=(8, 4))
426
+ sns.histplot(df[col], bins=bins, kde=True, ax=ax, color="#2C6E91", alpha=0.8)
427
+ ax.set_title(f"Distribution of {col.replace('_', ' ').title()}", fontsize=12)
428
+ ax.set_xlabel(col.replace("_", " ").title(), fontsize=10)
429
+ ax.set_ylabel("Frequency", fontsize=10)
430
+ sns.despine()
431
+ st.pyplot(fig, clear_figure=True)
432
  st.write(df[col].describe().to_frame().T)
433
 
434
+ # --- Add PCA scatter visualization ---
435
+ if all(x in df.columns for x in ["pca_1", "pca_2", "operating_mode"]):
436
+ st.markdown("### PCA Feature Space — Colored by Operating Mode")
437
+ fig2, ax2 = plt.subplots(figsize=(6, 5))
438
+ sns.scatterplot(
439
+ data=df.sample(min(1000, len(df)), random_state=42),
440
+ x="pca_1", y="pca_2", hue="operating_mode",
441
+ palette="tab10", alpha=0.7, s=40, ax=ax2
442
+ )
443
+ ax2.set_title("Operating Mode Clusters (PCA Projection)", fontsize=12)
444
+ ax2.set_xlabel("PCA 1")
445
+ ax2.set_ylabel("PCA 2")
446
+ ax2.legend(title="Operating Mode", bbox_to_anchor=(1.05, 1), loc="upper left")
447
+ sns.despine()
448
+ st.pyplot(fig2, clear_figure=True)
449
+
450
+
451
  # ----- Correlations tab
452
  with tabs[2]:
453
  st.subheader("Correlation explorer")
 
456
  if len(corr_sel) >= 2:
457
  corr = df[corr_sel].corr()
458
  fig, ax = plt.subplots(figsize=(10,8))
459
+ sns.heatmap(
460
+ corr, cmap="RdBu_r", center=0, annot=True, fmt=".2f",
461
+ linewidths=0.5, cbar_kws={"shrink": 0.7}, ax=ax
462
+ )
463
+ ax.set_title("Feature Correlation Matrix", fontsize=12)
464
+ sns.despine()
465
+ st.pyplot(fig, clear_figure=True)
466
+
467
  else:
468
  st.info("Choose at least 2 numeric features to compute correlation.")
469
 
 
690
  lb = lb.sort_values("cv_r2", ascending=False).reset_index(drop=True)
691
  st.markdown("### Tuning Leaderboard (by CV R²)")
692
  st.dataframe(lb[["family","cv_r2"]].round(4))
693
+ # --- Bonus Visualization: Model Performance Summary ---
694
+ if not lb.empty:
695
+ st.markdown("#### Model Performance Summary (CV R²)")
696
+ fig_perf, ax_perf = plt.subplots(figsize=(7, 4))
697
+ colors = ["#2C6E91" if fam != lb.iloc[0]["family"] else "#C65F00" for fam in lb["family"]]
698
+ ax_perf.barh(lb["family"], lb["cv_r2"], color=colors, alpha=0.85)
699
+ ax_perf.set_xlabel("Cross-Validated R² Score", fontsize=10)
700
+ ax_perf.set_ylabel("Model Family", fontsize=10)
701
+ ax_perf.set_title("Performance Comparison Across Model Families", fontsize=12)
702
+ ax_perf.invert_yaxis()
703
+ for i, v in enumerate(lb["cv_r2"]):
704
+ ax_perf.text(v + 0.005, i, f"{v:.3f}", va="center", fontsize=9)
705
+ sns.despine()
706
+ st.pyplot(fig_perf, clear_figure=True)
707
+
708
 
709
  # --- Build base-models and collect out-of-fold preds for stacking ---
710
  st.markdown("### Building base models & out-of-fold predictions for stacking")
 
895
 
896
  st.success(" AutoML + Stacking complete — metrics, artifacts, and SHAP ready.")
897
 
898
+ # --- Store AutoML summary for optional LLM advisory ---
899
+ st.session_state["automl_summary"] = {
900
+ "leaderboard": lb[["family", "cv_r2"]].round(4).to_dict(orient="records"),
901
+ "final_r2": float(final_r2),
902
+ "final_rmse": float(final_rmse),
903
+ "target": target,
904
+ "use_case": use_case
905
+ }
906
+
907
+ # --- Optional: AI Model Recommendation Assistant ---
908
+ st.markdown("---")
909
+ st.subheader("AI Recommendation Assistant (cached local model)")
910
+ st.caption("Get quick local AI suggestions without internet — cached inside ./logs")
911
+
912
+ if st.button("Get AI Recommendation (tiny local LLM)", key="ai_reco"):
913
+ summary = st.session_state.get("automl_summary", {})
914
+ st.info("Loading local model... first time may take ~10s.")
915
+ try:
916
+ import importlib.util, os
917
+ from pathlib import Path
918
+
919
+ # Ensure transformers is available
920
+ if importlib.util.find_spec("transformers") is None:
921
+ st.error("Transformers not installed. Run `pip install transformers`.")
922
+ else:
923
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
924
+
925
+ MODEL_NAME = "sshleifer/tiny-gpt2" # very small 6 MB model
926
+ MODEL_DIR = Path(LOG_DIR) / "cached_tiny_llm"
927
+ os.makedirs(MODEL_DIR, exist_ok=True)
928
+
929
+ # If model is already cached locally, load from there
930
+ if (MODEL_DIR / "config.json").exists():
931
+ st.caption("Loading tiny model from local cache...")
932
+ model = AutoModelForCausalLM.from_pretrained(MODEL_DIR)
933
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
934
+ else:
935
+ st.caption("☁️ Downloading tiny model (once only)...")
936
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
937
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
938
+ model.save_pretrained(MODEL_DIR)
939
+ tokenizer.save_pretrained(MODEL_DIR)
940
+ st.success("Cached tiny LLM in ./logs/cached_tiny_llm")
941
+
942
+ assistant = pipeline("text-generation", model=model, tokenizer=tokenizer)
943
+
944
+ prompt = f"""
945
+ You are an ML model tuning assistant.
946
+ Given this AutoML summary, provide 3 actionable steps for improvement if overfitting,
947
+ underfitting, or data quality issues are suspected.
948
+
949
+ Use case: {summary.get('use_case')}
950
+ Target: {summary.get('target')}
951
+ Final R²: {summary.get('final_r2')}
952
+ Final RMSE: {summary.get('final_rmse')}
953
+ Leaderboard: {summary.get('leaderboard')}
954
+
955
+ Respond in concise numbered steps.
956
+ """
957
+ out = assistant(prompt, max_new_tokens=90, temperature=0.7, do_sample=True)[0]["generated_text"]
958
+ st.success("LLM Recommendation:")
959
+ st.markdown(out)
960
+ log("Tiny LLM recommendation generated successfully.")
961
+ except Exception as e:
962
+ st.error(f"LLM generation failed: {e}")
963
+ st.info("If the model download failed, rerun once — it will cache afterward.")
964
+
965
+
966
+
967
 
968
 
969
  # ----- Target & Business Impact tab