saracandu commited on
Commit
0013017
·
verified ·
1 Parent(s): 6c5ccd2

Upload handcoded_tokenizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handcoded_tokenizer.py +232 -0
handcoded_tokenizer.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from transformers import PreTrainedTokenizer
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ def load_json(path: str) -> Union[Dict, List]:
11
+ """
12
+ Load a JSON file from the given path.
13
+
14
+ Args:
15
+ path (str): The path to the JSON file to be loaded.
16
+
17
+ Returns:
18
+ Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list.
19
+ """
20
+ with open(path, "r") as f:
21
+ return json.load(f)
22
+
23
+
24
+ class STLTokenizer(PreTrainedTokenizer):
25
+ """
26
+ A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process.
27
+
28
+ This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs,
29
+ and handle padding and special tokens.
30
+ """
31
+
32
+ def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad",
33
+ bos_token: str = "/s", eos_token: str = "s", model_max_length = 512):
34
+ """
35
+ Initializes the STLTokenizer with a given vocabulary and special tokens.
36
+
37
+ Args:
38
+ vocab_path (str): The path to the JSON file containing the vocabulary.
39
+ unk_token (str, optional): The token used for unknown words. Defaults to "unk".
40
+ pad_token (str, optional): The token used for padding. Defaults to "pad".
41
+ bos_token (str, optional): The token used for the beginning of a sequence. Defaults to "/s".
42
+ eos_token (str, optional): The token used for the end of a sequence. Defaults to "s".
43
+ """
44
+ self.vocab = load_json(vocab_path)
45
+ self.unk_token = unk_token
46
+ self.pad_token = pad_token
47
+ self.bos_token = bos_token
48
+ self.eos_token = eos_token
49
+ self.model_max_length = model_max_length
50
+ self.id_to_token = {v: k for k, v in self.vocab.items()} # Reverse mapping
51
+
52
+ @property
53
+ def vocab_size(self) -> int:
54
+ """
55
+ Returns the size of the vocabulary.
56
+
57
+ Returns:
58
+ int: The number of tokens in the vocabulary.
59
+ """
60
+ return len(self.vocab)
61
+
62
+ def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False):
63
+ """
64
+ Replaces spaces in the input sequence with a specified token.
65
+
66
+ Args:
67
+ sequence (str): The input sequence.
68
+ undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces.
69
+
70
+ Returns:
71
+ str: The preprocessed sequence with spaces or padding tokens replaced.
72
+ """
73
+ if undo:
74
+ return sequence.replace(new_space_token, space_token)
75
+ else:
76
+ return sequence.replace(space_token, new_space_token)
77
+
78
+ def add_bos_eos(self, sequence: str) -> str:
79
+ """
80
+ Aggiunge i token BOS all'inizio e EOS alla fine della sequenza.
81
+
82
+ Args:
83
+ sequence (str): La sequenza di input.
84
+
85
+ Returns:
86
+ str: La sequenza con i token BOS ed EOS.
87
+ """
88
+ return f'{self.bos_token} {sequence} {self.eos_token}'
89
+
90
+ def tokenize(self, text: str) -> List[str]:
91
+ """
92
+ Tokenizes the input text into a list of tokens.
93
+
94
+ The method preprocesses the input text by replacing spaces with padding tokens and then tries to
95
+ find the longest possible match for each substring in the vocabulary.
96
+
97
+ Args:
98
+ text (str): The input text to be tokenized.
99
+
100
+ Returns:
101
+ List[str]: A list of tokens representing the tokenized text.
102
+ """
103
+ text = self.add_bos_eos(text)
104
+ text = self.prepad_sequence(text)
105
+
106
+ tokens = []
107
+ i = 0
108
+ while i < len(text):
109
+ best_match = None
110
+ for j in range(len(text), i, -1): # Try matching substrings of decreasing length
111
+ subtoken = text[i:j]
112
+ if subtoken in self.vocab:
113
+ best_match = subtoken
114
+ break
115
+ if best_match:
116
+ tokens.append(best_match)
117
+ i += len(best_match)
118
+ else:
119
+ tokens.append(self.unk_token)
120
+ i += 1
121
+ return tokens
122
+
123
+ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
124
+ """
125
+ Converts a list of tokens into a list of token IDs.
126
+
127
+ Args:
128
+ tokens (List[str]): A list of tokens to be converted into IDs.
129
+
130
+ Returns:
131
+ List[int]: A list of corresponding token IDs.
132
+ """
133
+ return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
134
+
135
+ def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
136
+ """
137
+ Converts a list of token IDs into a list of tokens.
138
+
139
+ Args:
140
+ ids (List[int]): A list of token IDs to be converted into tokens.
141
+
142
+ Returns:
143
+ List[str]: A list of corresponding tokens.
144
+ """
145
+ return [self.id_to_token.get(i, self.unk_token) for i in ids]
146
+
147
+ def encode(self, sequence: str) -> List[int]:
148
+ """
149
+ Encodes a string sequence into a list of token IDs.
150
+
151
+ This method tokenizes the input sequence using the `tokenize` method,
152
+ and then converts the resulting tokens into their corresponding token IDs
153
+ using the `convert_tokens_to_ids` method.
154
+
155
+ Args:
156
+ sequence (str): The input sequence (text) to be encoded.
157
+
158
+ Returns:
159
+ List[int]: A list of token IDs corresponding to the input sequence.
160
+ """
161
+ splitted_sequence = self.tokenize(sequence)
162
+ return self.convert_tokens_to_ids(splitted_sequence)
163
+
164
+ def postpad_sequence(self, sequence, pad_token_id):
165
+ """
166
+ Fills the sequence up to max_length padding elements
167
+ """
168
+ num_extra_elements = self.model_max_length - len(sequence) -1
169
+ if num_extra_elements > 0:
170
+ sequence.extend([pad_token_id] * num_extra_elements)
171
+ return sequence
172
+
173
+ def decode(self, token_ids: List[int]) -> str:
174
+ """
175
+ Decodes a list of token IDs into a string of text.
176
+
177
+ The method converts the IDs to tokens and joins them to form a string.
178
+ It also restores the original spaces or padding tokens if `undo` is True.
179
+
180
+ Args:
181
+ token_ids (List[int]): A list of token IDs to be decoded.
182
+ skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False.
183
+
184
+ Returns:
185
+ str: The decoded string.
186
+ """
187
+ tokens = self.convert_ids_to_tokens(token_ids)
188
+ decoded = "".join(tokens)
189
+ return self.prepad_sequence(decoded, undo=True)
190
+
191
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
192
+ """
193
+ Saves the tokenizer's vocabulary to a file.
194
+ Useful only when the vocabulary has to be retrieved and is not given
195
+ (thus this is not the case: here to further improvements with sentencepiece).
196
+
197
+ This method saves the vocabulary to a JSON file in the specified directory.
198
+
199
+ Args:
200
+ save_directory (str): The directory where the vocabulary file will be saved.
201
+ filename_prefix (Optional[str]): An optional prefix for the filename.
202
+
203
+ Returns:
204
+ Tuple[str]: A tuple containing the path to the saved vocabulary file.
205
+ """
206
+ vocab_file = f"{save_directory}/{filename_prefix + '-' if filename_prefix else ''}vocab.json"
207
+ with open(vocab_file, "w", encoding="utf-8") as f:
208
+ json.dump(self.vocab, f, indent=2, ensure_ascii=False)
209
+ return (vocab_file,)
210
+
211
+ def get_vocab(self) -> dict:
212
+ """
213
+ Retrieves the vocabulary used by the tokenizer.
214
+
215
+ Returns:
216
+ dict: The vocabulary as a dictionary.
217
+ """
218
+ return self.vocab
219
+
220
+
221
+ # EXAMPLE OF USAGE
222
+
223
+ # sequence = "( not ( x_1 <= 0.2988 ) until[11,21] x_0 <= -0.7941 )"
224
+ # tokenizer = STLTokenizer('tokenizer_files/tokenizer.json')
225
+ # token_ids = tokenizer.encode(sequence)
226
+ # decoded_sequence = tokenizer.decode(token_ids)
227
+
228
+ # print("Original sequence: ", sequence)
229
+ # print("Encoded sequence: ", token_ids)
230
+ # print("Decoded sequence: ", decoded_sequence)
231
+
232
+