import streamlit as st import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, plot_tree from sklearn.preprocessing import StandardScaler from sklearn.metrics import classification_report, accuracy_score, confusion_matrix # Set up the Streamlit page st.set_page_config(page_title="Explore Decision Tree Algorithm", layout="wide") st.title("🌳 Decision Tree Classifier: Explained with the Iris Dataset") # ------------------------------------ # Introduction # ------------------------------------ st.markdown(""" ## 🧠 What is a Decision Tree? A **Decision Tree** is a popular machine learning algorithm that uses a tree-like structure to make decisions. Each **internal node** asks a question about a feature, each **branch** represents the outcome of that question, and each **leaf node** gives the final prediction. > 🧩 Think of it like playing "20 Questions" β€” each question helps narrow down the possibilities. --- ## βš™οΈ How Decision Trees Work 1. Start with all the data at the root. 2. Select the **best feature** to split the data (based on Gini or Entropy). 3. Repeat the splitting process on each subset until: - All points are classified - Or a **stopping condition** (like max depth) is met πŸ” Criteria used to choose the best feature: - **Gini Index** (default) - **Entropy** (Information Gain) --- ### πŸ“ˆ Pros and Cons βœ… Easy to understand and visualize βœ… Handles both numerical and categorical features βœ… No need for feature scaling ⚠️ Prone to overfitting β€” use `max_depth`, `min_samples_leaf`, or pruning --- """) # ------------------------------------ # Load and Explore the Dataset # ------------------------------------ st.subheader("🌼 Let's Explore the Iris Dataset") iris = load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) df["target"] = iris.target df["species"] = df["target"].apply(lambda x: iris.target_names[x]) st.markdown("Here's a quick look at the dataset πŸ‘‡") st.dataframe(df.head(), use_container_width=True) # ------------------------------------ # Feature Visualization # ------------------------------------ st.markdown("### πŸ“Š Visualize Feature Relationships") selected_features = st.multiselect("Pick two features to visualize", iris.feature_names, default=iris.feature_names[:2]) if len(selected_features) == 2: plt.figure(figsize=(8, 5)) sns.scatterplot(data=df, x=selected_features[0], y=selected_features[1], hue="species", palette="Set2", s=80) st.pyplot(plt.gcf()) plt.clf() # ------------------------------------ # Sidebar: Model Settings # ------------------------------------ st.sidebar.header("🌲 Model Settings") criterion = st.sidebar.radio("Splitting Criterion", ["gini", "entropy"]) max_depth = st.sidebar.slider("Max Depth", min_value=1, max_value=10, value=3) # ------------------------------------ # Preprocessing and Train/Test Split # ------------------------------------ X = df[iris.feature_names] y = df["target"] scaler = StandardScaler() X_scaled = scaler.fit_transform(X) X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42) # ------------------------------------ # Train Model # ------------------------------------ model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth, random_state=42) model.fit(X_train, y_train) y_pred = model.predict(X_test) # ------------------------------------ # Performance Metrics # ------------------------------------ acc = accuracy_score(y_test, y_pred) st.success(f"βœ… Model Accuracy: {acc*100:.2f}%") st.markdown("### 🧾 Classification Report") st.text(classification_report(y_test, y_pred, target_names=iris.target_names)) st.markdown("### πŸ” Confusion Matrix") cm = confusion_matrix(y_test, y_pred) fig, ax = plt.subplots() sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=iris.target_names, yticklabels=iris.target_names) plt.xlabel("Predicted") plt.ylabel("Actual") st.pyplot(fig) # ------------------------------------ # Visualize Decision Tree # ------------------------------------ st.markdown("### 🌳 Visualizing the Tree Structure") fig, ax = plt.subplots(figsize=(12, 6)) plot_tree(model, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, fontsize=10) st.pyplot(fig) # ------------------------------------ # Final Thoughts # ------------------------------------ st.markdown(""" --- ## πŸ’‘ Key Takeaways - Decision Trees offer **clear visual explanations** of how decisions are made. - They need **very little preprocessing** (like normalization or encoding). - They’re easy to overfit on small datasets β€” control complexity with `max_depth`, `min_samples_leaf`, or **pruning**. ## πŸ“Œ When Should You Use a Decision Tree? - When model **interpretability** is important - When your data contains both **numerical and categorical** features - When you need a **fast prototype** > 🎯 *Pro Tip:* Use ensembles like **Random Forest** or **Gradient Boosting** for better performance in real-world scenarios. --- """)