LisaMegaWatts commited on
Commit
215b74b
verified
1 Parent(s): d968181

Add BPE tokenizer (needed for Colab training notebooks)

Browse files
Files changed (1) hide show
  1. transformer_tokenizer.py +151 -0
transformer_tokenizer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BPE tokenizer for causal LM training (Phase D: Symbiotic Distillation).
2
+
3
+ Minimal GPT-2 style byte-level BPE tokenizer that loads vocab.json and
4
+ merges.txt files produced by the text-pipeline. No HuggingFace dependency.
5
+
6
+ Matches the Julia SLM BPETokenizer (tokenizer.jl) encoding/decoding.
7
+ """
8
+ import json
9
+ from typing import Dict, List, Tuple
10
+
11
+ try:
12
+ import regex as re # supports \p{L} Unicode property escapes
13
+ except ImportError:
14
+ import re # fallback (will fail on \p{L} patterns)
15
+
16
+
17
+ def _build_byte_to_unicode() -> Dict[int, str]:
18
+ """GPT-2 byte-to-unicode mapping (matches Julia _build_byte_to_unicode)."""
19
+ bs = list(range(ord("!"), ord("~") + 1))
20
+ bs += list(range(ord("隆"), ord("卢") + 1))
21
+ bs += list(range(ord("庐"), ord("每") + 1))
22
+ cs = list(bs)
23
+ n = 0
24
+ for b in range(256):
25
+ if b not in bs:
26
+ bs.append(b)
27
+ cs.append(256 + n)
28
+ n += 1
29
+ return {b: chr(c) for b, c in zip(bs, cs)}
30
+
31
+
32
+ # GPT-2 pre-tokenization pattern
33
+ _GPT2_PAT = re.compile(
34
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
35
+ re.UNICODE,
36
+ )
37
+
38
+
39
+ class BPETokenizer:
40
+ """Minimal byte-level BPE tokenizer for causal LM training.
41
+
42
+ Loads vocab.json + merges.txt and provides encode/decode methods
43
+ compatible with the Julia SLM tokenizer (0-indexed token IDs).
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ encoder: Dict[str, int],
49
+ merges: List[Tuple[str, str]],
50
+ ):
51
+ self.encoder = encoder
52
+ self.decoder = {v: k for k, v in encoder.items()}
53
+ self.merges = merges
54
+ self.merge_ranks = {pair: i for i, pair in enumerate(merges)}
55
+ self.byte_to_unicode = _build_byte_to_unicode()
56
+ self.unicode_to_byte = {v: k for k, v in self.byte_to_unicode.items()}
57
+
58
+ @classmethod
59
+ def from_files(cls, vocab_path: str, merges_path: str) -> "BPETokenizer":
60
+ """Load tokenizer from vocab.json and merges.txt files."""
61
+ with open(vocab_path, "r", encoding="utf-8") as f:
62
+ encoder = json.load(f)
63
+
64
+ merges = []
65
+ with open(merges_path, "r", encoding="utf-8") as f:
66
+ for line in f:
67
+ line = line.strip()
68
+ if line.startswith("#") or not line:
69
+ continue
70
+ parts = line.split()
71
+ if len(parts) == 2:
72
+ merges.append((parts[0], parts[1]))
73
+
74
+ return cls(encoder, merges)
75
+
76
+ @property
77
+ def vocab_size(self) -> int:
78
+ return len(self.encoder)
79
+
80
+ @property
81
+ def pad_token_id(self) -> int:
82
+ return self.encoder.get("<|pad|>", 0)
83
+
84
+ @property
85
+ def eos_token_id(self) -> int:
86
+ return self.encoder.get("<|eos|>", 1)
87
+
88
+ def encode(self, text: str) -> List[int]:
89
+ """Encode text to token IDs (0-indexed)."""
90
+ tokens = []
91
+ for match in _GPT2_PAT.finditer(text):
92
+ word = match.group()
93
+ # Convert bytes to unicode representation
94
+ encoded_chars = [self.byte_to_unicode[b] for b in word.encode("utf-8")]
95
+ # Apply BPE merges
96
+ symbols = list(encoded_chars)
97
+ symbols = self._bpe_encode_word(symbols)
98
+ # Look up token IDs
99
+ for tok in symbols:
100
+ token_id = self.encoder.get(tok)
101
+ if token_id is not None:
102
+ tokens.append(token_id)
103
+ return tokens
104
+
105
+ def decode(self, ids: List[int]) -> str:
106
+ """Decode token IDs back to text."""
107
+ token_strs = [self.decoder.get(i, "") for i in ids]
108
+ joined = "".join(token_strs)
109
+ # Convert unicode chars back to bytes
110
+ out = bytearray()
111
+ for c in joined:
112
+ b = self.unicode_to_byte.get(c)
113
+ if b is not None:
114
+ out.append(b)
115
+ else:
116
+ out.extend(c.encode("utf-8"))
117
+ return out.decode("utf-8", errors="replace")
118
+
119
+ def _bpe_encode_word(self, symbols: List[str]) -> List[str]:
120
+ """Iteratively merge the highest-priority pair."""
121
+ while len(symbols) > 1:
122
+ # Find best merge pair
123
+ best_pair = None
124
+ best_rank = float("inf")
125
+ for i in range(len(symbols) - 1):
126
+ pair = (symbols[i], symbols[i + 1])
127
+ rank = self.merge_ranks.get(pair, float("inf"))
128
+ if rank < best_rank:
129
+ best_rank = rank
130
+ best_pair = pair
131
+
132
+ if best_rank == float("inf"):
133
+ break # no more merges
134
+
135
+ # Apply the merge
136
+ new_symbols = []
137
+ i = 0
138
+ while i < len(symbols):
139
+ if (
140
+ i < len(symbols) - 1
141
+ and symbols[i] == best_pair[0]
142
+ and symbols[i + 1] == best_pair[1]
143
+ ):
144
+ new_symbols.append(best_pair[0] + best_pair[1])
145
+ i += 2
146
+ else:
147
+ new_symbols.append(symbols[i])
148
+ i += 1
149
+ symbols = new_symbols
150
+
151
+ return symbols