flopml commited on
Commit
cdd1aa0
·
verified ·
1 Parent(s): 915f29f

Create tokenizer.py

Browse files
Files changed (1) hide show
  1. py_src/tokenizer.py +207 -0
py_src/tokenizer.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import struct
3
+
4
+ # Constants
5
+ MAX_VOCAB_SIZE = 32000
6
+ MAX_WORD_LEN = 16
7
+
8
+ def ERROR(message, *args):
9
+ """Prints an error message to stderr and exits."""
10
+ import sys
11
+ sys.stderr.write(message % args)
12
+ sys.exit(1)
13
+
14
+ def INFO(message, *args):
15
+ """Prints an informational message to stdout."""
16
+ print(message % args)
17
+
18
+ class Tokenizer:
19
+ def __init__(self, fname=None):
20
+ self.vocab_size = 0
21
+ self.vocab = [''] * MAX_VOCAB_SIZE # Preallocate vocab with empty strings
22
+
23
+ if fname:
24
+ self.load_tokenizer(fname)
25
+
26
+ INFO("vocabulary size: %d (%d max)", self.vocab_size, MAX_VOCAB_SIZE)
27
+ INFO("max token length: %d", MAX_WORD_LEN)
28
+ # Approximate size of structure: vocab_size * MAX_WORD_LEN + overhead
29
+ structure_size = self.vocab_size * MAX_WORD_LEN
30
+ INFO("size of structure: %d bytes", structure_size)
31
+
32
+ def add_word(self, word):
33
+ """Adds a word to the vocabulary."""
34
+ if self.vocab_size >= MAX_VOCAB_SIZE:
35
+ return -1
36
+ # Truncate word if it's longer than MAX_WORD_LEN - 1
37
+ if len(word) >= MAX_WORD_LEN:
38
+ word = word[:MAX_WORD_LEN - 1]
39
+ self.vocab[self.vocab_size] = word
40
+ self.vocab_size += 1
41
+ return self.vocab_size - 1
42
+
43
+ def encode_word(self, word):
44
+ """Encodes a word into its corresponding ID using binary search."""
45
+ left = 0
46
+ right = self.vocab_size - 1
47
+
48
+ while left <= right:
49
+ mid = left + (right - left) // 2
50
+ cmp = self._compare(word, self.vocab[mid])
51
+
52
+ if cmp == 0:
53
+ return mid
54
+ elif cmp < 0:
55
+ right = mid - 1
56
+ else:
57
+ left = mid + 1
58
+
59
+ return -1
60
+
61
+ def encode_stream(self, stream):
62
+ """
63
+ Encodes a word from a stream.
64
+
65
+ Args:
66
+ stream (list of str): A list containing the characters of the stream.
67
+
68
+ Returns:
69
+ int: The ID of the encoded word.
70
+ """
71
+ word = ''
72
+ id = -1
73
+ j = 0
74
+
75
+ for i in range(min(MAX_WORD_LEN, len(stream))):
76
+ word += stream[i]
77
+ tmp = self.encode_word(word)
78
+ if tmp != -1:
79
+ id = tmp
80
+ j = i + 1
81
+
82
+ # Modify the stream in-place to remove the processed characters
83
+ del stream[:j]
84
+
85
+ return id
86
+
87
+ def encode_file(self, fd):
88
+ """
89
+ Encodes a word from a file descriptor.
90
+
91
+ Args:
92
+ fd (file object): The file to encode from.
93
+
94
+ Returns:
95
+ int: The ID of the encoded word.
96
+ """
97
+ word = ''
98
+ id = -1
99
+ j = 0
100
+
101
+ for _ in range(MAX_WORD_LEN):
102
+ c = fd.read(1)
103
+ if not c:
104
+ break
105
+ char = c.decode('utf-8', errors='ignore')
106
+ word += char
107
+ tmp = self.encode_word(word)
108
+ if tmp != -1:
109
+ id = tmp
110
+ j = len(word)
111
+
112
+ # Seek back the remaining characters
113
+ to_seek = MAX_WORD_LEN - j
114
+ if to_seek > 0:
115
+ fd.seek(-to_seek, os.SEEK_CUR)
116
+
117
+ return id
118
+
119
+ def decode(self, id):
120
+ """Decodes an ID back into its corresponding word."""
121
+ if 0 <= id < self.vocab_size:
122
+ return self.vocab[id]
123
+ return None
124
+
125
+ def decode_file(self, fd):
126
+ """
127
+ Decodes an ID read from a file descriptor back into its corresponding word.
128
+
129
+ Args:
130
+ fd (file object): The file to decode from.
131
+
132
+ Returns:
133
+ str: The decoded word.
134
+ """
135
+ data = fd.read(4) # Read 4 bytes for an integer
136
+ if len(data) < 4:
137
+ ERROR("read EOF from file\n")
138
+
139
+ id = struct.unpack('i', data)[0]
140
+ return self.decode(id)
141
+
142
+ def save_vocab(self, fname):
143
+ """Saves the vocabulary to a text file, one word per line."""
144
+ try:
145
+ with open(fname, 'w', encoding='utf-8') as f:
146
+ max_len = 0
147
+ for i in range(self.vocab_size):
148
+ word = self.vocab[i]
149
+ f.write(word + '\n')
150
+ if len(word) > max_len:
151
+ max_len = len(word)
152
+ INFO("wrote %d tokens to file \"%s\"\nMax token length was %d",
153
+ self.vocab_size, fname, max_len)
154
+ except IOError as e:
155
+ ERROR("failed to write to \"%s\": %s\n", fname, str(e))
156
+
157
+ def load_vocab(self, fname):
158
+ """Loads the vocabulary from a text file, expecting one word per line."""
159
+ try:
160
+ with open(fname, 'r', encoding='utf-8') as f:
161
+ for line in f:
162
+ word = line.strip()
163
+ if word:
164
+ self.add_word(word)
165
+ except IOError as e:
166
+ ERROR("failed to open \"%s\": %s\n", fname, str(e))
167
+
168
+ def save_tokenizer(self, fname):
169
+ """Saves the tokenizer's vocabulary to a binary file."""
170
+ try:
171
+ with open(fname, 'wb') as f:
172
+ for i in range(MAX_VOCAB_SIZE):
173
+ if i < self.vocab_size:
174
+ word = self.vocab[i].encode('utf-8')
175
+ if len(word) >= MAX_WORD_LEN:
176
+ word = word[:MAX_WORD_LEN - 1]
177
+ word += b'\0' * (MAX_WORD_LEN - len(word))
178
+ else:
179
+ word = b'\0' * MAX_WORD_LEN
180
+ f.write(word)
181
+ INFO("wrote %d bytes (%d tokens) to \"%s\"",
182
+ MAX_VOCAB_SIZE * MAX_WORD_LEN, self.vocab_size, fname)
183
+ except IOError as e:
184
+ ERROR("failed to write to \"%s\": %s\n", fname, str(e))
185
+
186
+ def load_tokenizer(self, fname):
187
+ """Loads the tokenizer's vocabulary from a binary file."""
188
+ try:
189
+ with open(fname, 'rb') as f:
190
+ for i in range(MAX_VOCAB_SIZE):
191
+ bytes_word = f.read(MAX_WORD_LEN)
192
+ if not bytes_word or len(bytes_word) < MAX_WORD_LEN:
193
+ break
194
+ # Decode up to the first null byte
195
+ word = bytes_word.split(b'\0', 1)[0].decode('utf-8', errors='ignore')
196
+ if word:
197
+ self.vocab[i] = word
198
+ self.vocab_size += 1
199
+ INFO("read %d bytes (%d tokens) from \"%s\"",
200
+ self.vocab_size * MAX_WORD_LEN, self.vocab_size, fname)
201
+ except IOError as e:
202
+ ERROR("failed to read from \"%s\": %s\n", fname, str(e))
203
+
204
+ @staticmethod
205
+ def _compare(a, b):
206
+ """Helper method to compare two strings similar to strcmp in C."""
207
+ return (a > b) - (a < b)