File size: 3,722 Bytes
c312c47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""BPE tokenizer — GPT-2 style byte-level BPE (matches Julia SLM tokenizer)."""
import json
from typing import Dict, List, Tuple

try:
    import regex as re
except ImportError:
    import re

_GPT2_PAT = re.compile(
    r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
    re.UNICODE,
)


def _build_byte_to_unicode() -> Dict[int, str]:
    bs = list(range(ord("!"), ord("~") + 1))
    bs += list(range(ord("¡"), ord("¬") + 1))
    bs += list(range(ord("®"), ord("ÿ") + 1))
    cs = list(bs)
    n = 0
    for b in range(256):
        if b not in bs:
            bs.append(b)
            cs.append(256 + n)
            n += 1
    return {b: chr(c) for b, c in zip(bs, cs)}


class BPETokenizer:
    def __init__(self, encoder: Dict[str, int], merges: List[Tuple[str, str]]):
        self.encoder = encoder
        self.decoder = {v: k for k, v in encoder.items()}
        self.merges = merges
        self.merge_ranks = {pair: i for i, pair in enumerate(merges)}
        self.byte_to_unicode = _build_byte_to_unicode()
        self.unicode_to_byte = {v: k for k, v in self.byte_to_unicode.items()}

    @classmethod
    def from_files(cls, vocab_path: str, merges_path: str) -> "BPETokenizer":
        with open(vocab_path, "r", encoding="utf-8") as f:
            encoder = json.load(f)
        merges = []
        with open(merges_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line.startswith("#") or not line:
                    continue
                parts = line.split()
                if len(parts) == 2:
                    merges.append((parts[0], parts[1]))
        return cls(encoder, merges)

    @property
    def vocab_size(self) -> int:
        return len(self.encoder)

    def encode(self, text: str) -> List[int]:
        tokens = []
        for match in _GPT2_PAT.finditer(text):
            word = match.group()
            encoded_chars = [self.byte_to_unicode[b] for b in word.encode("utf-8")]
            symbols = self._bpe_encode_word(list(encoded_chars))
            for tok in symbols:
                token_id = self.encoder.get(tok)
                if token_id is not None:
                    tokens.append(token_id)
        return tokens

    def decode(self, ids: List[int]) -> str:
        token_strs = [self.decoder.get(i, "") for i in ids]
        joined = "".join(token_strs)
        out = bytearray()
        for c in joined:
            b = self.unicode_to_byte.get(c)
            if b is not None:
                out.append(b)
            else:
                out.extend(c.encode("utf-8"))
        return out.decode("utf-8", errors="replace")

    def _bpe_encode_word(self, symbols: List[str]) -> List[str]:
        while len(symbols) > 1:
            best_pair = None
            best_rank = float("inf")
            for i in range(len(symbols) - 1):
                pair = (symbols[i], symbols[i + 1])
                rank = self.merge_ranks.get(pair, float("inf"))
                if rank < best_rank:
                    best_rank = rank
                    best_pair = pair
            if best_rank == float("inf"):
                break
            new_symbols = []
            i = 0
            while i < len(symbols):
                if (
                    i < len(symbols) - 1
                    and symbols[i] == best_pair[0]
                    and symbols[i + 1] == best_pair[1]
                ):
                    new_symbols.append(best_pair[0] + best_pair[1])
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols
        return symbols