Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +168 -0
- requirements.txt +0 -0
app.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from sklearn import datasets
|
| 7 |
+
from sklearn.model_selection import train_test_split
|
| 8 |
+
from sklearn.tree import DecisionTreeClassifier
|
| 9 |
+
from sklearn import tree
|
| 10 |
+
from sklearn.metrics import accuracy_score
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
|
| 13 |
+
st.set_page_config(
|
| 14 |
+
page_title="Decision Tree Visualizer",
|
| 15 |
+
page_icon=":chart_with_upwards_trend:",
|
| 16 |
+
layout="wide",
|
| 17 |
+
initial_sidebar_state="expanded")
|
| 18 |
+
|
| 19 |
+
# load dataset
|
| 20 |
+
iris=datasets.load_iris()
|
| 21 |
+
x = iris.data
|
| 22 |
+
y = iris.target
|
| 23 |
+
x_train, x_test, y_train, y_test = train_test_split(x,y, test_size=0.2,random_state=42)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# constants
|
| 27 |
+
min_weight_fraction_leaf=0.0
|
| 28 |
+
max_features = None
|
| 29 |
+
max_leaf_nodes = None
|
| 30 |
+
min_impurity_decrease=0.0
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Load initial graph
|
| 36 |
+
fig, ax = plt.subplots()
|
| 37 |
+
|
| 38 |
+
# Plot initial graph
|
| 39 |
+
scatter = ax.scatter(x.T[0], x.T[1], c=y, cmap='rainbow')
|
| 40 |
+
ax.set_xlabel(iris.feature_names[0], fontsize=10)
|
| 41 |
+
ax.set_ylabel(iris.feature_names[1],fontsize=10)
|
| 42 |
+
ax.set_title('Sepal Length vs Sepal Width', fontsize=15)
|
| 43 |
+
legend1 = ax.legend(*scatter.legend_elements(),
|
| 44 |
+
title="Classes",loc="upper right")
|
| 45 |
+
ax.add_artist(legend1)
|
| 46 |
+
ax.legend()
|
| 47 |
+
orig = st.pyplot(fig)
|
| 48 |
+
|
| 49 |
+
# sidebar elements
|
| 50 |
+
st.sidebar.header(':blue[_Decision Tree_] Algo Visualizer', divider='rainbow')
|
| 51 |
+
|
| 52 |
+
criterion = st.sidebar.selectbox("Criterion",
|
| 53 |
+
("gini", "entropy", "log_loss"),
|
| 54 |
+
help="""The function to measure the quality of a split.
|
| 55 |
+
Supported criteria are “gini” for the Gini impurity and “log_loss” and “entropy”
|
| 56 |
+
both for the Shannon information gain""")
|
| 57 |
+
max_depth = st.sidebar.number_input("Max Depth",
|
| 58 |
+
min_value=0,
|
| 59 |
+
max_value=30,
|
| 60 |
+
step=1,
|
| 61 |
+
value=0,
|
| 62 |
+
help="""The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure""")
|
| 63 |
+
if max_depth == 0:
|
| 64 |
+
max_depth=None
|
| 65 |
+
min_samples_split = st.sidebar.number_input("Min Sample Split",
|
| 66 |
+
min_value=0,
|
| 67 |
+
max_value=x_train.shape[0],
|
| 68 |
+
value=2,
|
| 69 |
+
help="""The minimum number of samples required to split an internal node.
|
| 70 |
+
If float, enter between 0 and 1""")
|
| 71 |
+
min_samples_leaf = st.sidebar.number_input("Min sample Leaf",
|
| 72 |
+
min_value=0,
|
| 73 |
+
max_value=x_train.shape[0],
|
| 74 |
+
value=1,
|
| 75 |
+
help="""The minimum number of samples required to be at a leaf node.
|
| 76 |
+
If float, enter between 0 and 1""")
|
| 77 |
+
random_state = st.sidebar.number_input("Random State",
|
| 78 |
+
min_value=0,
|
| 79 |
+
value=42)
|
| 80 |
+
|
| 81 |
+
# advance features
|
| 82 |
+
toggle = st.sidebar.toggle("Advance Features")
|
| 83 |
+
|
| 84 |
+
if toggle:
|
| 85 |
+
min_weight_fraction_leaf = st.sidebar.number_input("Min Weight Fraction Leaf",
|
| 86 |
+
min_value=0.0,
|
| 87 |
+
max_value=1.0,
|
| 88 |
+
value=0.0,
|
| 89 |
+
help="""The minimum weighted fraction of the sum total of weights
|
| 90 |
+
(of all the input samples) required to be at a leaf node. """)
|
| 91 |
+
max_features = st.sidebar.selectbox("Max Features",
|
| 92 |
+
(None,"sqrt", "log2","Custom"),
|
| 93 |
+
help="""The number of features to consider when looking for the best split""")
|
| 94 |
+
if max_features == "Custom":
|
| 95 |
+
max_features = st.sidebar.number_input("Enter Max Features",
|
| 96 |
+
value=None,
|
| 97 |
+
step=1)
|
| 98 |
+
|
| 99 |
+
max_leaf_nodes = st.sidebar.number_input("Max Leaf Nodes",
|
| 100 |
+
min_value=0,
|
| 101 |
+
help="""Grow a tree with max_leaf_nodes in best-first fashion. """)
|
| 102 |
+
if max_leaf_nodes==0:
|
| 103 |
+
max_leaf_nodes=None
|
| 104 |
+
min_impurity_decrease = st.sidebar.number_input("Min Impurity Decrase",
|
| 105 |
+
min_value=0.0,
|
| 106 |
+
help="""A node will be split if this split induces a decrease of the
|
| 107 |
+
impurity greater than or equal to this value.""")
|
| 108 |
+
train = st.sidebar.button("Train Model", type="primary")
|
| 109 |
+
if st.sidebar.button("Reset"):
|
| 110 |
+
st.experimental_rerun()
|
| 111 |
+
if train:
|
| 112 |
+
orig.empty()
|
| 113 |
+
|
| 114 |
+
msg = st.toast('Running', icon='🫸🏼')
|
| 115 |
+
# building model
|
| 116 |
+
clf = DecisionTreeClassifier(criterion=criterion,max_depth=max_depth,
|
| 117 |
+
min_samples_split=min_samples_split,
|
| 118 |
+
min_samples_leaf=min_samples_leaf,
|
| 119 |
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
| 120 |
+
max_features=max_features,
|
| 121 |
+
random_state=random_state,
|
| 122 |
+
max_leaf_nodes=max_leaf_nodes,
|
| 123 |
+
min_impurity_decrease=min_impurity_decrease)
|
| 124 |
+
clf.fit(x_train[:, :2], y_train)
|
| 125 |
+
x_pred = clf.predict(x_train[:,:2])
|
| 126 |
+
y_pred = clf.predict(x_test[:, :2])
|
| 127 |
+
st.subheader("Train Accuracy " + str(round(accuracy_score(y_train, x_pred), 2)) + ", "+ "Test Accuracy " + str(round(accuracy_score(y_test, y_pred), 2)))
|
| 128 |
+
st.write("Total Depth: " + str(clf.tree_.max_depth))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# # define ranges for meshgrid
|
| 132 |
+
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
| 133 |
+
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
| 134 |
+
|
| 135 |
+
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
|
| 136 |
+
np.arange(y_min, y_max, 0.01))
|
| 137 |
+
|
| 138 |
+
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
|
| 139 |
+
Z = Z.reshape(xx.shape)
|
| 140 |
+
|
| 141 |
+
# Plot the decision boundaries
|
| 142 |
+
plt.figure(figsize=(8, 6))
|
| 143 |
+
plt.contourf(xx, yy, Z, alpha=0.8)
|
| 144 |
+
plt.scatter(x[:, 0], x[:, 1], c=y, edgecolors='k', s=20)
|
| 145 |
+
plt.xlabel('Sepal length')
|
| 146 |
+
plt.ylabel('Sepal width')
|
| 147 |
+
plt.title('Decision Boundaries')
|
| 148 |
+
plt.tight_layout()
|
| 149 |
+
plt.savefig('decision_boundary_plot.png')
|
| 150 |
+
plt.close()
|
| 151 |
+
|
| 152 |
+
# Display decision boundary plot
|
| 153 |
+
st.image("decision_boundary_plot.png")
|
| 154 |
+
|
| 155 |
+
# Plot decision tree
|
| 156 |
+
plt.figure(figsize=(25, 20))
|
| 157 |
+
tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
|
| 158 |
+
plt.xlim(plt.xlim()[0] * 2, plt.xlim()[1] * 2)
|
| 159 |
+
plt.ylim(plt.ylim()[0] * 2, plt.ylim()[1] * 2)
|
| 160 |
+
plt.savefig("decision_tree.png")
|
| 161 |
+
plt.close()
|
| 162 |
+
|
| 163 |
+
# Display decision tree plot
|
| 164 |
+
st.image("decision_tree.png")
|
| 165 |
+
|
| 166 |
+
msg.toast('Model run successfully!', icon='😎')
|
| 167 |
+
|
| 168 |
+
|
requirements.txt
ADDED
|
Binary file (188 Bytes). View file
|
|
|