Yashvj123's picture
Update Home.py
eb736a2 verified
import streamlit as st
import numpy as np
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Load Model
model = pickle.load(open("Employee_Attrition1.pkl", "rb"))
st.set_page_config(
page_title="Employee Attrition",
page_icon="📊",
layout="centered",
)
st.markdown("""
<style>
.stApp {
background: linear-gradient(to right, #f7f7f7, #e3f2fd);
color: #333;
}
.title {
text-align: center;
font-size: 28px;
font-weight: bold;
color: #2C3E50;
}
.subtitle {
text-align: center;
font-size: 30px;
font-weight: bold;
color: #003366;
margin-top: 10px;
}
.stButton > button {
width: 100%;
background: linear-gradient(to right, #009688, #00796B);
color: white;
font-size: 18px;
font-weight: bold;
border-radius: 8px;
padding: 10px;
transition: 0.3s;
border: none;
box-shadow: 3px 3px 6px rgba(0, 0, 0, 0.2);
}
.stButton > button:hover {
background: linear-gradient(to right, #00796B, #004D40);
transform: scale(1.05);
box-shadow: 5px 5px 8px rgba(0, 0, 0, 0.3);
}
.result-box {
text-align: center;
font-size: 24px;
font-weight: bold;
color: white;
padding: 15px;
border-radius: 10px;
margin-top: 20px;
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.2);
}
.result-box.stay {
background: linear-gradient(to right, #43A047, #2E7D32);
}
.result-box.leave {
background: linear-gradient(to right, #E53935, #C62828);
}
.card {
background: white;
padding: 15px;
border-radius: 12px;
box-shadow: 2px 2px 8px rgba(0, 0, 0, 0.1);
margin-top: 15px;
}
.stTextInput > div > div > input {
border: 2px solid #00796B;
border-radius: 6px;
padding: 10px;
}
.stSelectbox > div > div {
border: 2px solid #00796B !important;
border-radius: 6px !important;
}
</style>
""", unsafe_allow_html=True)
# Navigation State
if "current_page" not in st.session_state:
st.session_state.current_page = "Model Pipeline"
def switch_page(page):
st.session_state.current_page = page
# Sidebar Navigation
st.sidebar.title("Navigation")
if st.sidebar.button("Model Pipeline"):
switch_page("Model Pipeline")
if st.sidebar.button("Hands-on Model"):
switch_page("Hands-on Model")
## Importing Data
data = pd.read_csv("Employee_Attrition_Data.csv")
data.columns = data.columns.str.strip()
# Model Report Page
if st.session_state.current_page == "Model Pipeline":
st.markdown("<h1 class='title'>Employee Attrition Prediction</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.image("images/Employee_Atrition.png",
caption="Employee_Atrition",
use_container_width=True)
if st.button("**Problem Statement**"):
switch_page("Problem Statement")
if st.button("**Data Collection**"):
switch_page("Data Collection")
if st.button("**Simple EDA**"):
switch_page("Simple EDA")
if st.button("**Data Pre-processing**"):
switch_page("Data Pre-processing")
if st.button("**Exploratory Data Analysis**"):
switch_page("EDA")
if st.button("**Model Building**"):
switch_page("Model Building")
if st.button("**Final Model**"):
switch_page("Final Model")
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown(
"""
<div style="text-align: center;">
<a href="https://github.com/Yashvj22/Employee_Attrition_Classification_Model" target="_blank" style="
background: linear-gradient(to right, #009688, #00796B);
color: white;
padding: 14px 28px;
text-decoration: none;
font-size: 18px;
font-weight: bold;
border-radius: 10px;
display: inline-block;
transition: 0.3s;
box-shadow: 3px 3px 6px rgba(0, 0, 0, 0.2);">
🔗 See Whole Code on GitHub
</a>
</div>
""",
unsafe_allow_html=True
)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown('''
<h2 style="text-align:center;"> About Author</h2>
<div style="background-color:#f5f5f5; border-radius:10px; padding:20px; margin-top:20px;">
<p style="font-size:16px; text-align:center; font-family:Georgia; line-height:1.6; color:#000;">
Hello! I’m <b>Yash Jadhav</b>, a passionate <span style="color:#FF6347;">Data Scientist</span>
and <span style="color:#4682B4;">Data Analyst</span>.
I specialize in transforming raw data into actionable insights and helping others master the art of Machine Learning.
</p>
<div style="text-align:center; margin-top:20px;">
<a href="https://www.linkedin.com/in/yash-jadhav-454b0a237/" target="_blank" style="
background-color:#0073b1; color:white; padding:10px 20px; border-radius:5px;
text-decoration:none; margin-right:10px;">LinkedIn</a>
<a href="https://github.com/Yashvj22" target="_blank" style="
background-color:black; color:white; padding:10px 20px; border-radius:5px;
text-decoration:none; margin-right:10px;">GitHub</a>
<a href="https://medium.com/@yashvj2222" target="_blank" style="
background-color:grey; color:white; padding:10px 20px; border-radius:5px;
text-decoration:none;">Medium</a>
</div>
</div>
''', unsafe_allow_html=True)
# Individual Sections
elif st.session_state.current_page == "Problem Statement":
st.markdown("<h1 class='title'>Problem Statement</h1>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center; margin-top: 20px;">
Understand the factors contributing to attrition and develop predictive models to identify at-risk employees and
predict the whether the employee has left the company.
</h5>
""", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
st.image("images/problem_statement.png",
caption="Employee Attrition Prediction Overview",
use_container_width=True)
if st.button("🔙 Go Back to Model Pipeline"):
switch_page("Model Pipeline")
elif st.session_state.current_page == "Data Collection":
st.markdown("<h1 class='title'>Data Collection</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center; margin-top: 20px;">
The dataset used in this project is sourced from Kaggle, containing information on employee attrition
along with various workplace, economic, and demographic factors.
</h5>
""", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center; margin-top: 10px;">
📌 <a href="https://www.kaggle.com/datasets/HRAnalyticRepository/employee-attrition-data" target="_blank" style="font-weight: bold; color: #007BFF; text-decoration: none;">
Click here to access the dataset on Kaggle</a>
</h5>
""", unsafe_allow_html=True)
st.markdown("<h2 class='subtitle' style='text-align: center; margin-top: 20px;'>Dataset Overview</h2>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center; margin-top: 15px; margin-bottom: 20px;">
The dataset consists of <b>10,000 rows</b> and <b>22 columns</b>, capturing crucial indicators such as Employee Attrition,
Job Satisfaction, Monthly Income, Work-Life Balance, and more. Below is a summary of the dataset features:
</h5>
""", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
data_info = """
<div style="font-size: 16px; background-color: #F5F5F5; padding: 15px; border-radius: 10px;">
• <b>Employee ID:</b> A unique identifier assigned to each employee.<br>
• <b>Age:</b> The age of the employee, ranging from 18 to 60 years.<br>
• <b>Gender:</b> The gender of the employee.<br>
• <b>Years at Company:</b> The number of years the employee has been working at the company.<br>
• <b>Monthly Income:</b> The monthly salary of the employee, in dollars.<br>
• <b>Job Role:</b> The department or role the employee works in, categorized into Finance, Healthcare, Technology, Education, and Media.<br>
• <b>Work-Life Balance:</b> The employee's perceived balance between work and personal life (Poor, Below Average, Good, Excellent).<br>
• <b>Job Satisfaction:</b> The employee's satisfaction with their job (Very Low, Low, Medium, High).<br>
• <b>Performance Rating:</b> The employee's performance rating (Low, Below Average, Average, High).<br>
• <b>Number of Promotions:</b> The total number of promotions the employee has received.<br>
• <b>Distance from Home:</b> The distance between the employee's home and workplace, in miles.<br>
• <b>Education Level:</b> The highest education level attained by the employee (High School, Associate Degree, Bachelor’s Degree, Master’s Degree, PhD).<br>
• <b>Marital Status:</b> The marital status of the employee (Divorced, Married, Single).<br>
• <b>Job Level:</b> The job level of the employee (Entry, Mid, Senior).<br>
• <b>Company Size:</b> The size of the company the employee works for (Small, Medium, Large).<br>
• <b>Company Tenure:</b> The total number of years the employee has been working in the industry.<br>
• <b>Remote Work:</b> Whether the employee works remotely (Yes or No).<br>
• <b>Leadership Opportunities:</b> Whether the employee has leadership opportunities (Yes or No).<br>
• <b>Innovation Opportunities:</b> Whether the employee has opportunities for innovation (Yes or No).<br>
• <b>Company Reputation:</b> The employee's perception of the company's reputation (Very Poor, Poor, Good, Excellent).<br>
• <b>Employee Recognition:</b> The level of recognition the employee receives (Very Low, Low, Medium, High).<br>
• <b>Attrition:</b> Whether the employee has left the company, encoded as 0 (Stayed) and 1 (Left).<br>
</div>
"""
st.markdown(data_info, unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
if st.button("🔙 Go Back to Model Pipeline"):
switch_page("Model Pipeline")
elif st.session_state.current_page == "Simple EDA":
st.markdown("<h1 class='title'>Simple Exploratory Data Analysis</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center; margin-top: 20px;">
Exploratory Data Analysis (EDA) helps in understanding the structure, patterns, and missing values in the dataset.
Below is an initial preview of the data, followed by a missing values summary.
</h5>
""", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("<h3 class='subtitle' style='text-align: center;'>Sample Dataset</h3>", unsafe_allow_html=True)
st.dataframe(data.head())
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("<h3 class='subtitle' style='text-align: center;'>Missing Values Summary</h3>", unsafe_allow_html=True)
missing_values = data.isna().sum().reset_index()
missing_values.columns = ["Column Name", "Missing Values"]
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.dataframe(missing_values)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("<h3 class='subtitle' style='text-align: center;'>Data Description</h3>", unsafe_allow_html=True)
st.dataframe(data.describe())
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("<h3 class='subtitle' style='text-align: center;'>Boxplots for Numerical Features</h3>", unsafe_allow_html=True)
numeric_columns = ['Age', 'Years at Company', 'Monthly Income', 'Number of Promotions', 'Distance from Home', 'Company Tenure']
fig, axes = plt.subplots(nrows=len(numeric_columns), ncols=1, figsize=(10, 5 * len(numeric_columns)))
axes = axes.flatten()
for i, col in enumerate(numeric_columns):
sns.boxplot(x=data[col], ax=axes[i], color="skyblue")
axes[i].set_title(f'Boxplot of {col}', fontsize=12)
axes[i].set_xlabel("")
plt.tight_layout()
st.pyplot(fig)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("<h3 class='subtitle' style='text-align: center;'>Bar Plots for Categorical Features</h3>", unsafe_allow_html=True)
categorical_columns = [
'Performance Rating', 'Number of Promotions',
'Education Level', 'Marital Status', 'Job Level',
'Company Size', 'Company Reputation', 'Employee Recognition']
fig, axes = plt.subplots(nrows=len(categorical_columns)//2, ncols=2, figsize=(15, 25))
axes = axes.flatten()
for i, col in enumerate(categorical_columns):
sns.countplot(x=data[col], ax=axes[i], palette="coolwarm")
axes[i].set_title(f'Distribution of {col}', fontsize=12)
axes[i].set_xlabel("")
axes[i].tick_params(axis='x', rotation=45) # Rotate x labels for clarity
plt.tight_layout()
st.pyplot(fig)
st.markdown("<br>", unsafe_allow_html=True)
if st.button("🔙 Go Back to Model Pipeline"):
switch_page("Model Pipeline")
elif st.session_state.current_page == "Data Pre-processing":
st.markdown("<h1 class='title'>Data Preprocessing</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown("<h2 class='subtitle' style='text-align: center;'>Feature Engineering</h2>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center;">
<b>Refining Features for Improved Model Performance by Best-fit Random Forest</b>
</h5>
""", unsafe_allow_html=True)
st.markdown("""
<div style="
border: 1px solid #ddd;
border-radius: 8px;
padding: 15px;
background-color: #f9f9f9;
text-align: justify;">
Feature Engineering enhances the dataset by selecting, transforming, and creating new features to optimize
model predictions. In this project, the <b>Best-fit Random Forest</b> method was used to identify and refine
the most impactful features. The final set of engineered features includes:
<ul>
<li><b>Age:</b> Retained as a key factor in employee attrition analysis.</li>
<li><b>Monthly Income:</b> Helps assess financial stability and job satisfaction.</li>
<li><b>Distance from Home:</b> Used to determine the effect of commuting on attrition.</li>
<li><b>Education Level:</b> Encoded to quantify qualification influence.</li>
<li><b>Number of Dependents:</b> Indicates financial responsibilities affecting career decisions.</li>
<li><b>Company Tenure:</b> Represents employee loyalty and career growth.</li>
<li><b>Remote Work:</b> Evaluates flexibility as a factor in job retention.</li>
<li><b>Leadership Opportunities:</b> Identifies leadership roles impacting engagement.</li>
<li><b>Company Reputation:</b> Measures organizational perception and its influence on attrition.</li>
</ul>
</div>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
# Ordinal Encoding Section
st.markdown("<h2 class='subtitle' style='text-align: center;'>Ordinal Encoding</h2>", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center;">
<b>Applying Ordinal Encoding to Ordered Categories</b>
</h5>
""", unsafe_allow_html=True)
st.markdown("""
<div style="
border: 1px solid #ddd;
border-radius: 8px;
padding: 15px;
background-color: #f9f9f9;
text-align: justify;">
<b>Ordinal Encoding</b> is applied to features where the order of categories matters.
For example:
<ul>
<li><b>Education Level</b>: "High School" < "Associate Degree" < "Bachelor’s Degree" < "Master’s Degree" < "PhD"</li>
<li><b>Company Reputation</b>: "Poor" < "Fair" < "Good" < "Excellent"</li>
</ul>
This ensures that the encoded values respect the natural ordering of the categories.
</div>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
# One-Hot Encoding Section
st.markdown("<h2 class='subtitle' style='text-align: center;'>One-Hot Encoding</h2>", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center;">
<b>Applying One-Hot Encoding to Nominal Categorical Variables</b>
</h5>
""", unsafe_allow_html=True)
st.markdown("""
<div style="
border: 1px solid #ddd;
border-radius: 8px;
padding: 15px;
background-color: #f9f9f9;
text-align: justify;">
<b>One-Hot Encoding</b> is applied to categorical features that do not have an inherent order.
These include:
<ul>
<li><b>Remote Work</b>: Whether an employee works remotely or not.</li>
<li><b>Leadership Opportunities</b>: Availability of leadership roles.</li>
</ul>
To avoid multicollinearity, we drop the first category from each feature.
</div>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
# Scaling Numerical Features
st.markdown("<h2 class='subtitle' style='text-align: center;'>Scaling Continuous Variables</h2>", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("""
<h5 style="text-align: center;">
<b>Applying Robust Scaling to Handle Outliers</b>
</h5>
""", unsafe_allow_html=True)
st.markdown("""
<div style="
border: 1px solid #ddd;
border-radius: 8px;
padding: 15px;
background-color: #f9f9f9;
text-align: justify;">
<b>Robust Scaling</b> is applied to numerical features to make them less sensitive to outliers.
Instead of using the mean and standard deviation, Robust Scaling scales features using the median and interquartile range (IQR).
This is particularly useful for variables with skewed distributions, such as:
<ul>
<li>Age</li>
<li>Monthly Income</li>
<li>Distance from Home</li>
<li>Number of Dependents</li>
<li>Company Tenure</li>
</ul>
</div>
""", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
if st.button("🔙 Go Back to Model Pipeline"):
switch_page("Model Pipeline")
elif st.session_state.current_page == "EDA":
st.markdown("<h1 class='title'>Exploratory Data Analysis (EDA)</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown("<h3 class='subtitle' style='text-align: center;'>Attrition Bar Plot</h3>", unsafe_allow_html=True)
st.image("images/Attrition bar plot.png", caption="Attrition distribution", use_container_width=True)
st.markdown("""
<h5 style="text-align: center;">
Insights: Our Target Column is mostly balanced
</h5>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown("<h2 class='subtitle' style='text-align: center;'>Age Distribution by Attrition</h2>", unsafe_allow_html=True)
st.image("images/Age Distribution of Employees by Attrition.png", caption="Age Distribution by Attrition", use_container_width=True)
st.markdown("""
<h5 style="text-align: center;">
Insight: Younger employees (18-30 years) have a <b>higher attrition rate</b>, while mid-career employees (30-50 years) show more <b>job stability</b>.
Attrition slightly <b>increases after 50</b>, possibly due to retirement or career transitions.
The <b>highest employee retention</b> is observed in the <b>30-50 age group</b>, indicating job stability in mid-career.
</h5>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown("<h2 class='subtitle' style='text-align: center;'>Monthly Income Distribution by Attrition</h2>", unsafe_allow_html=True)
st.image("images/Monthly Income Distribution by Attrition.png", caption="Monthly Income vs. Attrition", use_container_width=True)
st.markdown("""
<h5 style="text-align: center;">
Insights: Mostly if we see Employess are leaving in monthly income range of Rs(6000 - 9000) and Rs(14,000 - 15,000)
</h5>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
if st.button("🔙 Go Back to Model Pipeline"):
switch_page("Model Pipeline")
# Model Building
elif st.session_state.current_page == "Model Building":
st.markdown("<h1 class='title'>Model Building</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border:2px solid #1363DF;'>", unsafe_allow_html=True)
# Introduction
st.markdown("""
<h2 style="text-align: center;">Introduction</h2>
<p style="text-align: justify;">In this section, we explore different <b>Ensemble Learning</b> techniques to improve model performance.</p>
<p style="text-align: center; font-size:18px;">
🥇 <b style="color:#FF9800;">Voting Classifier</b> - 🎯 <b style="color:#E91E63;">Bagging Classifier</b> - 🌲 <b style="color:#4CAF50;">Random Forest Classifier</b>
</p>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #bbb;'>", unsafe_allow_html=True)
st.markdown("""
<h3>🥇 Voting Classifier</h3>
<p>Voting Classifier combines predictions from multiple different models and selects the most common (majority vote) as the final output.</p>
<ul>
<li><b>Hard Voting:</b> Chooses the most frequent class label.</li>
<li><b>Soft Voting:</b> Averages predicted probabilities and selects the class with the highest probability.</li>
</ul>
<br>
<h3>🎯 Bagging Classifier</h3>
<p>Bagging (Bootstrap Aggregation) improves model performance by training multiple models on different random subsets of the data and averaging their predictions.</p>
<ul>
<li><b>Reduces variance</b> by averaging multiple high-variance models.</li>
<li>Each model is trained on a different <b>bootstrap sample</b> of the dataset.</li>
</ul>
<br>
<h3>🌲 Random Forest</h3>
<p>Random Forest is an extension of Bagging that trains multiple Decision Trees but adds randomness by selecting different feature subsets.</p>
<ul>
<li>More diverse trees reduce overfitting.</li>
<li>Handles missing values and outliers well.</li>
</ul>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #bbb;'>", unsafe_allow_html=True)
# Combining High & Low Variance Models
with st.expander("Combining High & Low Variance Models for Better Performance ⚖️"):
st.markdown("""
<p>To improve ensemble performance, we strategically combine models with different variance levels:</p>
<ul>
<li><b>Voting Classifier:</b> Mixes <b>high-variance</b> (Decision Tree, KNN with small K) & <b>low-variance</b> (KNN with large K) models.</li>
<li><b>Bagging & Random Forest:</b> Use <b>only high-variance models</b> (Deep Decision Trees) to maximize variance reduction.</li>
</ul>
<p style='color:green;'><b>✅ This balance prevents excessive overfitting or underfitting, making our model more robust!</b></p>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #bbb;'>", unsafe_allow_html=True)
# Hyperparameter Tuning Section
with st.expander("⚡ Hyperparameter Tuning using Optuna"):
st.markdown("""
<p>We optimized hyperparameters for <b>KNN, Decision Tree, Bagging Classifier, and Random Forest</b> using <b>Optuna</b> to enhance model performance.</p>
<p><b>Below are the optimized parameters:</b></p>
""", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
st.markdown("""
<h5>🔹 K-Nearest Neighbors (KNN)</h5>
<ul>
<li><code>n_neighbors</code></li>
<li><code>p</code></li>
<li><code>weights</code></li>
<li><code>algorithm</code></li>
</ul>
<h5>🔹 Decision Tree</h5>
<ul>
<li><code>max_depth</code></li>
<li><code>min_samples_split</code></li>
<li><code>min_samples_leaf</code></li>
<li><code>max_features</code></li>
<li><code>min_impurity_decrease</code></li>
</ul>
""", unsafe_allow_html=True)
with col2:
st.markdown("""
<h5>🔹 Bagging Classifier</h5>
<ul>
<li><code>n_estimators</code></li>
<li><code>max_samples</code></li>
</ul>
<h5>🔹 Random Forest</h5>
<ul>
<li><code>n_estimators</code></li>
<li><code>max_samples</code></li>
</ul>
""", unsafe_allow_html=True)
st.markdown("<hr style='border:2px solid #1363DF;'>", unsafe_allow_html=True)
# Model Performance
st.markdown("<h2 style='text-align: center;'>Model Performance</h2>", unsafe_allow_html=True)
st.markdown("""
<style>
table {
width: 100%;
border-collapse: collapse;
text-align: center;
font-size: 18px;
}
th, td {
padding: 12px;
border-bottom: 2px solid #ddd;
}
th {
background-color: #1363DF;
color: white;
}
tr:nth-child(even) {background-color: #F9F9F9;}
tr:hover {background-color: #E3F2FD;}
</style>
""", unsafe_allow_html=True)
st.markdown("""
<table>
<tr>
<th>Ensemble Model</th>
<th>Training Score</th>
<th>Test Score</th>
<th>Generalized Score</th>
</tr>
<tr>
<td>Voting Ensemble</td>
<td>98.80%</td>
<td>80.00%</td>
<td>72.33%</td>
</tr>
<tr>
<td>Bagging Ensemble</td>
<td>85.18%</td>
<td>78.07%</td>
<td>71.64%</td>
</tr>
<tr>
<td>Random Forest</td>
<td>94.88%</td>
<td>83.27%</td>
<td><b>75.2%</b></td>
</tr>
</table>
""", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
if st.button("🔙 Go Back to Model Pipeline"):
switch_page("Model Pipeline")
# Fianl Model
elif st.session_state.current_page == "Final Model":
st.markdown(
"""
<style>
.title {
text-align: center;
font-size: 36px;
font-weight: bold;
color: #1E3A8A;
}
.subtitle {
text-align: center;
font-size: 20px;
color: #475569;
margin-bottom: 20px;
}
.image-container {
display: flex;
justify-content: center;
}
.caption {
text-align: center;
font-size: 16px;
font-style: italic;
color: #6B7280;
}
.box {
background-color: #F8FAFC;
padding: 15px;
border-radius: 10px;
box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
margin-bottom: 20px;
}
</style>
""",
unsafe_allow_html=True,
)
# Title
st.markdown("<h1 class='title'>Final Model</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown(
"<div class='box'>"
"<p><strong>After experimenting with multiple trials using Optuna, we selected the best-fit model "
"by analyzing the training and test scores of different trials. "
"The following scatter plots provide insights into this selection process.</strong></p>"
"</div>",
unsafe_allow_html=True,
)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: center;'>Training vs Test Score</h3>", unsafe_allow_html=True)
st.markdown(
"<p class='subtitle'>This scatter plot visualizes the training and test scores of all trials. "
"The goal was to identify a model where both scores are closely aligned, ensuring minimal overfitting or underfitting.</p>",
unsafe_allow_html=True,
)
st.image("images/random_forest.png",
caption="All Trails",
use_container_width=True)
st.markdown(
"<p style='text-align: center; font-weight: bold; font-size: 16px;'>"
"From the above trials, we selected the <b>29th trial</b> as its train score and test score have minimal difference."
"</p>",
unsafe_allow_html=True
)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: center;'>Selected Best-Fit Model</h3>", unsafe_allow_html=True)
st.markdown(
"<div class='box'>"
"<li><b>Ensemble Method:</b> RandomForestClassifier</li>"
"<li><b>Hyperparameters:</b>"
"<ul>"
"<li>n_estimators = 84</li>"
"<li>max_samples = 0.604329</li>"
"<li>max_features = 'sqrt'</li>"
"<li>min_samples_leaf = 5</li>"
"<li>min_samples_split = 7</li>"
"<li>max_depth = 20</li>"
"</ul></li>"
"</ul>"
"<p>This model was selected as it demonstrated a balance between generalization and performance.</p>"
"</div>",
unsafe_allow_html=True,
)
if st.button("🔙 Go Back to Model Pipeline"):
switch_page("Model Pipeline")
# Hands-on Model Page
elif st.session_state.current_page == "Hands-on Model":
st.markdown("<h1 class='title'>Hands-on Model</h1>", unsafe_allow_html=True)
st.markdown("<hr style='border:1px solid #ddd;'>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
age = st.slider("Age", 18, 59, 38)
monthly_income = st.slider("Monthly Income", 1316, 14276, 7312)
distance_from_home = st.slider("Distance from Home", 1, 99, 50)
education_level = st.selectbox("Education Level", ["High School", "Associate Degree", "Bachelor’s Degree", "Master’s Degree", "PhD"])
num_dependents = st.slider("Number of Dependents", 0, 6, 1)
with col2:
company_tenure = st.slider("Company Tenure", 2, 126, 55)
remote_work = st.selectbox("Remote Work", ["Yes", "No"])
leadership_opportunities = st.selectbox("Leadership Opportunities", ["Yes", "No"])
company_reputation = st.selectbox("Company Reputation", ["Poor", "Fair", "Good", "Excellent"])
input_data = np.array([[age,monthly_income,distance_from_home, education_level, num_dependents,
company_tenure, remote_work,leadership_opportunities, company_reputation]])
if st.button("Predict Attrition"):
prediction = model.predict(input_data)
result = "Likely to Stay" if prediction[0] == 0 else "Likely to Leave"
result_class = "stay" if prediction[0] == 0 else "leave"
st.markdown(
f"""
<div class="result-box {result_class}">
{result}
</div>
""",
unsafe_allow_html=True,
)