employee_attribute_2 / src /streamlit_app.py
zaid002's picture
Update src/streamlit_app.py
a5d32c6 verified
import json
import numpy as np
import pandas as pd
import streamlit as st
import os
st.set_page_config(
page_title="Employee Attrition Prediction",
page_icon="๐Ÿ‘ฉโ€๐Ÿ’ผ",
layout="centered"
)
st.title("๐Ÿ‘ฉโ€๐Ÿ’ผ Employee Attrition Prediction (HF Safe Version)")
st.write("This app predicts whether an employee is likely to leave the company.")
# ---------------------------
# Load lightweight JSON model
# ---------------------------
if os.path.exists("lightweight_model.json"):
st.success("โœ… Model loaded successfully!")
light_model = json.load(open("lightweight_model.json"))
means = light_model["feature_means"]
stds = light_model["feature_stds"]
else:
st.warning("โš ๏ธ lightweight_model.json not found. Using fallback model.")
means = {
"Age": 35,
"MonthlyIncome": 6500,
"JobSatisfaction": 3,
"WorkLifeBalance": 3,
"YearsAtCompany": 5,
"OverTime": 0.2
}
# Avoid divide-by-zero
stds = {k: 1 for k in means}
# ---------------------------
# Prediction Logic (Simple Logistic)
# ---------------------------
def simple_predict(df):
# Normalize input
for col in df.columns:
df[col] = (df[col] - means[col]) / (stds[col] + 1e-6)
score = df.sum(axis=1).values[0]
probability = 1 / (1 + np.exp(-score))
return probability
# ---------------------------
# Input Form
# ---------------------------
st.header("๐Ÿ”ฎ Enter Employee Details")
age = st.number_input("Age", min_value=18, max_value=60, value=30)
income = st.number_input("Monthly Income", min_value=1000, max_value=20000, value=5000)
job_sat = st.slider("Job Satisfaction (1โ€“4)", 1, 4, 3)
wlb = st.slider("Work-Life Balance (1โ€“4)", 1, 4, 3)
years = st.number_input("Years at Company", min_value=0, max_value=40, value=5)
overtime = st.selectbox("OverTime", ["Yes", "No"])
overtime_val = 1 if overtime == "Yes" else 0
# Prepare DataFrame
input_df = pd.DataFrame([{
"Age": age,
"MonthlyIncome": income,
"JobSatisfaction": job_sat,
"WorkLifeBalance": wlb,
"YearsAtCompany": years,
"OverTime": overtime_val
}])
# ---------------------------
# Predict
# ---------------------------
if st.button("Predict Attrition"):
prob = simple_predict(input_df)
if prob > 0.5:
st.error(f"โš ๏ธ Employee likely to leave the company. (Confidence: {prob:.2f})")
else:
st.success(f"โœ… Employee likely to stay. (Confidence: {1 - prob:.2f})")
# ---------------------------
# Footer
# ---------------------------
st.markdown("---")
st.caption("Built with โค๏ธ using Streamlit โ€” Safe for HuggingFace Spaces")