Update README.md
Browse files
README.md
CHANGED
|
@@ -196,9 +196,9 @@ Place those in data/ and then:
|
|
| 196 |
import numpy as np
|
| 197 |
import json
|
| 198 |
|
| 199 |
-
train_embeds = np.load("data/train_embeds.npy")
|
| 200 |
with open("data/train_texts.json", "r") as f:
|
| 201 |
-
train_texts = json.load(f)
|
| 202 |
|
| 203 |
print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
|
| 204 |
```
|
|
@@ -220,13 +220,13 @@ batch = {
|
|
| 220 |
with torch.no_grad():
|
| 221 |
logits, metadata = model.proto_module(batch)
|
| 222 |
|
| 223 |
-
attn_scores = metadata["attentions"][0]
|
| 224 |
for label_id, scores in enumerate(attn_scores):
|
| 225 |
topk = sorted(zip(batch["tokens"][0], scores.tolist()),
|
| 226 |
key=lambda x: -x[1])[:5]
|
| 227 |
print(f"Label {label_id} top tokens:", topk)
|
| 228 |
|
| 229 |
-
proto_vecs = model.proto_module.prototype_vectors.cpu().numpy()
|
| 230 |
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
|
| 231 |
|
| 232 |
for label_id, u_c in enumerate(proto_vecs):
|
|
|
|
| 196 |
import numpy as np
|
| 197 |
import json
|
| 198 |
|
| 199 |
+
train_embeds = np.load("data/train_embeds.npy")
|
| 200 |
with open("data/train_texts.json", "r") as f:
|
| 201 |
+
train_texts = json.load(f)
|
| 202 |
|
| 203 |
print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
|
| 204 |
```
|
|
|
|
| 220 |
with torch.no_grad():
|
| 221 |
logits, metadata = model.proto_module(batch)
|
| 222 |
|
| 223 |
+
attn_scores = metadata["attentions"][0]
|
| 224 |
for label_id, scores in enumerate(attn_scores):
|
| 225 |
topk = sorted(zip(batch["tokens"][0], scores.tolist()),
|
| 226 |
key=lambda x: -x[1])[:5]
|
| 227 |
print(f"Label {label_id} top tokens:", topk)
|
| 228 |
|
| 229 |
+
proto_vecs = model.proto_module.prototype_vectors.cpu().numpy()
|
| 230 |
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
|
| 231 |
|
| 232 |
for label_id, u_c in enumerate(proto_vecs):
|