File size: 1,997 Bytes
6c7e241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
try:
    from .base import Tokenizer, get_stats, merge, visualise_tokens
except ImportError:  # allow running as a script from inside `tokenizer/`
    from base import Tokenizer, get_stats, merge, visualise_tokens

class BasicTokenizer(Tokenizer):
    def __init__(self):
        super().__init__()
    
    def train(self, text, vocab_size, verbose=False):
        # 'ids' is a list of integers, each representing a byte from the UTF-8 encoded string
        ids = list(text.encode("utf-8")) # list[int]
        if verbose:
            print(f"len(text) = {len(text)}")
            print(f"len(tokens) = {len(ids)}")
            
        num_merges = vocab_size - 256
        
        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for i in range(num_merges):
            stats = {}
            get_stats(ids, stats)
            pair = max(stats, key=stats.get) # (int, int)
            idx = 256 + i
            ids = merge(ids, pair, idx)
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            if verbose and i % 100 == 0:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
        
        self.vocab = vocab
        self.merges = merges
    
    def decode(self, ids) -> str:
        text = b"".join([self.vocab[id] for id in ids])
        text = text.decode(encoding="utf-8", errors="replace")
        return text
    
    def encode(self, text, verbose=False) -> list[int]:
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            if verbose:
                visualise_tokens([self.vocab[token] for token in tokens])
            stats = {}
            get_stats(tokens, stats)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if not pair in self.merges:
                break
            idx = self.merges[pair]
            tokens = merge(tokens, pair, idx)
        return tokens