saracandu commited on
Commit
2ec2e3a
·
verified ·
1 Parent(s): fb5373d

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +9 -11
tokenizer.py CHANGED
@@ -10,7 +10,7 @@ logger = logging.get_logger(__name__)
10
 
11
  from huggingface_hub import hf_hub_download
12
  import json
13
- import os
14
 
15
  def load_json(path, repo_id=None):
16
  if repo_id:
@@ -128,16 +128,14 @@ class STLTokenizer(PreTrainedTokenizer):
128
  tokens.append(self.unk_token)
129
  i += 1
130
  return tokens
131
-
132
- def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
133
- """
134
- Converts a list of tokens into a list of token IDs.
135
- Args:
136
- tokens (List[str]): A list of tokens to be converted into IDs.
137
- Returns:
138
- List[int]: A list of corresponding token IDs.
139
- """
140
- return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
141
 
142
  def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
143
  """
 
10
 
11
  from huggingface_hub import hf_hub_download
12
  import json
13
+ from transformers import AddedToken
14
 
15
  def load_json(path, repo_id=None):
16
  if repo_id:
 
128
  tokens.append(self.unk_token)
129
  i += 1
130
  return tokens
131
+
132
+ def convert_tokens_to_ids(self, tokens: Union[List[str], str, AddedToken]) -> List[int]:
133
+ # Se è un singolo token non iterabile, lo metti in lista
134
+ if isinstance(tokens, (str, AddedToken)):
135
+ tokens = [tokens]
136
+ # Converti token a stringa se sono AddedToken
137
+ tokens_str = [str(token) if isinstance(token, AddedToken) else token for token in tokens]
138
+ return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens_str]
 
 
139
 
140
  def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
141
  """