row56 commited on
Commit
cafe337
·
verified ·
1 Parent(s): f97c30f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -196,9 +196,9 @@ Place those in data/ and then:
196
  import numpy as np
197
  import json
198
 
199
- train_embeds = np.load("data/train_embeds.npy") # shape (N, d)
200
  with open("data/train_texts.json", "r") as f:
201
- train_texts = json.load(f) # list[str]
202
 
203
  print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
204
  ```
@@ -220,13 +220,13 @@ batch = {
220
  with torch.no_grad():
221
  logits, metadata = model.proto_module(batch)
222
 
223
- attn_scores = metadata["attentions"][0] # [num_labels, seq_len]
224
  for label_id, scores in enumerate(attn_scores):
225
  topk = sorted(zip(batch["tokens"][0], scores.tolist()),
226
  key=lambda x: -x[1])[:5]
227
  print(f"Label {label_id} top tokens:", topk)
228
 
229
- proto_vecs = model.proto_module.prototype_vectors.cpu().numpy() # [num_labels, d]
230
  nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
231
 
232
  for label_id, u_c in enumerate(proto_vecs):
 
196
  import numpy as np
197
  import json
198
 
199
+ train_embeds = np.load("data/train_embeds.npy")
200
  with open("data/train_texts.json", "r") as f:
201
+ train_texts = json.load(f)
202
 
203
  print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
204
  ```
 
220
  with torch.no_grad():
221
  logits, metadata = model.proto_module(batch)
222
 
223
+ attn_scores = metadata["attentions"][0]
224
  for label_id, scores in enumerate(attn_scores):
225
  topk = sorted(zip(batch["tokens"][0], scores.tolist()),
226
  key=lambda x: -x[1])[:5]
227
  print(f"Label {label_id} top tokens:", topk)
228
 
229
+ proto_vecs = model.proto_module.prototype_vectors.cpu().numpy()
230
  nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
231
 
232
  for label_id, u_c in enumerate(proto_vecs):