Ramyamaheswari's picture
Update app.py
e2335ab verified
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
# Set up the Streamlit page
st.set_page_config(page_title="Explore Decision Tree Algorithm", layout="wide")
st.title("🌳 Decision Tree Classifier: Explained with the Iris Dataset")
# ------------------------------------
# Introduction
# ------------------------------------
st.markdown("""
## 🧠 What is a Decision Tree?
A **Decision Tree** is a popular machine learning algorithm that uses a tree-like structure to make decisions.
Each **internal node** asks a question about a feature, each **branch** represents the outcome of that question, and each **leaf node** gives the final prediction.
> 🧩 Think of it like playing "20 Questions" β€” each question helps narrow down the possibilities.
---
## βš™οΈ How Decision Trees Work
1. Start with all the data at the root.
2. Select the **best feature** to split the data (based on Gini or Entropy).
3. Repeat the splitting process on each subset until:
- All points are classified
- Or a **stopping condition** (like max depth) is met
πŸ” Criteria used to choose the best feature:
- **Gini Index** (default)
- **Entropy** (Information Gain)
---
### πŸ“ˆ Pros and Cons
βœ… Easy to understand and visualize
βœ… Handles both numerical and categorical features
βœ… No need for feature scaling
⚠️ Prone to overfitting β€” use `max_depth`, `min_samples_leaf`, or pruning
---
""")
# ------------------------------------
# Load and Explore the Dataset
# ------------------------------------
st.subheader("🌼 Let's Explore the Iris Dataset")
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df["target"] = iris.target
df["species"] = df["target"].apply(lambda x: iris.target_names[x])
st.markdown("Here's a quick look at the dataset πŸ‘‡")
st.dataframe(df.head(), use_container_width=True)
# ------------------------------------
# Feature Visualization
# ------------------------------------
st.markdown("### πŸ“Š Visualize Feature Relationships")
selected_features = st.multiselect("Pick two features to visualize", iris.feature_names, default=iris.feature_names[:2])
if len(selected_features) == 2:
plt.figure(figsize=(8, 5))
sns.scatterplot(data=df, x=selected_features[0], y=selected_features[1], hue="species", palette="Set2", s=80)
st.pyplot(plt.gcf())
plt.clf()
# ------------------------------------
# Sidebar: Model Settings
# ------------------------------------
st.sidebar.header("🌲 Model Settings")
criterion = st.sidebar.radio("Splitting Criterion", ["gini", "entropy"])
max_depth = st.sidebar.slider("Max Depth", min_value=1, max_value=10, value=3)
# ------------------------------------
# Preprocessing and Train/Test Split
# ------------------------------------
X = df[iris.feature_names]
y = df["target"]
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
# ------------------------------------
# Train Model
# ------------------------------------
model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
# ------------------------------------
# Performance Metrics
# ------------------------------------
acc = accuracy_score(y_test, y_pred)
st.success(f"βœ… Model Accuracy: {acc*100:.2f}%")
st.markdown("### 🧾 Classification Report")
st.text(classification_report(y_test, y_pred, target_names=iris.target_names))
st.markdown("### πŸ” Confusion Matrix")
cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel("Predicted")
plt.ylabel("Actual")
st.pyplot(fig)
# ------------------------------------
# Visualize Decision Tree
# ------------------------------------
st.markdown("### 🌳 Visualizing the Tree Structure")
fig, ax = plt.subplots(figsize=(12, 6))
plot_tree(model, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, fontsize=10)
st.pyplot(fig)
# ------------------------------------
# Final Thoughts
# ------------------------------------
st.markdown("""
---
## πŸ’‘ Key Takeaways
- Decision Trees offer **clear visual explanations** of how decisions are made.
- They need **very little preprocessing** (like normalization or encoding).
- They’re easy to overfit on small datasets β€” control complexity with `max_depth`, `min_samples_leaf`, or **pruning**.
## πŸ“Œ When Should You Use a Decision Tree?
- When model **interpretability** is important
- When your data contains both **numerical and categorical** features
- When you need a **fast prototype**
> 🎯 *Pro Tip:* Use ensembles like **Random Forest** or **Gradient Boosting** for better performance in real-world scenarios.
---
""")