suvradeepp commited on
Commit
7e315ff
·
verified ·
1 Parent(s): 1812ff9

Create decision_tree_steps.py

Browse files
Files changed (1) hide show
  1. decision_tree_steps.py +85 -0
decision_tree_steps.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import streamlit as st
3
+ import numpy as np
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.datasets import make_moons
6
+ from sklearn.tree import DecisionTreeClassifier
7
+ from sklearn.metrics import accuracy_score
8
+ from sklearn.tree import plot_tree
9
+ from sklearn.tree import export_graphviz
10
+ from os import system
11
+ from graphviz import Source
12
+ from sklearn import tree
13
+
14
+ def draw_meshgrid():
15
+ a = np.arange(start=X[:, 0].min() - 1, stop=X[:, 0].max() + 1, step=0.01)
16
+ b = np.arange(start=X[:, 1].min() - 1, stop=X[:, 1].max() + 1, step=0.01)
17
+
18
+ XX, YY = np.meshgrid(a, b)
19
+
20
+ input_array = np.array([XX.ravel(), YY.ravel()]).T
21
+
22
+ return XX, YY, input_array
23
+
24
+ X, y = make_moons(n_samples=500, noise=0.30, random_state=42)
25
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
26
+
27
+ plt.style.use('fivethirtyeight')
28
+
29
+ st.sidebar.markdown("# Decision Tree Classifier")
30
+
31
+ criterion = st.sidebar.selectbox(
32
+ 'Criterion',
33
+ ('gini', 'entropy')
34
+ )
35
+
36
+ splitter = st.sidebar.selectbox(
37
+ 'Splitter',
38
+ ('best', 'random')
39
+ )
40
+
41
+ max_depth = int(st.sidebar.number_input('Max Depth'))
42
+
43
+ min_samples_split = st.sidebar.slider('Min Samples Split', 1, X_train.shape[0], 2,key=1234)
44
+
45
+ min_samples_leaf = st.sidebar.slider('Min Samples Leaf', 1, X_train.shape[0], 1,key=1235)
46
+
47
+ max_features = st.sidebar.slider('Max Features', 1, 2, 2,key=1236)
48
+
49
+ max_leaf_nodes = int(st.sidebar.number_input('Max Leaf Nodes'))
50
+
51
+ min_impurity_decrease = st.sidebar.number_input('Min Impurity Decrease')
52
+
53
+ # Load initial graph
54
+ fig, ax = plt.subplots()
55
+
56
+ # Plot initial graph
57
+ ax.scatter(X.T[0], X.T[1], c=y, cmap='rainbow')
58
+ orig = st.pyplot(fig)
59
+
60
+ if st.sidebar.button('Run Algorithm'):
61
+
62
+ orig.empty()
63
+
64
+ if max_depth == 0:
65
+ max_depth = None
66
+
67
+ if max_leaf_nodes == 0:
68
+ max_leaf_nodes = None
69
+
70
+ clf = DecisionTreeClassifier(criterion=criterion,splitter=splitter,max_depth=max_depth,random_state=42,min_samples_split=min_samples_split,min_samples_leaf=min_samples_leaf,max_features=max_features,max_leaf_nodes=max_leaf_nodes,min_impurity_decrease=min_impurity_decrease)
71
+ clf.fit(X_train, y_train)
72
+ y_pred = clf.predict(X_test)
73
+
74
+ XX, YY, input_array = draw_meshgrid()
75
+ labels = clf.predict(input_array)
76
+
77
+ ax.contourf(XX, YY, labels.reshape(XX.shape), alpha=0.5, cmap='rainbow')
78
+ plt.xlabel("Col1")
79
+ plt.ylabel("Col2")
80
+ orig = st.pyplot(fig)
81
+ st.subheader("Accuracy for Decision Tree " + str(round(accuracy_score(y_test, y_pred), 2)))
82
+
83
+ tree = export_graphviz(clf,feature_names=["Col1","Col2"])
84
+
85
+ st.graphviz_chart(tree)