beihai commited on
Commit
e87f1da
·
1 Parent(s): 7865cbb

Upload visualize.py

Browse files
Files changed (1) hide show
  1. visualize.py +103 -0
visualize.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 本文用到的库
2
+ import numpy as np
3
+ import pandas as pd
4
+ from sklearn.tree import DecisionTreeClassifier
5
+ import base64
6
+ import streamlit as st
7
+ from sklearn import preprocessing
8
+ from dtreeviz.trees import *
9
+ from data import getDataSetOrigin,dataPreprocessing
10
+ import joblib
11
+ from DecisionTree import dt_param_selector
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ from data import dataPreprocessing
15
+ from sklearn.tree import DecisionTreeClassifier
16
+ import streamlit as st
17
+
18
+ def decisionTreeViz(clf):
19
+ df = dataPreprocessing()
20
+ X, y = df[df.columns[:-1]], df["label"]
21
+ viz = dtreeviz(
22
+ clf,
23
+ X,
24
+ y,
25
+ orientation="LR",
26
+ target_name="label",
27
+ feature_names=df.columns[:-1],
28
+ class_names=["good", "bad"], # need class_names for classifier
29
+ )
30
+
31
+ return viz
32
+
33
+
34
+ def svg_write(svg, center=True):
35
+ """
36
+ Disable center to left-margin align like other objects.
37
+ """
38
+ # Encode as base 64
39
+ b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
40
+
41
+ # Add some CSS on top
42
+ css_justify = "center" if center else "left"
43
+ css = (
44
+ f'<p style="text-align:center; display: flex; justify-content: {css_justify};">'
45
+ )
46
+ html = f'{css}<img src="data:image/svg+xml;base64,{b64}"/>'
47
+
48
+ # Write the HTML
49
+ st.write(html, unsafe_allow_html=True, width=800, caption="决策树")
50
+
51
+ def plotSurface():
52
+ st.set_option('deprecation.showPyplotGlobalUse', False)
53
+ # Parameters
54
+ n_classes = 2
55
+ plot_colors = "ryb"
56
+ plot_step = 0.02
57
+
58
+ # Load data
59
+ df = dataPreprocessing()
60
+ plt.figure(figsize=(8,4))
61
+ for pairidx, pair in enumerate([[1, 0], [1, 3], [1, 4], [1, 5],
62
+ [3, 0], [3, 2], [3, 4], [3, 5]]):
63
+ # We only take the two corresponding features
64
+ X, y = df[df.columns[:-1]].values[:, pair], df["label"]
65
+
66
+ # Train
67
+ clf = DecisionTreeClassifier().fit(X, y)
68
+
69
+ # Plot the decision boundary
70
+ fig=plt.subplot(2, 4, pairidx + 1)
71
+
72
+ x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
73
+ y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
74
+ xx, yy = np.meshgrid(
75
+ np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step)
76
+ )
77
+ plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
78
+
79
+ Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
80
+ Z = Z.reshape(xx.shape)
81
+ cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)
82
+
83
+ plt.xlabel(df.columns[pair[0]])
84
+ plt.ylabel(df.columns[pair[1]])
85
+
86
+ # Plot the training points
87
+ for i, color in zip(range(n_classes), plot_colors):
88
+ idx = np.where(y == i)
89
+ plt.scatter(
90
+ X[idx, 0],
91
+ X[idx, 1],
92
+ c=color,
93
+ label=df["label"][i],
94
+ cmap=plt.cm.RdYlBu,
95
+ edgecolor="black",
96
+ s=15,
97
+ )
98
+ plt.suptitle("Decision surface of a decision tree using paired features")
99
+ plt.legend(loc="lower right", borderpad=0, handletextpad=0)
100
+ plt.axis("tight")
101
+ # plt.show()
102
+ plt.tight_layout()
103
+ st.pyplot()