AureliusGPT / model /vocab /tokenizer.py
Tarush-AI's picture
Upload tokenizer.py
43993b6 verified
import os, io, sys
import sentencepiece as spm
from config import PROJECT_ROOT, vocab_size
from model.vocab.preprocess import Preprocessor
class Tokenizer:
def __init__(self):
self.sp = spm.SentencePieceProcessor()
def train(self, all_sentences):
spm.SentencePieceTrainer.train(
sentence_iterator=iter(all_sentences),
model_prefix="data/tokenizer",
model_type="bpe",
vocab_size=vocab_size,
user_defined_symbols=["<BEGIN>", "<END>", "<PAD>"]
)
self.sp.Load(self.path)
def load_weights(self, path):
self.sp.Load(path)
def encode(self, text):
return self.sp.EncodeAsIds(text)
def decode(self, ids):
return self.sp.DecodeIds(ids)
def test(self, file):
text_content = None
if file:
text_content = file
if not text_content:
test_file_path = os.path.join(os.path.dirname(__file__), "tokenize_test.txt")
if os.path.exists(test_file_path):
with open(test_file_path, "r") as f:
text_content = f.read()
else:
print(f"Default test file not found at {test_file_path}")
return
if text_content:
try:
encoded_ids = self.encode(text_content)
decoded_text = self.decode(encoded_ids)
output_content = f"Original Text:\n{text_content}\n\nToken IDs:\n{encoded_ids}\n\nDecoded Text:\n{decoded_text}\n"
output_file_path = os.path.join(os.path.dirname(__file__), "tokenize_test_output.txt")
with open(output_file_path, "w") as f:
f.write(output_content)
print(f"Saved to {output_file_path}.")
except Exception as e:
print(f"Error during tokenization test: {e}")
if __name__ == "__main__":
file = None
if len(sys.argv) > 1:
test = sys.argv[1]
if test != "test":
print("Only permitted argument is 'test'; Please try again.")
pass
else:
print("Tokenization logic is wrapped into overall training functionality.")
pass
if len(sys.argv) > 2:
filepath = sys.argv[2]
try:
with open(filepath, "r") as f:
file = f.read()
except Exception as e:
print("Invalid filepath, falling back to original test.")
file = None
if len(sys.argv) > 1 and sys.argv[1] == "test":
Tokenizer().test(file)