vedaco commited on
Commit
c4cd8de
·
verified ·
1 Parent(s): 594bc39

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +175 -113
tokenizer.py CHANGED
@@ -1,21 +1,28 @@
1
- """Tokenizer for Veda Programming Assistant"""
2
 
3
  import json
4
  import re
5
- from typing import List, Dict, Optional
6
-
7
 
8
  class VedaTokenizer:
9
- """Tokenizer with conversation support"""
 
 
 
10
 
11
  def __init__(self, vocab_size: int = 8000):
12
  self.vocab_size = vocab_size
13
  self.token_to_idx: Dict[str, int] = {}
14
  self.idx_to_token: Dict[int, str] = {}
15
- self._init_vocab()
16
-
17
- def _init_vocab(self):
18
- """Initialize vocabulary with conversation tokens"""
 
 
 
 
 
19
  special = [
20
  "<PAD>", "<UNK>", "<START>", "<END>",
21
  "<CODE>", "<ENDCODE>",
@@ -26,143 +33,190 @@ class VedaTokenizer:
26
  self.token_to_idx[token] = idx
27
  self.idx_to_token[idx] = token
28
 
 
29
  idx = len(special)
 
30
  for i in range(32, 127):
31
  char = chr(i)
32
- self.token_to_idx[char] = idx
33
- self.idx_to_token[idx] = char
34
- idx += 1
35
-
36
- for char in ["\n", "\t"]:
37
- self.token_to_idx[char] = idx
38
- self.idx_to_token[idx] = char
39
- idx += 1
40
 
 
 
 
 
 
 
 
41
  self.base_vocab_size = idx
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def fit(self, texts: List[str]):
44
- """Build vocabulary"""
45
- word_freq = {}
 
 
 
46
 
47
  for text in texts:
48
- words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+|[^\s]', text)
 
49
  for word in words:
50
- word_freq[word] = word_freq.get(word, 0) + 1
 
 
 
 
 
 
51
 
52
- sorted_words = sorted(word_freq.items(), key=lambda x: -x[1])
53
 
54
- idx = self.base_vocab_size
55
- for word, _ in sorted_words:
56
- if idx >= self.vocab_size:
57
  break
58
- if word not in self.token_to_idx and len(word) <= 25:
59
- self.token_to_idx[word] = idx
60
- self.idx_to_token[idx] = word
61
- idx += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- print(f"Vocabulary: {len(self.token_to_idx)} tokens")
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def encode(self, text: str, max_length: Optional[int] = None) -> List[int]:
66
- """Encode text"""
67
- tokens = self._tokenize(text)
 
68
  encoded = []
69
 
70
- for token in tokens:
71
- if token in self.token_to_idx:
72
- encoded.append(self.token_to_idx[token])
73
  else:
74
- for char in token:
75
- encoded.append(self.token_to_idx.get(char, 1))
 
 
76
 
 
77
  if max_length:
78
- if len(encoded) < max_length:
79
- encoded += [0] * (max_length - len(encoded))
80
- else:
81
  encoded = encoded[:max_length]
 
 
82
 
83
  return encoded
84
 
85
- def _tokenize(self, text: str) -> List[str]:
86
- """Tokenize text"""
87
- tokens = []
88
- parts = re.split(r'(\s+)', text)
89
-
90
- for part in parts:
91
- if not part:
92
- continue
93
- if part.isspace():
94
- for char in part:
95
- tokens.append(char)
96
- elif part in self.token_to_idx:
97
- tokens.append(part)
98
- else:
99
- i = 0
100
- while i < len(part):
101
- matched = False
102
- for length in range(min(len(part) - i, 20), 0, -1):
103
- substr = part[i:i+length]
104
- if substr in self.token_to_idx:
105
- tokens.append(substr)
106
- i += length
107
- matched = True
108
- break
109
- if not matched:
110
- tokens.append(part[i])
111
- i += 1
112
-
113
- return tokens
114
-
115
  def decode(self, indices: List[int]) -> str:
116
  """Decode indices to text"""
117
- result = []
118
- prev = ""
119
-
120
  for idx in indices:
121
- if idx == 0:
122
- continue
123
- if idx not in self.idx_to_token:
124
- continue
125
-
126
- token = self.idx_to_token[idx]
127
-
128
- if token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
129
- continue
130
-
131
- if token == "<CODE>":
132
- result.append("\n```python\n")
133
- prev = "\n"
134
- continue
135
- if token == "<ENDCODE>":
136
- result.append("\n```\n")
137
- prev = "\n"
138
- continue
139
-
140
- if not result:
141
- result.append(token)
142
- elif token in "\n\t":
143
- result.append(token)
144
- elif token in ".,;:!?()[]{}":
145
- result.append(token)
146
- elif prev in "(\n\t[{":
147
- result.append(token)
148
- elif len(prev) > 0 and prev[-1].isalnum() and len(token) > 0 and token[0].isalnum():
149
- result.append(" " + token)
150
- else:
151
- result.append(token)
152
-
153
- prev = token
154
 
155
- return "".join(result)
156
 
157
  def save(self, path: str):
158
  """Save tokenizer"""
 
 
 
 
 
 
 
159
  with open(path, 'w') as f:
160
- json.dump({
161
- 'vocab_size': self.vocab_size,
162
- 'token_to_idx': self.token_to_idx,
163
- 'idx_to_token': {str(k): v for k, v in self.idx_to_token.items()},
164
- 'base_vocab_size': self.base_vocab_size
165
- }, f, indent=2)
166
 
167
  def load(self, path: str):
168
  """Load tokenizer"""
@@ -172,6 +226,14 @@ class VedaTokenizer:
172
  self.token_to_idx = data['token_to_idx']
173
  self.idx_to_token = {int(k): v for k, v in data['idx_to_token'].items()}
174
  self.base_vocab_size = data.get('base_vocab_size', 100)
 
 
 
 
 
 
 
 
175
 
176
  @property
177
  def vocabulary_size(self) -> int:
 
1
+ """Subword Tokenizer (BPE-like) for Veda Programming Assistant"""
2
 
3
  import json
4
  import re
5
+ from typing import List, Dict, Optional, Tuple
 
6
 
7
  class VedaTokenizer:
8
+ """
9
+ Subword tokenizer that learns common subwords/phrases.
10
+ Better than word-level or char-level tokenization.
11
+ """
12
 
13
  def __init__(self, vocab_size: int = 8000):
14
  self.vocab_size = vocab_size
15
  self.token_to_idx: Dict[str, int] = {}
16
  self.idx_to_token: Dict[int, str] = {}
17
+
18
+ # Base vocabulary (special tokens + ASCII)
19
+ self._init_base_vocab()
20
+
21
+ # Merges for subwords (pair -> new_token)
22
+ self.merges: Dict[Tuple[str, str], str] = {}
23
+
24
+ def _init_base_vocab(self):
25
+ """Initialize base vocabulary"""
26
  special = [
27
  "<PAD>", "<UNK>", "<START>", "<END>",
28
  "<CODE>", "<ENDCODE>",
 
33
  self.token_to_idx[token] = idx
34
  self.idx_to_token[idx] = token
35
 
36
+ # ASCII characters as base tokens
37
  idx = len(special)
38
+ # Printable ASCII range
39
  for i in range(32, 127):
40
  char = chr(i)
41
+ if char not in self.token_to_idx:
42
+ self.token_to_idx[char] = idx
43
+ self.idx_to_token[idx] = char
44
+ idx += 1
 
 
 
 
45
 
46
+ # Common whitespace
47
+ for char in ["\n", "\t", " "]: # spaces for indentation
48
+ if char not in self.token_to_idx:
49
+ self.token_to_idx[char] = idx
50
+ self.idx_to_token[idx] = char
51
+ idx += 1
52
+
53
  self.base_vocab_size = idx
54
 
55
+ def _get_stats(self, vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, str], int]:
56
+ """Count frequency of adjacent pairs"""
57
+ pairs = {}
58
+ for word_tuple, freq in vocab.items():
59
+ for i in range(len(word_tuple) - 1):
60
+ pair = (word_tuple[i], word_tuple[i+1])
61
+ pairs[pair] = pairs.get(pair, 0) + freq
62
+ return pairs
63
+
64
+ def _merge_vocab(self, pair: Tuple[str, str], vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]:
65
+ """Merge all occurrences of pair in vocabulary"""
66
+ new_vocab = {}
67
+ bigram = pair
68
+ new_token = "".join(pair)
69
+
70
+ for word, freq in vocab.items():
71
+ new_word = []
72
+ i = 0
73
+ while i < len(word):
74
+ if i < len(word) - 1 and word[i] == bigram[0] and word[i+1] == bigram[1]:
75
+ new_word.append(new_token)
76
+ i += 2
77
+ else:
78
+ new_word.append(word[i])
79
+ i += 1
80
+ new_vocab[tuple(new_word)] = freq
81
+
82
+ return new_vocab
83
+
84
  def fit(self, texts: List[str]):
85
+ """Train BPE tokenizer on texts"""
86
+ # Pre-tokenize into words to avoid merging across word boundaries
87
+ # This regex splits by whitespace but keeps punctuation
88
+ # Also handles code symbols better
89
+ word_counts = {}
90
 
91
  for text in texts:
92
+ # Simple pre-tokenization for code
93
+ words = re.findall(r'[a-zA-Z0-9_]+|[^\s\w]', text)
94
  for word in words:
95
+ # Convert word to tuple of characters
96
+ token_tuple = tuple(c for c in word)
97
+ word_counts[token_tuple] = word_counts.get(token_tuple, 0) + 1
98
+
99
+ # BPE training loop
100
+ vocab = word_counts
101
+ num_merges = self.vocab_size - self.base_vocab_size
102
 
103
+ print(f"Training BPE tokenizer (target vocab: {self.vocab_size})...")
104
 
105
+ for i in range(num_merges):
106
+ pairs = self._get_stats(vocab)
107
+ if not pairs:
108
  break
109
+
110
+ # Find most frequent pair
111
+ best_pair = max(pairs, key=pairs.get)
112
+
113
+ # Stop if pair frequency is too low (e.g., 1)
114
+ if pairs[best_pair] < 2:
115
+ break
116
+
117
+ # Merge pair
118
+ vocab = self._merge_vocab(best_pair, vocab)
119
+
120
+ # Add new token to vocabulary
121
+ new_token = "".join(best_pair)
122
+ self.merges[best_pair] = new_token
123
+
124
+ idx = len(self.token_to_idx)
125
+ self.token_to_idx[new_token] = idx
126
+ self.idx_to_token[idx] = new_token
127
+
128
+ if (i + 1) % 100 == 0:
129
+ print(f"BPE merge {i+1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}'")
130
 
131
+ print(f"BPE training complete. Final vocab size: {len(self.token_to_idx)}")
132
+
133
+ def _tokenize_word(self, word: str) -> List[str]:
134
+ """Tokenize a single word using learned merges"""
135
+ if word in self.token_to_idx:
136
+ return [word]
137
+
138
+ # Start with characters
139
+ tokens = list(word)
140
+
141
+ # Apply merges iteratively
142
+ # Note: In a real BPE implementation we would apply in order of priority
143
+ # Here we do a simpler greedy application based on length
144
+ while True:
145
+ merged = False
146
+ i = 0
147
+ new_tokens = []
148
+
149
+ while i < len(tokens) - 1:
150
+ pair = (tokens[i], tokens[i+1])
151
+ pair_str = "".join(pair)
152
+
153
+ # Check if this pair forms a known token
154
+ if pair_str in self.token_to_idx:
155
+ new_tokens.append(pair_str)
156
+ i += 2
157
+ merged = True
158
+ else:
159
+ new_tokens.append(tokens[i])
160
+ i += 1
161
+
162
+ if i < len(tokens):
163
+ new_tokens.append(tokens[i])
164
+
165
+ if not merged:
166
+ break
167
+
168
+ tokens = new_tokens
169
+
170
+ return tokens
171
+
172
  def encode(self, text: str, max_length: Optional[int] = None) -> List[int]:
173
+ """Encode text to token indices"""
174
+ # Pre-tokenize same way as training
175
+ words = re.findall(r'[a-zA-Z0-9_]+|[^\s\w]|\s+', text)
176
  encoded = []
177
 
178
+ for word in words:
179
+ if word in self.token_to_idx:
180
+ encoded.append(self.token_to_idx[word])
181
  else:
182
+ # Apply BPE
183
+ subwords = self._tokenize_word(word)
184
+ for sw in subwords:
185
+ encoded.append(self.token_to_idx.get(sw, self.token_to_idx["<UNK>"]))
186
 
187
+ # Truncate or Pad
188
  if max_length:
189
+ if len(encoded) > max_length:
 
 
190
  encoded = encoded[:max_length]
191
+ elif len(encoded) < max_length:
192
+ encoded += [self.token_to_idx["<PAD>"]] * (max_length - len(encoded))
193
 
194
  return encoded
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def decode(self, indices: List[int]) -> str:
197
  """Decode indices to text"""
198
+ tokens = []
 
 
199
  for idx in indices:
200
+ # Skip special tokens if needed, but usually we decode them
201
+ # and let post-processing handle cleanup
202
+ if idx in self.idx_to_token:
203
+ token = self.idx_to_token[idx]
204
+ if token not in ["<PAD>", "<UNK>", "<START>", "<END>"]:
205
+ tokens.append(token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ return "".join(tokens)
208
 
209
  def save(self, path: str):
210
  """Save tokenizer"""
211
+ data = {
212
+ 'vocab_size': self.vocab_size,
213
+ 'token_to_idx': self.token_to_idx,
214
+ 'idx_to_token': {str(k): v for k, v in self.idx_to_token.items()},
215
+ 'base_vocab_size': self.base_vocab_size,
216
+ 'merges': {f"{p[0]}|{p[1]}": m for p, m in self.merges.items()}
217
+ }
218
  with open(path, 'w') as f:
219
+ json.dump(data, f, indent=2)
 
 
 
 
 
220
 
221
  def load(self, path: str):
222
  """Load tokenizer"""
 
226
  self.token_to_idx = data['token_to_idx']
227
  self.idx_to_token = {int(k): v for k, v in data['idx_to_token'].items()}
228
  self.base_vocab_size = data.get('base_vocab_size', 100)
229
+
230
+ # Load merges
231
+ if 'merges' in data:
232
+ self.merges = {}
233
+ for k, v in data['merges'].items():
234
+ p = k.split('|')
235
+ if len(p) == 2:
236
+ self.merges[(p[0], p[1])] = v
237
 
238
  @property
239
  def vocabulary_size(self) -> int: