Update README.md
Browse files
README.md
CHANGED
|
@@ -58,17 +58,18 @@ inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", trunca
|
|
| 58 |
outputs = model(**inputs)
|
| 59 |
predictions = torch.argmax(outputs.logits, dim=2)
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
with open("id2tag.json", "r") as f:
|
| 64 |
-
id2tag = json.load(f)
|
| 65 |
|
| 66 |
-
# Convert predictions to tags
|
| 67 |
word_ids = inputs.word_ids()
|
| 68 |
prev_word = None
|
| 69 |
for idx, word_idx in enumerate(word_ids):
|
| 70 |
if word_idx is not None and word_idx != prev_word:
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
print(tokens[word_idx], "->", tag)
|
| 73 |
prev_word = word_idx
|
| 74 |
```
|
|
|
|
| 58 |
outputs = model(**inputs)
|
| 59 |
predictions = torch.argmax(outputs.logits, dim=2)
|
| 60 |
|
| 61 |
+
# Get tag mapping from model config
|
| 62 |
+
id2tag = model.config.id2label
|
|
|
|
|
|
|
| 63 |
|
|
|
|
| 64 |
word_ids = inputs.word_ids()
|
| 65 |
prev_word = None
|
| 66 |
for idx, word_idx in enumerate(word_ids):
|
| 67 |
if word_idx is not None and word_idx != prev_word:
|
| 68 |
+
tag_id = predictions[0][idx].item()
|
| 69 |
+
if isinstance(id2tag, dict):
|
| 70 |
+
tag = id2tag.get(str(tag_id), id2tag.get(tag_id, "UNK"))
|
| 71 |
+
else:
|
| 72 |
+
tag = id2tag[tag_id] if tag_id < len(id2tag) else "UNK"
|
| 73 |
print(tokens[word_idx], "->", tag)
|
| 74 |
prev_word = word_idx
|
| 75 |
```
|