ML_ALGORITHMS / pages /Decision_Tree.py
sree4411's picture
Update pages/Decision_Tree.py
96c53e1 verified
import streamlit as st
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
st.set_page_config(page_title="Explore Decision Tree Algorithm", layout="wide")
st.title("🌳 Decision Tree Classifier Explained")
# Tabs for better UX
tab1, tab2, tab3 = st.tabs(["πŸ“˜ About Decision Tree", "πŸ§ͺ Try It Out", "πŸ“ˆ Visualization"])
with tab1:
st.markdown("""
## 🌳 What is a Decision Tree?
A **Decision Tree** is a flowchart-like structure used in supervised learning for both **classification** and **regression** tasks.
It splits the dataset into smaller and smaller subsets while at the same time an associated decision tree is incrementally developed.
### 🧠 Core Concepts:
- **Root Node**: The first decision point (split) based on the most important feature.
- **Internal Nodes**: Decision points that split data based on feature values.
- **Leaf Nodes**: The final output class (for classification) or value (for regression).
- **Branches**: Represent outcomes of a decision or test.
---
### πŸ“Š How it Works:
1. The algorithm chooses a feature and a split point that best separates the data.
2. Splits are made recursively until a stopping criterion is met (like max depth or pure leaves).
3. At each node, it chooses the split that **maximizes information gain** or **minimizes impurity**.
---
### βš–οΈ Splitting Criteria:
- **Gini Impurity**:
- Measures how often a randomly chosen element would be incorrectly labeled.
- Best when you want a fast computation.
- Formula: $Gini(t) = 1 - \\sum p(i)^2$
- **Entropy (Information Gain)**:
- Measures the level of disorder or uncertainty.
- Based on the concept from information theory.
- Formula: $Entropy(t) = - \\sum p(i) \\log_2(p(i))$
---
### πŸ” Visual Intuition:
Imagine a dataset of flowers. You might first split by **petal length**. One branch contains flowers with short petals, another with long petals.
Then, each of those groups is further split based on **petal width**, and so on β€” until the flowers are grouped into species.
---
### πŸ“¦ Use Cases:
- Medical Diagnosis (e.g., "Is this tumor malignant or benign?")
- Loan Approval ("Should this applicant get a loan?")
- Customer Churn Prediction
- Email Spam Detection
- Any tabular data with clear patterns and labeled outcomes
---
### βœ… Advantages:
- Simple to understand and interpret (like flowcharts)
- Requires little data preprocessing (no scaling or normalization)
- Can handle both numerical and categorical data
- Easily visualized
### ⚠️ Limitations:
- Prone to overfitting if not pruned or constrained
- Small changes in data can lead to very different trees
- Not good with very large datasets or high-dimensional data alone (better with ensembles like Random Forest)
---
🎯 **Tip**: Use pruning, set a maximum depth, or use ensemble methods (like Random Forest or Gradient Boosted Trees) to enhance performance and reduce overfitting.
""")
with tab2:
st.subheader("🌸 Train a Decision Tree on the Iris Dataset")
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target
st.dataframe(df.head(), use_container_width=True)
criterion = st.radio("Select Splitting Criterion", ["gini", "entropy"])
max_depth = st.slider("Select Max Depth", 1, 10, value=3)
X = df.drop('target', axis=1)
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
st.success(f"βœ… Model Accuracy: {acc*100:.2f}%")
st.markdown("### πŸ“Š Classification Report")
st.text(classification_report(y_test, y_pred, target_names=iris.target_names))
with tab3:
st.subheader("πŸ“‰ Visualize the Decision Tree")
fig, ax = plt.subplots(figsize=(12, 6))
plot_tree(model, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
st.pyplot(fig)
st.subheader("πŸ“Œ Feature Importance")
importance_df = pd.DataFrame({
'Feature': iris.feature_names,
'Importance': model.feature_importances_
}).sort_values(by="Importance", ascending=False)
st.dataframe(importance_df, use_container_width=True)
fig2, ax2 = plt.subplots()
sns.barplot(x='Importance', y='Feature', data=importance_df, palette='viridis')
st.pyplot(fig2)