jiehou commited on
Commit
d859776
·
1 Parent(s): 2706ff8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -61,14 +61,14 @@ def plot_digits(instances, images_per_row=3):
61
 
62
  n = len(instances)
63
 
64
- fig = plt.figure(figsize=(30,15))
65
  for i in range(len(instances)):
66
  # Debug, plot figure
67
  fig.add_subplot(n_rows, images_per_row, i + 1)
68
  #print(instances[i])
69
  plt.imshow(instances[i].reshape(size,size), cmap = mpl.cm.binary)
70
  plt.axis("off")
71
- plt.title("Neighbor "+str(i+1), size=22)
72
  fig.tight_layout()
73
 
74
  plt.savefig('results.png', dpi=300)
@@ -99,8 +99,8 @@ def KNN_predict(train_features, train_labels, test_feature, K):
99
  for k in range(K):
100
  major_class.append(sorted_labels[k][1])
101
 
102
- # at most 60 neighbors for visualization
103
- if k <60:
104
  neighbor_feature = sorted_labels[k][2]
105
  neighbor_imgs.append(neighbor_feature)
106
 
@@ -125,7 +125,7 @@ sample_images = get_sample_images(10)
125
 
126
  ### configure inputs/outputs
127
  set_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
128
- set_K = gr.inputs.Slider(0, 60, default=7)
129
 
130
  set_label = gr.outputs.Textbox(label="Predicted Digit")
131
  set_out_images = gr.outputs.Image(label="Closest Neighbors")
 
61
 
62
  n = len(instances)
63
 
64
+ fig = plt.figure(figsize=(15,8))
65
  for i in range(len(instances)):
66
  # Debug, plot figure
67
  fig.add_subplot(n_rows, images_per_row, i + 1)
68
  #print(instances[i])
69
  plt.imshow(instances[i].reshape(size,size), cmap = mpl.cm.binary)
70
  plt.axis("off")
71
+ plt.title("Neighbor "+str(i+1), size=20)
72
  fig.tight_layout()
73
 
74
  plt.savefig('results.png', dpi=300)
 
99
  for k in range(K):
100
  major_class.append(sorted_labels[k][1])
101
 
102
+ # at most 24 neighbors for visualization
103
+ if k <24:
104
  neighbor_feature = sorted_labels[k][2]
105
  neighbor_imgs.append(neighbor_feature)
106
 
 
125
 
126
  ### configure inputs/outputs
127
  set_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
128
+ set_K = gr.inputs.Slider(0, 24, default=7)
129
 
130
  set_label = gr.outputs.Textbox(label="Predicted Digit")
131
  set_out_images = gr.outputs.Image(label="Closest Neighbors")