saracandu commited on
Commit
5bcd847
·
verified ·
1 Parent(s): 2a59121

Delete handcoded_tokenizer.py

Browse files
Files changed (1) hide show
  1. handcoded_tokenizer.py +0 -232
handcoded_tokenizer.py DELETED
@@ -1,232 +0,0 @@
1
- import re
2
- import json
3
- from typing import Any, Dict, List, Optional, Tuple, Union
4
- from transformers import PreTrainedTokenizer
5
- from transformers.utils import logging
6
-
7
- logger = logging.get_logger(__name__)
8
-
9
-
10
- def load_json(path: str) -> Union[Dict, List]:
11
- """
12
- Load a JSON file from the given path.
13
-
14
- Args:
15
- path (str): The path to the JSON file to be loaded.
16
-
17
- Returns:
18
- Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list.
19
- """
20
- with open(path, "r") as f:
21
- return json.load(f)
22
-
23
-
24
- class STLTokenizer(PreTrainedTokenizer):
25
- """
26
- A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process.
27
-
28
- This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs,
29
- and handle padding and special tokens.
30
- """
31
-
32
- def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad",
33
- bos_token: str = "/s", eos_token: str = "s", model_max_length = 512):
34
- """
35
- Initializes the STLTokenizer with a given vocabulary and special tokens.
36
-
37
- Args:
38
- vocab_path (str): The path to the JSON file containing the vocabulary.
39
- unk_token (str, optional): The token used for unknown words. Defaults to "unk".
40
- pad_token (str, optional): The token used for padding. Defaults to "pad".
41
- bos_token (str, optional): The token used for the beginning of a sequence. Defaults to "/s".
42
- eos_token (str, optional): The token used for the end of a sequence. Defaults to "s".
43
- """
44
- self.vocab = load_json(vocab_path)
45
- self.unk_token = unk_token
46
- self.pad_token = pad_token
47
- self.bos_token = bos_token
48
- self.eos_token = eos_token
49
- self.model_max_length = model_max_length
50
- self.id_to_token = {v: k for k, v in self.vocab.items()} # Reverse mapping
51
-
52
- @property
53
- def vocab_size(self) -> int:
54
- """
55
- Returns the size of the vocabulary.
56
-
57
- Returns:
58
- int: The number of tokens in the vocabulary.
59
- """
60
- return len(self.vocab)
61
-
62
- def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False):
63
- """
64
- Replaces spaces in the input sequence with a specified token.
65
-
66
- Args:
67
- sequence (str): The input sequence.
68
- undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces.
69
-
70
- Returns:
71
- str: The preprocessed sequence with spaces or padding tokens replaced.
72
- """
73
- if undo:
74
- return sequence.replace(new_space_token, space_token)
75
- else:
76
- return sequence.replace(space_token, new_space_token)
77
-
78
- def add_bos_eos(self, sequence: str) -> str:
79
- """
80
- Aggiunge i token BOS all'inizio e EOS alla fine della sequenza.
81
-
82
- Args:
83
- sequence (str): La sequenza di input.
84
-
85
- Returns:
86
- str: La sequenza con i token BOS ed EOS.
87
- """
88
- return f'{self.bos_token} {sequence} {self.eos_token}'
89
-
90
- def tokenize(self, text: str) -> List[str]:
91
- """
92
- Tokenizes the input text into a list of tokens.
93
-
94
- The method preprocesses the input text by replacing spaces with padding tokens and then tries to
95
- find the longest possible match for each substring in the vocabulary.
96
-
97
- Args:
98
- text (str): The input text to be tokenized.
99
-
100
- Returns:
101
- List[str]: A list of tokens representing the tokenized text.
102
- """
103
- text = self.add_bos_eos(text)
104
- text = self.prepad_sequence(text)
105
-
106
- tokens = []
107
- i = 0
108
- while i < len(text):
109
- best_match = None
110
- for j in range(len(text), i, -1): # Try matching substrings of decreasing length
111
- subtoken = text[i:j]
112
- if subtoken in self.vocab:
113
- best_match = subtoken
114
- break
115
- if best_match:
116
- tokens.append(best_match)
117
- i += len(best_match)
118
- else:
119
- tokens.append(self.unk_token)
120
- i += 1
121
- return tokens
122
-
123
- def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
124
- """
125
- Converts a list of tokens into a list of token IDs.
126
-
127
- Args:
128
- tokens (List[str]): A list of tokens to be converted into IDs.
129
-
130
- Returns:
131
- List[int]: A list of corresponding token IDs.
132
- """
133
- return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
134
-
135
- def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
136
- """
137
- Converts a list of token IDs into a list of tokens.
138
-
139
- Args:
140
- ids (List[int]): A list of token IDs to be converted into tokens.
141
-
142
- Returns:
143
- List[str]: A list of corresponding tokens.
144
- """
145
- return [self.id_to_token.get(i, self.unk_token) for i in ids]
146
-
147
- def encode(self, sequence: str) -> List[int]:
148
- """
149
- Encodes a string sequence into a list of token IDs.
150
-
151
- This method tokenizes the input sequence using the `tokenize` method,
152
- and then converts the resulting tokens into their corresponding token IDs
153
- using the `convert_tokens_to_ids` method.
154
-
155
- Args:
156
- sequence (str): The input sequence (text) to be encoded.
157
-
158
- Returns:
159
- List[int]: A list of token IDs corresponding to the input sequence.
160
- """
161
- splitted_sequence = self.tokenize(sequence)
162
- return self.convert_tokens_to_ids(splitted_sequence)
163
-
164
- def postpad_sequence(self, sequence, pad_token_id):
165
- """
166
- Fills the sequence up to max_length padding elements
167
- """
168
- num_extra_elements = self.model_max_length - len(sequence) -1
169
- if num_extra_elements > 0:
170
- sequence.extend([pad_token_id] * num_extra_elements)
171
- return sequence
172
-
173
- def decode(self, token_ids: List[int]) -> str:
174
- """
175
- Decodes a list of token IDs into a string of text.
176
-
177
- The method converts the IDs to tokens and joins them to form a string.
178
- It also restores the original spaces or padding tokens if `undo` is True.
179
-
180
- Args:
181
- token_ids (List[int]): A list of token IDs to be decoded.
182
- skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False.
183
-
184
- Returns:
185
- str: The decoded string.
186
- """
187
- tokens = self.convert_ids_to_tokens(token_ids)
188
- decoded = "".join(tokens)
189
- return self.prepad_sequence(decoded, undo=True)
190
-
191
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
192
- """
193
- Saves the tokenizer's vocabulary to a file.
194
- Useful only when the vocabulary has to be retrieved and is not given
195
- (thus this is not the case: here to further improvements with sentencepiece).
196
-
197
- This method saves the vocabulary to a JSON file in the specified directory.
198
-
199
- Args:
200
- save_directory (str): The directory where the vocabulary file will be saved.
201
- filename_prefix (Optional[str]): An optional prefix for the filename.
202
-
203
- Returns:
204
- Tuple[str]: A tuple containing the path to the saved vocabulary file.
205
- """
206
- vocab_file = f"{save_directory}/{filename_prefix + '-' if filename_prefix else ''}vocab.json"
207
- with open(vocab_file, "w", encoding="utf-8") as f:
208
- json.dump(self.vocab, f, indent=2, ensure_ascii=False)
209
- return (vocab_file,)
210
-
211
- def get_vocab(self) -> dict:
212
- """
213
- Retrieves the vocabulary used by the tokenizer.
214
-
215
- Returns:
216
- dict: The vocabulary as a dictionary.
217
- """
218
- return self.vocab
219
-
220
-
221
- # EXAMPLE OF USAGE
222
-
223
- # sequence = "( not ( x_1 <= 0.2988 ) until[11,21] x_0 <= -0.7941 )"
224
- # tokenizer = STLTokenizer('tokenizer_files/tokenizer.json')
225
- # token_ids = tokenizer.encode(sequence)
226
- # decoded_sequence = tokenizer.decode(token_ids)
227
-
228
- # print("Original sequence: ", sequence)
229
- # print("Encoded sequence: ", token_ids)
230
- # print("Decoded sequence: ", decoded_sequence)
231
-
232
-