jiehou commited on
Commit
ce34bb9
·
1 Parent(s): f7a21ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -106,12 +106,17 @@ def KNN_predict(train_features, train_labels, test_feature, K):
106
 
107
  ### get final prediction
108
  final_prediction = scipy.stats.mode(major_class).mode[0]
 
 
 
 
 
109
 
110
  ### get neighbor images and save to local
111
  neighbor_imgs =np.array(neighbor_imgs)
112
  image_path = plot_digits(neighbor_imgs, images_per_row=6)
113
 
114
- return final_prediction, image_path
115
 
116
  ### main function for gradio to call to classify image
117
  def call_our_KNN(test_image, K=7):
@@ -128,13 +133,17 @@ set_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
128
  set_K = gr.inputs.Slider(1, 24, step=1, default=7)
129
 
130
  set_label = gr.outputs.Textbox(label="Predicted Digit")
 
 
 
131
  set_out_images = gr.outputs.Image(label="Closest Neighbors")
132
 
133
 
 
134
  ### configure gradio, detailed can be found at https://www.gradio.app/docs/#i_slider
135
  interface = gr.Interface(fn=call_our_KNN,
136
  inputs=[set_image, set_K],
137
- outputs=[set_label,set_out_images],
138
  examples_per_page = 2,
139
  examples = sample_images,
140
  title="CSCI4750/5750 Demo 1: Digit classification using KNN algorithm",
 
106
 
107
  ### get final prediction
108
  final_prediction = scipy.stats.mode(major_class).mode[0]
109
+
110
+ ### get frequency of classes
111
+ class_freq = {}
112
+ for i in range(0,10):
113
+ class_freq['digit '+str(i)] = float(major_class.count(i)) / len(major_class)
114
 
115
  ### get neighbor images and save to local
116
  neighbor_imgs =np.array(neighbor_imgs)
117
  image_path = plot_digits(neighbor_imgs, images_per_row=6)
118
 
119
+ return final_prediction, class_freq, image_path
120
 
121
  ### main function for gradio to call to classify image
122
  def call_our_KNN(test_image, K=7):
 
133
  set_K = gr.inputs.Slider(1, 24, step=1, default=7)
134
 
135
  set_label = gr.outputs.Textbox(label="Predicted Digit")
136
+ # define output as the single class text
137
+ set_probability = gr.outputs.Label(num_top_classes=10, label="Predicted Class")
138
+
139
  set_out_images = gr.outputs.Image(label="Closest Neighbors")
140
 
141
 
142
+
143
  ### configure gradio, detailed can be found at https://www.gradio.app/docs/#i_slider
144
  interface = gr.Interface(fn=call_our_KNN,
145
  inputs=[set_image, set_K],
146
+ outputs=[set_label,set_probability,set_out_images],
147
  examples_per_page = 2,
148
  examples = sample_images,
149
  title="CSCI4750/5750 Demo 1: Digit classification using KNN algorithm",