import streamlit as st import pandas as pd import numpy as np from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, plot_tree from sklearn.preprocessing import StandardScaler from sklearn.metrics import classification_report, accuracy_score import matplotlib.pyplot as plt import seaborn as sns st.set_page_config(page_title="Explore Decision Tree Algorithm", layout="wide") st.title("๐ŸŒณ Decision Tree Classifier Explained") # Tabs for better UX tab1, tab2, tab3 = st.tabs(["๐Ÿ“˜ About Decision Tree", "๐Ÿงช Try It Out", "๐Ÿ“ˆ Visualization"]) with tab1: st.markdown(""" ## ๐ŸŒณ What is a Decision Tree? A **Decision Tree** is a flowchart-like structure used in supervised learning for both **classification** and **regression** tasks. It splits the dataset into smaller and smaller subsets while at the same time an associated decision tree is incrementally developed. ### ๐Ÿง  Core Concepts: - **Root Node**: The first decision point (split) based on the most important feature. - **Internal Nodes**: Decision points that split data based on feature values. - **Leaf Nodes**: The final output class (for classification) or value (for regression). - **Branches**: Represent outcomes of a decision or test. --- ### ๐Ÿ“Š How it Works: 1. The algorithm chooses a feature and a split point that best separates the data. 2. Splits are made recursively until a stopping criterion is met (like max depth or pure leaves). 3. At each node, it chooses the split that **maximizes information gain** or **minimizes impurity**. --- ### โš–๏ธ Splitting Criteria: - **Gini Impurity**: - Measures how often a randomly chosen element would be incorrectly labeled. - Best when you want a fast computation. - Formula: $Gini(t) = 1 - \\sum p(i)^2$ - **Entropy (Information Gain)**: - Measures the level of disorder or uncertainty. - Based on the concept from information theory. - Formula: $Entropy(t) = - \\sum p(i) \\log_2(p(i))$ --- ### ๐Ÿ” Visual Intuition: Imagine a dataset of flowers. You might first split by **petal length**. One branch contains flowers with short petals, another with long petals. Then, each of those groups is further split based on **petal width**, and so on โ€” until the flowers are grouped into species. --- ### ๐Ÿ“ฆ Use Cases: - Medical Diagnosis (e.g., "Is this tumor malignant or benign?") - Loan Approval ("Should this applicant get a loan?") - Customer Churn Prediction - Email Spam Detection - Any tabular data with clear patterns and labeled outcomes --- ### โœ… Advantages: - Simple to understand and interpret (like flowcharts) - Requires little data preprocessing (no scaling or normalization) - Can handle both numerical and categorical data - Easily visualized ### โš ๏ธ Limitations: - Prone to overfitting if not pruned or constrained - Small changes in data can lead to very different trees - Not good with very large datasets or high-dimensional data alone (better with ensembles like Random Forest) --- ๐ŸŽฏ **Tip**: Use pruning, set a maximum depth, or use ensemble methods (like Random Forest or Gradient Boosted Trees) to enhance performance and reduce overfitting. """) with tab2: st.subheader("๐ŸŒธ Train a Decision Tree on the Iris Dataset") iris = load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) df['target'] = iris.target st.dataframe(df.head(), use_container_width=True) criterion = st.radio("Select Splitting Criterion", ["gini", "entropy"]) max_depth = st.slider("Select Max Depth", 1, 10, value=3) X = df.drop('target', axis=1) y = df['target'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth, random_state=42) model.fit(X_train, y_train) y_pred = model.predict(X_test) acc = accuracy_score(y_test, y_pred) st.success(f"โœ… Model Accuracy: {acc*100:.2f}%") st.markdown("### ๐Ÿ“Š Classification Report") st.text(classification_report(y_test, y_pred, target_names=iris.target_names)) with tab3: st.subheader("๐Ÿ“‰ Visualize the Decision Tree") fig, ax = plt.subplots(figsize=(12, 6)) plot_tree(model, feature_names=iris.feature_names, class_names=iris.target_names, filled=True) st.pyplot(fig) st.subheader("๐Ÿ“Œ Feature Importance") importance_df = pd.DataFrame({ 'Feature': iris.feature_names, 'Importance': model.feature_importances_ }).sort_values(by="Importance", ascending=False) st.dataframe(importance_df, use_container_width=True) fig2, ax2 = plt.subplots() sns.barplot(x='Importance', y='Feature', data=importance_df, palette='viridis') st.pyplot(fig2)