jiehou commited on
Commit
dc070db
·
1 Parent(s): 22a5515

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -27,8 +27,7 @@ def plot_figure(N, u1, std1, u2, std2, show_dist, classifier=None):
27
  import numpy as np
28
  import matplotlib.pyplot as pp
29
  pp.style.use('default')
30
-
31
- val = 0. # this is the value where you want the data to appear on the y-axis.
32
 
33
  points_class1 = [stats.norm.rvs(loc=u1, scale = std1) for x in range(N)]
34
  points_class2 = [stats.norm.rvs(loc=u2, scale = std2) for x in range(N)]
@@ -94,16 +93,29 @@ def plot_figure(N, u1, std1, u2, std2, show_dist, classifier=None):
94
  from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
95
  model_sk = LinearDiscriminantAnalysis()
96
  model_sk.fit(X_data,y_labels)
97
- Z = model_sk.predict(np.c_[xx.ravel()])
98
- Z = Z.reshape(xx.shape)
 
 
 
 
 
99
 
100
  pp.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=0.2)
101
  elif classifier == 'QDA':
102
  from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
103
  model_sk = QuadraticDiscriminantAnalysis()
104
  model_sk.fit(X_data,y_labels)
105
- Z = model_sk.predict(np.c_[xx.ravel()])
106
- Z = Z.reshape(xx.shape)
 
 
 
 
 
 
 
 
107
  print("Z: ",Z)
108
  pp.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=0.2)
109
  elif classifier == 'NaiveBayes':
@@ -119,9 +131,9 @@ def plot_figure(N, u1, std1, u2, std2, show_dist, classifier=None):
119
 
120
  pp.xlim([x_min, x_max])
121
  pp.ylim([y_min, y_max])
122
- pp.xlabel("Feature 1 (X)", size=20)
123
  pp.xticks(fontsize=20)
124
- pp.xlabel("Feature 1 (X)")
125
  pp.legend(loc='upper right', borderpad=0, handletextpad=0, fontsize = 20)
126
  pp.savefig('plot.png')
127
 
@@ -129,8 +141,6 @@ def plot_figure(N, u1, std1, u2, std2, show_dist, classifier=None):
129
 
130
 
131
 
132
- # 1. define mean and standard deviation for class 1
133
-
134
  set_mean_class1 = gr.inputs.Slider(-20, 20, step=0.5, default=1, label = 'Mean (Class 1)')
135
  set_std_class1 = gr.inputs.Slider(0, 10, step=0.5, default=1.5, label = 'Standard Deviation (Class 1)')
136
 
@@ -155,6 +165,8 @@ set_out_plot_table = gr.outputs.Dataframe(type='pandas', label ='Simulated Datas
155
 
156
 
157
 
 
 
158
  ### configure gradio, detailed can be found at https://www.gradio.app/docs/#i_slider
159
  interface = gr.Interface(fn=plot_figure,
160
  inputs=[set_number_points,set_mean_class1,set_std_class1,set_mean_class2,set_std_class2, set_show_dist, set_classifier],
 
27
  import numpy as np
28
  import matplotlib.pyplot as pp
29
  pp.style.use('default')
30
+ #val = 0. # this is the value where you want the data to appear on the y-axis.
 
31
 
32
  points_class1 = [stats.norm.rvs(loc=u1, scale = std1) for x in range(N)]
33
  points_class2 = [stats.norm.rvs(loc=u2, scale = std2) for x in range(N)]
 
93
  from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
94
  model_sk = LinearDiscriminantAnalysis()
95
  model_sk.fit(X_data,y_labels)
96
+ zz = model_sk.predict(np.c_[xx.ravel()])
97
+
98
+ #Predictions for each point on meshgrid
99
+ #zz = np.array( [model_sk.predict( [[xx,yy]])[0] for xx, yy in zip(np.ravel(X), np.ravel(Y)) ] )
100
+
101
+ #Reshaping the predicted class into the meshgrid shape
102
+ Z = zz.reshape(X.shape)
103
 
104
  pp.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=0.2)
105
  elif classifier == 'QDA':
106
  from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
107
  model_sk = QuadraticDiscriminantAnalysis()
108
  model_sk.fit(X_data,y_labels)
109
+
110
+ model_sk.fit(X_data,y_labels)
111
+ zz = model_sk.predict(np.c_[xx.ravel()])
112
+
113
+ #Predictions for each point on meshgrid
114
+ #zz = np.array( [model_sk.predict( [[xx,yy]])[0] for xx, yy in zip(np.ravel(X), np.ravel(Y)) ] )
115
+
116
+ #Reshaping the predicted class into the meshgrid shape
117
+ Z = zz.reshape(X.shape)
118
+
119
  print("Z: ",Z)
120
  pp.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=0.2)
121
  elif classifier == 'NaiveBayes':
 
131
 
132
  pp.xlim([x_min, x_max])
133
  pp.ylim([y_min, y_max])
134
+ pp.xlabel("Feature 1 (X1)", size=20)
135
  pp.xticks(fontsize=20)
136
+ pp.ylabel("Feature 2 (X2)")
137
  pp.legend(loc='upper right', borderpad=0, handletextpad=0, fontsize = 20)
138
  pp.savefig('plot.png')
139
 
 
141
 
142
 
143
 
 
 
144
  set_mean_class1 = gr.inputs.Slider(-20, 20, step=0.5, default=1, label = 'Mean (Class 1)')
145
  set_std_class1 = gr.inputs.Slider(0, 10, step=0.5, default=1.5, label = 'Standard Deviation (Class 1)')
146
 
 
165
 
166
 
167
 
168
+
169
+
170
  ### configure gradio, detailed can be found at https://www.gradio.app/docs/#i_slider
171
  interface = gr.Interface(fn=plot_figure,
172
  inputs=[set_number_points,set_mean_class1,set_std_class1,set_mean_class2,set_std_class2, set_show_dist, set_classifier],