zaid002 commited on
Commit
a5d32c6
·
verified ·
1 Parent(s): d1d3efe

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +70 -3
src/streamlit_app.py CHANGED
@@ -4,16 +4,30 @@ import pandas as pd
4
  import streamlit as st
5
  import os
6
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  if os.path.exists("lightweight_model.json"):
 
8
  light_model = json.load(open("lightweight_model.json"))
 
9
  means = light_model["feature_means"]
10
  stds = light_model["feature_stds"]
 
11
  else:
12
  st.warning("⚠️ lightweight_model.json not found. Using fallback model.")
13
 
14
- # Define default averages (safe fallback)
15
  means = {
16
- "Age": 35,
17
  "MonthlyIncome": 6500,
18
  "JobSatisfaction": 3,
19
  "WorkLifeBalance": 3,
@@ -21,4 +35,57 @@ else:
21
  "OverTime": 0.2
22
  }
23
 
24
- stds = {k: 1 for k in means} # avoid divide-by-zero
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
5
  import os
6
 
7
+ st.set_page_config(
8
+ page_title="Employee Attrition Prediction",
9
+ page_icon="👩‍💼",
10
+ layout="centered"
11
+ )
12
+
13
+ st.title("👩‍💼 Employee Attrition Prediction (HF Safe Version)")
14
+ st.write("This app predicts whether an employee is likely to leave the company.")
15
+
16
+ # ---------------------------
17
+ # Load lightweight JSON model
18
+ # ---------------------------
19
  if os.path.exists("lightweight_model.json"):
20
+ st.success("✅ Model loaded successfully!")
21
  light_model = json.load(open("lightweight_model.json"))
22
+
23
  means = light_model["feature_means"]
24
  stds = light_model["feature_stds"]
25
+
26
  else:
27
  st.warning("⚠️ lightweight_model.json not found. Using fallback model.")
28
 
 
29
  means = {
30
+ "Age": 35,
31
  "MonthlyIncome": 6500,
32
  "JobSatisfaction": 3,
33
  "WorkLifeBalance": 3,
 
35
  "OverTime": 0.2
36
  }
37
 
38
+ # Avoid divide-by-zero
39
+ stds = {k: 1 for k in means}
40
+
41
+ # ---------------------------
42
+ # Prediction Logic (Simple Logistic)
43
+ # ---------------------------
44
+ def simple_predict(df):
45
+ # Normalize input
46
+ for col in df.columns:
47
+ df[col] = (df[col] - means[col]) / (stds[col] + 1e-6)
48
+
49
+ score = df.sum(axis=1).values[0]
50
+ probability = 1 / (1 + np.exp(-score))
51
+ return probability
52
+
53
+ # ---------------------------
54
+ # Input Form
55
+ # ---------------------------
56
+ st.header("🔮 Enter Employee Details")
57
+
58
+ age = st.number_input("Age", min_value=18, max_value=60, value=30)
59
+ income = st.number_input("Monthly Income", min_value=1000, max_value=20000, value=5000)
60
+ job_sat = st.slider("Job Satisfaction (1–4)", 1, 4, 3)
61
+ wlb = st.slider("Work-Life Balance (1–4)", 1, 4, 3)
62
+ years = st.number_input("Years at Company", min_value=0, max_value=40, value=5)
63
+ overtime = st.selectbox("OverTime", ["Yes", "No"])
64
+ overtime_val = 1 if overtime == "Yes" else 0
65
+
66
+ # Prepare DataFrame
67
+ input_df = pd.DataFrame([{
68
+ "Age": age,
69
+ "MonthlyIncome": income,
70
+ "JobSatisfaction": job_sat,
71
+ "WorkLifeBalance": wlb,
72
+ "YearsAtCompany": years,
73
+ "OverTime": overtime_val
74
+ }])
75
+
76
+ # ---------------------------
77
+ # Predict
78
+ # ---------------------------
79
+ if st.button("Predict Attrition"):
80
+ prob = simple_predict(input_df)
81
+
82
+ if prob > 0.5:
83
+ st.error(f"⚠️ Employee likely to leave the company. (Confidence: {prob:.2f})")
84
+ else:
85
+ st.success(f"✅ Employee likely to stay. (Confidence: {1 - prob:.2f})")
86
+
87
+ # ---------------------------
88
+ # Footer
89
+ # ---------------------------
90
+ st.markdown("---")
91
+ st.caption("Built with ❤️ using Streamlit — Safe for HuggingFace Spaces")