flpelerin commited on
Commit
1907275
·
1 Parent(s): 3400bb8

Update 4 files

Browse files

- /tokenizer.py
- /dataset.py
- /tokenizer.cli.py
- /trainer.cli.py

Files changed (4) hide show
  1. dataset.py +7 -1
  2. tokenizer.cli.py +1 -0
  3. tokenizer.py +148 -0
  4. trainer.cli.py +9 -3
dataset.py CHANGED
@@ -11,5 +11,11 @@ class Dataset:
11
  self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii')
12
 
13
 
14
- def Batch(self, ids):
 
 
 
 
 
 
15
  pass
 
11
  self.text = ''.join(s for s in self.dataset['train']['text']).encode('ascii', 'ignore').decode('ascii')
12
 
13
 
14
+ def __iadd__(self, value):
15
+ attr_name = value.__name__ if hasattr(value, '__name__') else type(value).__name__.lower()
16
+ setattr(self, attr_name, value)
17
+ return self
18
+
19
+
20
+ def batch(self, value): # TODO: Implement
21
  pass
tokenizer.cli.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # TODO: Implement
tokenizer.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import struct
3
+ import re
4
+
5
+
6
+
7
+ class Token:
8
+ def __init__(self, byte, prev):
9
+ self.byte = byte
10
+ self.prev = prev
11
+
12
+
13
+ def pack(self):
14
+ if not 0 <= ord(self.byte) <= 255:
15
+ raise ValueError(f"Byte value is out of range, got {self.byte} ({ord(self.byte)})")
16
+
17
+ return struct.pack("=B H", ord(self.byte), self.prev)
18
+
19
+
20
+ def __str__(self):
21
+ return f"{self.byte}, {self.prev}"
22
+
23
+ def to_binary(self):
24
+ return self.pack()
25
+
26
+
27
+
28
+
29
+ class Tokenizer:
30
+ def __init__(self):
31
+ self.vocab = [Token(chr(i), 0) for i in range(256)] # define base vocab from ASCII values
32
+
33
+
34
+ def find(self, byte, prev):
35
+ for i in range(prev, self.vocab_size):
36
+ token = self.vocab[i]
37
+ if token.byte == byte and token.prev == prev:
38
+ return i
39
+
40
+ return 0
41
+
42
+
43
+ def append(self, byte, prev):
44
+ token = self.find(byte, prev)
45
+ if token:
46
+ return token
47
+
48
+ self.vocab.append(Token(byte, prev))
49
+ return self.vocab_size - 1
50
+
51
+
52
+ def encode_one(self, text):
53
+ prev = 0
54
+
55
+ for i in range(len(text)):
56
+ byte = text[i]
57
+ token = self.find(byte, prev)
58
+
59
+ if token == 0:
60
+ return prev, text[i:]
61
+
62
+ prev = token
63
+
64
+ return prev, ''
65
+
66
+
67
+ def encode(self, text):
68
+ ids = []
69
+
70
+ while text:
71
+ token, text = self.encode_one(text)
72
+ ids.append(token)
73
+
74
+ return ids
75
+
76
+
77
+ def decode_one(self, token):
78
+ text = ""
79
+
80
+ while token:
81
+ text += self.vocab[token].byte
82
+ token = self.vocab[token].prev
83
+
84
+ return text[::-1]
85
+
86
+
87
+ def decode(self, ids):
88
+ text = ""
89
+
90
+ for token in ids:
91
+ text += self.decode_one(token)
92
+
93
+ return text
94
+
95
+
96
+ def add_special(self, text):
97
+ #print(f"Encoding string: {text}")
98
+ token = ord(text[0])
99
+ for byte in text[1:]:
100
+ token = self.append(byte, token)
101
+ #print(f"Working on byte {byte}")
102
+
103
+
104
+ @property
105
+ def vocab_size(self):
106
+ return len(self.vocab)
107
+
108
+
109
+ def __str__(self):
110
+ return '[' + ', '.join(str(token) for token in self.vocab) + ']'
111
+
112
+
113
+ def to_file(self, file):
114
+ with open(file, 'ab') as f:
115
+ for token in self.vocab:
116
+ f.write(token.to_binary())
117
+
118
+
119
+ def from_file(self, file):
120
+ self.clear()
121
+ with open(file, 'rb') as f:
122
+ while True:
123
+ try:
124
+ data = f.read(3)
125
+ token = Token.from_binary(data)
126
+ self.vocab += token
127
+ except ValueError:
128
+ break
129
+
130
+
131
+ def train(self, text, max_length=32000):
132
+ words = text.split()
133
+ words = [' ' + ''.join(re.findall(r'\w', word)) for word in words]
134
+ words = [word for word in words if len(word) >= 2]
135
+
136
+ word_freq = Counter(words)
137
+ sorted_words = sorted(word_freq, key=lambda x: (-word_freq[x], x))
138
+
139
+ for word in sorted_words:
140
+ if self.vocab_size > max_length:
141
+ break
142
+
143
+ self.add_special(word)
144
+ print(f"adding word: {word} | current vocab size: {self.vocab_size} | max length: {max_length}")
145
+
146
+
147
+ def c_encode(self, text): #TODO: Implement
148
+ return []
trainer.cli.py CHANGED
@@ -6,7 +6,7 @@ from logger import Wandb
6
 
7
  from trainer import Trainer
8
  from dataset import Dataset
9
- #from tokenizer import Tokenizer
10
 
11
 
12
 
@@ -27,6 +27,12 @@ if __name__ == '__main__':
27
 
28
  dataset = Dataset(config.dataset)
29
 
30
- #tokenizer = Tokenizer()
 
 
 
 
 
31
 
32
- trainer = Trainer(config)
 
 
6
 
7
  from trainer import Trainer
8
  from dataset import Dataset
9
+ from tokenizer import Tokenizer
10
 
11
 
12
 
 
27
 
28
  dataset = Dataset(config.dataset)
29
 
30
+ tokenizer = Tokenizer()
31
+ tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
32
+ ids = tokenizer.c_encode(dataset.text)
33
+
34
+ dataset += ids
35
+ dataset.batch(ids)
36
 
37
+ trainer = Trainer(config)
38
+ trainer.train(dataset)