veda-programming / tokenizer.py
vedaco's picture
Update tokenizer.py
ffd2cda verified
"""Tokenizer for Veda Programming Assistant"""
import json
import re
from typing import List, Dict, Optional
class VedaTokenizer:
"""Tokenizer with conversation support"""
def __init__(self, vocab_size: int = 8000):
self.vocab_size = vocab_size
self.token_to_idx: Dict[str, int] = {}
self.idx_to_token: Dict[int, str] = {}
self._init_vocab()
def _init_vocab(self):
"""Initialize vocabulary with conversation tokens"""
special = [
"<PAD>", "<UNK>", "<START>", "<END>",
"<CODE>", "<ENDCODE>",
"<USER>", "<ASSISTANT>"
]
for idx, token in enumerate(special):
self.token_to_idx[token] = idx
self.idx_to_token[idx] = token
idx = len(special)
for i in range(32, 127):
char = chr(i)
self.token_to_idx[char] = idx
self.idx_to_token[idx] = char
idx += 1
for char in ["\n", "\t"]:
self.token_to_idx[char] = idx
self.idx_to_token[idx] = char
idx += 1
self.base_vocab_size = idx
def fit(self, texts: List[str]):
"""Build vocabulary"""
word_freq = {}
for text in texts:
words = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+|[^\s]', text)
for word in words:
word_freq[word] = word_freq.get(word, 0) + 1
sorted_words = sorted(word_freq.items(), key=lambda x: -x[1])
idx = self.base_vocab_size
for word, _ in sorted_words:
if idx >= self.vocab_size:
break
if word not in self.token_to_idx and len(word) <= 25:
self.token_to_idx[word] = idx
self.idx_to_token[idx] = word
idx += 1
print(f"Vocabulary: {len(self.token_to_idx)} tokens")
def encode(self, text: str, max_length: Optional[int] = None) -> List[int]:
"""Encode text"""
tokens = self._tokenize(text)
encoded = []
for token in tokens:
if token in self.token_to_idx:
encoded.append(self.token_to_idx[token])
else:
for char in token:
encoded.append(self.token_to_idx.get(char, 1))
if max_length:
if len(encoded) < max_length:
encoded += [0] * (max_length - len(encoded))
else:
encoded = encoded[:max_length]
return encoded
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text"""
tokens = []
parts = re.split(r'(\s+)', text)
for part in parts:
if not part:
continue
if part.isspace():
for char in part:
tokens.append(char)
elif part in self.token_to_idx:
tokens.append(part)
else:
i = 0
while i < len(part):
matched = False
for length in range(min(len(part) - i, 20), 0, -1):
substr = part[i:i+length]
if substr in self.token_to_idx:
tokens.append(substr)
i += length
matched = True
break
if not matched:
tokens.append(part[i])
i += 1
return tokens
def decode(self, indices: List[int]) -> str:
"""Decode indices to text"""
result = []
prev = ""
for idx in indices:
if idx == 0:
continue
if idx not in self.idx_to_token:
continue
token = self.idx_to_token[idx]
if token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
continue
if token == "<CODE>":
result.append("\n```python\n")
prev = "\n"
continue
if token == "<ENDCODE>":
result.append("\n```\n")
prev = "\n"
continue
if not result:
result.append(token)
elif token in "\n\t":
result.append(token)
elif token in ".,;:!?()[]{}":
result.append(token)
elif prev in "(\n\t[{":
result.append(token)
elif len(prev) > 0 and prev[-1].isalnum() and len(token) > 0 and token[0].isalnum():
result.append(" " + token)
else:
result.append(token)
prev = token
return "".join(result)
def save(self, path: str):
"""Save tokenizer"""
with open(path, 'w') as f:
json.dump({
'vocab_size': self.vocab_size,
'token_to_idx': self.token_to_idx,
'idx_to_token': {str(k): v for k, v in self.idx_to_token.items()},
'base_vocab_size': self.base_vocab_size
}, f, indent=2)
def load(self, path: str):
"""Load tokenizer"""
with open(path, 'r') as f:
data = json.load(f)
self.vocab_size = data['vocab_size']
self.token_to_idx = data['token_to_idx']
self.idx_to_token = {int(k): v for k, v in data['idx_to_token'].items()}
self.base_vocab_size = data.get('base_vocab_size', 100)
@property
def vocabulary_size(self) -> int:
return len(self.token_to_idx)