import streamlit as st from sklearn import tree from sklearn.datasets import load_iris from sklearn.tree import export_graphviz import graphviz import pandas as pd st.set_page_config(page_title="Decision Tree Explorer", page_icon="🌳", layout="wide") st.title("🌳 Decision Tree Algorithm Explorer") st.write("Understand how Decision Trees work with simple explanations, visuals, and real-world examples.") section = st.radio("Choose a topic to explore:", [ "What is a Decision Tree?", "How It Works", "Entropy vs Gini", "Tree Construction", "Classification vs Regression", "Pruning", "Feature Importance", "Visualize Example Tree", "Try with Iris Data" ]) if section == "What is a Decision Tree?": st.header("πŸ“˜ What is a Decision Tree?") st.markdown(""" A **Decision Tree** is a flowchart-like model that makes decisions based on a series of questions. - 🎯 Used in both **classification** (e.g., spam vs. not spam) and **regression** (e.g., predicting price). - 🌱 It starts at a **root**, asks a question, and branches out based on answers. - πŸ”š Ends at a **leaf node** β€” which is the prediction. **Real-life example:** You're deciding what to wear. You ask: 1. Is it raining? 2. Is it cold? β†’ Based on your answers, you decide: jacket, umbrella, or just a T-shirt. """) elif section == "How It Works": st.header("βš™οΈ How Does It Work?") st.markdown(""" **Step-by-step:** 1. Start with the whole dataset. 2. Choose the feature that best splits the data. 3. Split the dataset. 4. Repeat until you reach a stopping condition. **Used concepts:** - Entropy (information gain) - Gini impurity """) elif section == "Entropy vs Gini": st.header("πŸ“Š Entropy vs Gini") st.markdown(""" ### Entropy Measures randomness or disorder in data. $$ H(Y) = - \sum p_i \log_2 p_i $$ ### Gini Impurity Measures the probability of wrong classification. $$ Gini(Y) = 1 - \sum p_i^2 $$ **Which to use?** - Gini is faster β†’ default in scikit-learn. - Entropy gives more information-theoretic understanding. """) elif section == "Tree Construction": st.header("πŸ”§ How is the Tree Built?") st.markdown(""" The tree is built **top-down** using a greedy algorithm: - Best feature is chosen using Gini or Entropy. - Splits continue until stopping criteria (e.g., max depth, pure leaf). **Tip**: Too many splits = overfitting! """) elif section == "Classification vs Regression": st.header("πŸ“ˆ Classification vs Regression") st.markdown(""" - **Classification Tree**: Predicts categories (Yes/No, Spam/Ham). - **Regression Tree**: Predicts continuous values (e.g., house price). **Example:** - Classification: Will a customer churn? - Regression: What will be the next month’s sales? """) elif section == "Pruning": st.header("βœ‚οΈ Pruning Techniques") st.markdown(""" **Why prune?** To avoid overfitting by cutting unnecessary branches. ### Pre-Pruning - `max_depth`: limit depth - `min_samples_split`: split only if enough samples - `min_samples_leaf`: limit how small leaves can be ### Post-Pruning - Cost Complexity Pruning (using Ξ±) """) elif section == "Feature Importance": st.header("πŸ“Œ Feature Importance") st.markdown(""" Decision Trees calculate how important each feature is by how much it reduces impurity. **Formula:** $$ Importance = \frac{Total\ Gain\ from\ Feature}{Total\ Gain\ from\ All\ Features} $$ Useful for feature selection and explaining model decisions. """) elif section == "Visualize Example Tree": st.header("🌿 Visualize a Small Tree Example") iris = load_iris() clf = tree.DecisionTreeClassifier(max_depth=3) clf = clf.fit(iris.data, iris.target) dot_data = tree.export_graphviz(clf, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True) st.graphviz_chart(dot_data) elif section == "Try with Iris Data": st.header("🌸 Try with Iris Dataset") iris = load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) df['target'] = iris.target st.write("Here's a preview of the dataset:") st.dataframe(df.head()) st.markdown("### Build and visualize a Decision Tree") max_depth = st.slider("Select max depth of the tree:", 1, 5, 3) clf = tree.DecisionTreeClassifier(max_depth=max_depth) clf = clf.fit(iris.data, iris.target) dot_data = tree.export_graphviz(clf, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True) st.graphviz_chart(dot_data) st.markdown("---") st.success("βœ… Decision Trees are simple yet powerful! Tune them well, visualize their structure, and understand every split.")