Tarush-AI commited on
Commit
43993b6
·
verified ·
1 Parent(s): a8689db

Upload tokenizer.py

Browse files
Files changed (1) hide show
  1. model/vocab/tokenizer.py +81 -0
model/vocab/tokenizer.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, sys
2
+ import sentencepiece as spm
3
+ from config import PROJECT_ROOT, vocab_size
4
+ from model.vocab.preprocess import Preprocessor
5
+
6
+ class Tokenizer:
7
+ def __init__(self):
8
+ self.sp = spm.SentencePieceProcessor()
9
+
10
+ def train(self, all_sentences):
11
+ spm.SentencePieceTrainer.train(
12
+ sentence_iterator=iter(all_sentences),
13
+ model_prefix="data/tokenizer",
14
+ model_type="bpe",
15
+ vocab_size=vocab_size,
16
+ user_defined_symbols=["<BEGIN>", "<END>", "<PAD>"]
17
+ )
18
+ self.sp.Load(self.path)
19
+
20
+ def load_weights(self, path):
21
+ self.sp.Load(path)
22
+
23
+ def encode(self, text):
24
+ return self.sp.EncodeAsIds(text)
25
+
26
+ def decode(self, ids):
27
+ return self.sp.DecodeIds(ids)
28
+
29
+
30
+ def test(self, file):
31
+ text_content = None
32
+ if file:
33
+ text_content = file
34
+
35
+ if not text_content:
36
+ test_file_path = os.path.join(os.path.dirname(__file__), "tokenize_test.txt")
37
+ if os.path.exists(test_file_path):
38
+ with open(test_file_path, "r") as f:
39
+ text_content = f.read()
40
+ else:
41
+ print(f"Default test file not found at {test_file_path}")
42
+ return
43
+
44
+ if text_content:
45
+ try:
46
+ encoded_ids = self.encode(text_content)
47
+ decoded_text = self.decode(encoded_ids)
48
+
49
+ output_content = f"Original Text:\n{text_content}\n\nToken IDs:\n{encoded_ids}\n\nDecoded Text:\n{decoded_text}\n"
50
+
51
+ output_file_path = os.path.join(os.path.dirname(__file__), "tokenize_test_output.txt")
52
+ with open(output_file_path, "w") as f:
53
+ f.write(output_content)
54
+
55
+ print(f"Saved to {output_file_path}.")
56
+ except Exception as e:
57
+ print(f"Error during tokenization test: {e}")
58
+
59
+ if __name__ == "__main__":
60
+ file = None
61
+ if len(sys.argv) > 1:
62
+ test = sys.argv[1]
63
+ if test != "test":
64
+ print("Only permitted argument is 'test'; Please try again.")
65
+ pass
66
+
67
+ else:
68
+ print("Tokenization logic is wrapped into overall training functionality.")
69
+ pass
70
+
71
+ if len(sys.argv) > 2:
72
+ filepath = sys.argv[2]
73
+ try:
74
+ with open(filepath, "r") as f:
75
+ file = f.read()
76
+ except Exception as e:
77
+ print("Invalid filepath, falling back to original test.")
78
+ file = None
79
+
80
+ if len(sys.argv) > 1 and sys.argv[1] == "test":
81
+ Tokenizer().test(file)