Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.datasets import load_iris | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.tree import DecisionTreeClassifier, plot_tree | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.metrics import classification_report, accuracy_score | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| st.set_page_config(page_title="Explore Decision Tree Algorithm", layout="wide") | |
| st.title("π³ Decision Tree Classifier Explained") | |
| # Tabs for better UX | |
| tab1, tab2, tab3 = st.tabs(["π About Decision Tree", "π§ͺ Try It Out", "π Visualization"]) | |
| with tab1: | |
| st.markdown(""" | |
| ## π³ What is a Decision Tree? | |
| A **Decision Tree** is a flowchart-like structure used in supervised learning for both **classification** and **regression** tasks. | |
| It splits the dataset into smaller and smaller subsets while at the same time an associated decision tree is incrementally developed. | |
| ### π§ Core Concepts: | |
| - **Root Node**: The first decision point (split) based on the most important feature. | |
| - **Internal Nodes**: Decision points that split data based on feature values. | |
| - **Leaf Nodes**: The final output class (for classification) or value (for regression). | |
| - **Branches**: Represent outcomes of a decision or test. | |
| --- | |
| ### π How it Works: | |
| 1. The algorithm chooses a feature and a split point that best separates the data. | |
| 2. Splits are made recursively until a stopping criterion is met (like max depth or pure leaves). | |
| 3. At each node, it chooses the split that **maximizes information gain** or **minimizes impurity**. | |
| --- | |
| ### βοΈ Splitting Criteria: | |
| - **Gini Impurity**: | |
| - Measures how often a randomly chosen element would be incorrectly labeled. | |
| - Best when you want a fast computation. | |
| - Formula: $Gini(t) = 1 - \\sum p(i)^2$ | |
| - **Entropy (Information Gain)**: | |
| - Measures the level of disorder or uncertainty. | |
| - Based on the concept from information theory. | |
| - Formula: $Entropy(t) = - \\sum p(i) \\log_2(p(i))$ | |
| --- | |
| ### π Visual Intuition: | |
| Imagine a dataset of flowers. You might first split by **petal length**. One branch contains flowers with short petals, another with long petals. | |
| Then, each of those groups is further split based on **petal width**, and so on β until the flowers are grouped into species. | |
| --- | |
| ### π¦ Use Cases: | |
| - Medical Diagnosis (e.g., "Is this tumor malignant or benign?") | |
| - Loan Approval ("Should this applicant get a loan?") | |
| - Customer Churn Prediction | |
| - Email Spam Detection | |
| - Any tabular data with clear patterns and labeled outcomes | |
| --- | |
| ### β Advantages: | |
| - Simple to understand and interpret (like flowcharts) | |
| - Requires little data preprocessing (no scaling or normalization) | |
| - Can handle both numerical and categorical data | |
| - Easily visualized | |
| ### β οΈ Limitations: | |
| - Prone to overfitting if not pruned or constrained | |
| - Small changes in data can lead to very different trees | |
| - Not good with very large datasets or high-dimensional data alone (better with ensembles like Random Forest) | |
| --- | |
| π― **Tip**: Use pruning, set a maximum depth, or use ensemble methods (like Random Forest or Gradient Boosted Trees) to enhance performance and reduce overfitting. | |
| """) | |
| with tab2: | |
| st.subheader("πΈ Train a Decision Tree on the Iris Dataset") | |
| iris = load_iris() | |
| df = pd.DataFrame(iris.data, columns=iris.feature_names) | |
| df['target'] = iris.target | |
| st.dataframe(df.head(), use_container_width=True) | |
| criterion = st.radio("Select Splitting Criterion", ["gini", "entropy"]) | |
| max_depth = st.slider("Select Max Depth", 1, 10, value=3) | |
| X = df.drop('target', axis=1) | |
| y = df['target'] | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth, random_state=42) | |
| model.fit(X_train, y_train) | |
| y_pred = model.predict(X_test) | |
| acc = accuracy_score(y_test, y_pred) | |
| st.success(f"β Model Accuracy: {acc*100:.2f}%") | |
| st.markdown("### π Classification Report") | |
| st.text(classification_report(y_test, y_pred, target_names=iris.target_names)) | |
| with tab3: | |
| st.subheader("π Visualize the Decision Tree") | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| plot_tree(model, feature_names=iris.feature_names, class_names=iris.target_names, filled=True) | |
| st.pyplot(fig) | |
| st.subheader("π Feature Importance") | |
| importance_df = pd.DataFrame({ | |
| 'Feature': iris.feature_names, | |
| 'Importance': model.feature_importances_ | |
| }).sort_values(by="Importance", ascending=False) | |
| st.dataframe(importance_df, use_container_width=True) | |
| fig2, ax2 = plt.subplots() | |
| sns.barplot(x='Importance', y='Feature', data=importance_df, palette='viridis') | |
| st.pyplot(fig2) | |