aswin-raghavan commited on
Commit
e631c99
·
1 Parent(s): eb801a5

add test acc

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -143,7 +143,31 @@ def update_exemplars(df, rng, exemplars, lut):
143
  preds[dist_to_ex1 < dist_to_ex0] = 1
144
  print(preds.shape, labels_train.shape, np.sum(preds == labels_train))
145
  train_acc = np.sum(preds == labels_train) / len(labels_train)
146
- return rng, exemplars, train_acc, 0. # score(embeds_train, exemplars, lut), score(embeds_test, exemplars, lut)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
 
149
  with gr.Blocks(title="End-User Personalization") as demo:
 
143
  preds[dist_to_ex1 < dist_to_ex0] = 1
144
  print(preds.shape, labels_train.shape, np.sum(preds == labels_train))
145
  train_acc = np.sum(preds == labels_train) / len(labels_train)
146
+ rng, test_acc = score(embeds_test, labels_test, rng, exemplars, lut)
147
+ return rng, exemplars, train_acc, test_acc
148
+
149
+ def score(embeds, labels, rng, exemplars, lut):
150
+ quantized_embeds, closest_bin = quantize_embeds(embeds)
151
+ # closest bin is nexample X 512
152
+ # lut[0] is nvals X dims
153
+ # hd_embeds in nexample x 512 x dims
154
+ hd_embeds_per_pos = lut[0][closest_bin]
155
+ # bundle along pos dimension 512
156
+ # lut[1] is 512 x dims
157
+ xor = lambda a,b: a*(1.-b) + b*(1.-a)
158
+ hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)
159
+ hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1]
160
+ hd_embeds[hd_embeds >= 0.5] = 1.
161
+ hd_embeds[hd_embeds < 0.5] = 0.
162
+ # hd_embeds_integer is nexample x dims
163
+ print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())
164
+ preds = np.zeros(hd_embeds.shape[0])
165
+ dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)
166
+ dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)
167
+ preds[dist_to_ex1 < dist_to_ex0] = 1
168
+ print(preds.shape, labels.shape, np.sum(preds == labels), len(labels))
169
+ acc = np.sum(preds == labels) / len(labels)
170
+ return rng, acc
171
 
172
 
173
  with gr.Blocks(title="End-User Personalization") as demo: