jiehou commited on
Commit
5a472d7
·
1 Parent(s): 169f859

Create new file

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+
5
+ def plot_figure(N, u1, std1, u2, std2, show_dist, classifier=None):
6
+ #N = 100
7
+ import numpy as np
8
+ import matplotlib.pyplot as pp
9
+ pp.style.use('default')
10
+ #val = 0. # this is the value where you want the data to appear on the y-axis.
11
+
12
+ points_class1 = [stats.norm.rvs(loc=u1, scale = std1) for x in range(N)]
13
+ points_class2 = [stats.norm.rvs(loc=u2, scale = std2) for x in range(N)]
14
+
15
+ pd_class1 = pd.DataFrame({'Feature 1 (X)': points_class1, 'Label (Y)': np.repeat(0,len(points_class1))})
16
+ pd_class2 = pd.DataFrame({'Feature 1 (X)': points_class2, 'Label (Y)': np.repeat(1,len(points_class2))})
17
+
18
+
19
+ pd_all = pd.concat([pd_class1, pd_class2]).reset_index(drop=True)
20
+
21
+ import numpy as np
22
+ X_data= pd_all['Feature 1 (X)'].to_numpy().reshape((len(pd_all),1))
23
+ y_labels= pd_all['Label (Y)']
24
+
25
+
26
+ # define x, y limits
27
+ x_min, x_max = X_data[:, 0].min() - 1, X_data[:, 0].max() + 1
28
+ y_min, y_max = 0-1, 1 + 1
29
+
30
+ fig = pp.figure(figsize=(8, 6)) # figure size in inches
31
+ fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.3, wspace=0.05)
32
+
33
+
34
+ pp.tick_params(left = False, right = False , labelleft = False ,
35
+ labelbottom = True, bottom = False)
36
+
37
+ #reference = [stats.uniform.rvs(loc=1, scale = 1) for x in range(N)]
38
+ pp.plot(points_class1, np.zeros_like(points_class1) + val, 'x', label = 'Class 1', markersize = 10)
39
+ pp.plot(points_class2, np.zeros_like(points_class2) + val, 'o', label = 'Class 2', markersize = 10)
40
+
41
+
42
+ if show_dist:
43
+ x = np.arange(x_min, x_max, 0.01, dtype=np.float) # define range of x
44
+ x, y, u, s = gaussian(x, 10000, np.mean(points_class1), np.std(points_class1) )
45
+ pp.plot(x, y)
46
+ #pp.plot(x, y, label=r'$Gaussian (\mu=%.2f,\ \sigma=%.2f)$' % (u, s))
47
+
48
+
49
+ x = np.arange(x_min, x_max, 0.01, dtype=np.float) # define range of x
50
+ x, y, u, s = gaussian(x, 10000, np.mean(points_class2), np.std(points_class2) )
51
+ pp.plot(x, y)
52
+ #pp.plot(x, y, label=r'$Gaussian (\mu=%.2f,\ \sigma=%.2f)$' % (u, s))
53
+
54
+
55
+
56
+ ### draw decision boundary on knn
57
+
58
+ import numpy as np
59
+ from matplotlib import pyplot as plt
60
+ from sklearn import neighbors, datasets
61
+ from matplotlib.colors import ListedColormap
62
+
63
+ # Create color maps for 3-class classification problem, as with iris
64
+ cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA'])
65
+ cmap_bold = ListedColormap(['#FF0000', '#00FF00'])
66
+
67
+
68
+ xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
69
+ np.linspace(y_min, y_max, 100))
70
+
71
+
72
+ if classifier == 'LDA':
73
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
74
+ model_sk = LinearDiscriminantAnalysis()
75
+ model_sk.fit(X_data,y_labels)
76
+ Z = model_sk.predict(np.c_[xx.ravel()])
77
+ Z = Z.reshape(xx.shape)
78
+
79
+ pp.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=0.2)
80
+ elif classifier == 'QDA':
81
+ from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
82
+ model_sk = QuadraticDiscriminantAnalysis()
83
+ model_sk.fit(X_data,y_labels)
84
+ Z = model_sk.predict(np.c_[xx.ravel()])
85
+ Z = Z.reshape(xx.shape)
86
+ print("Z: ",Z)
87
+ pp.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=0.2)
88
+ elif classifier == 'NaiveBayes':
89
+ from sklearn.naive_bayes import GaussianNB
90
+ model_sk = GaussianNB(priors = None)
91
+ model_sk.fit(X_data,y_labels)
92
+ Z = model_sk.predict(np.c_[xx.ravel()])
93
+ Z = Z.reshape(xx.shape)
94
+
95
+ pp.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=0.2)
96
+
97
+
98
+
99
+ pp.xlim([x_min, x_max])
100
+ pp.ylim([y_min, y_max])
101
+ pp.xlabel("Feature 1 (X)", size=20)
102
+ pp.xticks(fontsize=20)
103
+ pp.xlabel("Feature 1 (X)")
104
+ pp.legend(loc='upper right', borderpad=0, handletextpad=0, fontsize = 20)
105
+ pp.savefig('plot.png')
106
+
107
+ return 'plot.png', pd_all
108
+
109
+
110
+ # 1. define mean and standard deviation for class 1
111
+
112
+ set_mean_class1 = gr.inputs.Slider(-20, 20, step=0.5, default=1, label = 'Mean (Class 1)')
113
+ set_std_class1 = gr.inputs.Slider(0, 10, step=0.5, default=1.5, label = 'Standard Deviation (Class 1)')
114
+
115
+ # 2. define mean and standard deviation for class 2
116
+
117
+ set_mean_class2 = gr.inputs.Slider(-20, 20, step=0.5, default=10, label = 'Mean (Class 2)')
118
+ set_std_class2 = gr.inputs.Slider(0, 10, step=0.5, default=1.5, label = 'Standard Deviation (Class 2)')
119
+
120
+ # 3. Define the number of data points
121
+ set_number_points = gr.inputs.Slider(10, 100, step=5, default=20, label = 'Number of samples in each class')
122
+
123
+ # 4. show distribution or not
124
+ set_show_dist = gr.inputs.Checkbox(label="Show data distribution")
125
+
126
+ # 5. set classifier type
127
+ set_classifier = gr.inputs.Dropdown(["None", "LDA", "QDA", "NaiveBayes"])
128
+
129
+ # 6. define output imagem model
130
+ set_out_plot_images = gr.outputs.Image(label="Data visualization")
131
+
132
+ set_out_plot_table = gr.outputs.Dataframe(type='pandas', label ='Simulated Dataset')
133
+
134
+
135
+
136
+ ### configure gradio, detailed can be found at https://www.gradio.app/docs/#i_slider
137
+ interface = gr.Interface(fn=plot_figure,
138
+ inputs=[set_number_points,set_mean_class1,set_std_class1,set_mean_class2,set_std_class2, set_show_dist, set_classifier],
139
+ outputs=[set_out_plot_images,set_out_plot_table],
140
+ examples_per_page = 2,
141
+ #examples = get_sample_data(10),
142
+ title="CSCI4750/5750 Demo: Web Application for Probabilistic Classifier (Single feature)",
143
+ description= "Click examples below for a quick demo",
144
+ theme = 'huggingface',
145
+ layout = 'vertical', live=True
146
+ )
147
+ interface.launch(debug=True)