Machine_learning / pages /10_Decision_Tree.py
Harika22's picture
Update pages/10_Decision_Tree.py
c70fab9 verified
import streamlit as st
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.tree import export_graphviz
import graphviz
import pandas as pd
st.set_page_config(page_title="Decision Tree Explorer", page_icon="🌳", layout="wide")
st.title("🌳 Decision Tree Algorithm Explorer")
st.write("Understand how Decision Trees work with simple explanations, visuals, and real-world examples.")
section = st.radio("Choose a topic to explore:", [
"What is a Decision Tree?",
"How It Works",
"Entropy vs Gini",
"Tree Construction",
"Classification vs Regression",
"Pruning",
"Feature Importance",
"Visualize Example Tree",
"Try with Iris Data"
])
if section == "What is a Decision Tree?":
st.header("πŸ“˜ What is a Decision Tree?")
st.markdown("""
A **Decision Tree** is a flowchart-like model that makes decisions based on a series of questions.
- 🎯 Used in both **classification** (e.g., spam vs. not spam) and **regression** (e.g., predicting price).
- 🌱 It starts at a **root**, asks a question, and branches out based on answers.
- πŸ”š Ends at a **leaf node** β€” which is the prediction.
**Real-life example:**
You're deciding what to wear. You ask:
1. Is it raining?
2. Is it cold?
β†’ Based on your answers, you decide: jacket, umbrella, or just a T-shirt.
""")
elif section == "How It Works":
st.header("βš™οΈ How Does It Work?")
st.markdown("""
**Step-by-step:**
1. Start with the whole dataset.
2. Choose the feature that best splits the data.
3. Split the dataset.
4. Repeat until you reach a stopping condition.
**Used concepts:**
- Entropy (information gain)
- Gini impurity
""")
elif section == "Entropy vs Gini":
st.header("πŸ“Š Entropy vs Gini")
st.markdown("""
### Entropy
Measures randomness or disorder in data.
$$
H(Y) = - \sum p_i \log_2 p_i
$$
### Gini Impurity
Measures the probability of wrong classification.
$$
Gini(Y) = 1 - \sum p_i^2
$$
**Which to use?**
- Gini is faster β†’ default in scikit-learn.
- Entropy gives more information-theoretic understanding.
""")
elif section == "Tree Construction":
st.header("πŸ”§ How is the Tree Built?")
st.markdown("""
The tree is built **top-down** using a greedy algorithm:
- Best feature is chosen using Gini or Entropy.
- Splits continue until stopping criteria (e.g., max depth, pure leaf).
**Tip**: Too many splits = overfitting!
""")
elif section == "Classification vs Regression":
st.header("πŸ“ˆ Classification vs Regression")
st.markdown("""
- **Classification Tree**: Predicts categories (Yes/No, Spam/Ham).
- **Regression Tree**: Predicts continuous values (e.g., house price).
**Example:**
- Classification: Will a customer churn?
- Regression: What will be the next month’s sales?
""")
elif section == "Pruning":
st.header("βœ‚οΈ Pruning Techniques")
st.markdown("""
**Why prune?**
To avoid overfitting by cutting unnecessary branches.
### Pre-Pruning
- `max_depth`: limit depth
- `min_samples_split`: split only if enough samples
- `min_samples_leaf`: limit how small leaves can be
### Post-Pruning
- Cost Complexity Pruning (using Ξ±)
""")
elif section == "Feature Importance":
st.header("πŸ“Œ Feature Importance")
st.markdown("""
Decision Trees calculate how important each feature is by how much it reduces impurity.
**Formula:**
$$
Importance = \frac{Total\ Gain\ from\ Feature}{Total\ Gain\ from\ All\ Features}
$$
Useful for feature selection and explaining model decisions.
""")
elif section == "Visualize Example Tree":
st.header("🌿 Visualize a Small Tree Example")
iris = load_iris()
clf = tree.DecisionTreeClassifier(max_depth=3)
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
st.graphviz_chart(dot_data)
elif section == "Try with Iris Data":
st.header("🌸 Try with Iris Dataset")
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target
st.write("Here's a preview of the dataset:")
st.dataframe(df.head())
st.markdown("### Build and visualize a Decision Tree")
max_depth = st.slider("Select max depth of the tree:", 1, 5, 3)
clf = tree.DecisionTreeClassifier(max_depth=max_depth)
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True)
st.graphviz_chart(dot_data)
st.markdown("---")
st.success("βœ… Decision Trees are simple yet powerful! Tune them well, visualize their structure, and understand every split.")