Token Classification
Safetensors
Tatar
bert
tatar
morphology
rubert
ArabovMK commited on
Commit
a8a1f53
·
verified ·
1 Parent(s): 4527c85

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -6
README.md CHANGED
@@ -58,17 +58,18 @@ inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", trunca
58
  outputs = model(**inputs)
59
  predictions = torch.argmax(outputs.logits, dim=2)
60
 
61
- # Load tag mappings
62
- import json
63
- with open("id2tag.json", "r") as f:
64
- id2tag = json.load(f)
65
 
66
- # Convert predictions to tags
67
  word_ids = inputs.word_ids()
68
  prev_word = None
69
  for idx, word_idx in enumerate(word_ids):
70
  if word_idx is not None and word_idx != prev_word:
71
- tag = id2tag[str(predictions[0][idx].item())]
 
 
 
 
72
  print(tokens[word_idx], "->", tag)
73
  prev_word = word_idx
74
  ```
 
58
  outputs = model(**inputs)
59
  predictions = torch.argmax(outputs.logits, dim=2)
60
 
61
+ # Get tag mapping from model config
62
+ id2tag = model.config.id2label
 
 
63
 
 
64
  word_ids = inputs.word_ids()
65
  prev_word = None
66
  for idx, word_idx in enumerate(word_ids):
67
  if word_idx is not None and word_idx != prev_word:
68
+ tag_id = predictions[0][idx].item()
69
+ if isinstance(id2tag, dict):
70
+ tag = id2tag.get(str(tag_id), id2tag.get(tag_id, "UNK"))
71
+ else:
72
+ tag = id2tag[tag_id] if tag_id < len(id2tag) else "UNK"
73
  print(tokens[word_idx], "->", tag)
74
  prev_word = word_idx
75
  ```