Twitch-BPE / src /bpe_tokenizer.py
Soldier-Boy's picture
create: src files
c6e5251 verified
from __future__ import annotations
import os
from typing import List, Dict
from . import config as CFG
from .utils import load_json
from .pretokenizer import pretokenize_line
class BPETokenizer:
def __init__(self, tokenizer_dir: str = CFG.TOKENIZER_DIR):
vocab_path = os.path.join(tokenizer_dir, 'bpe_vocab.json')
merges_path = os.path.join(tokenizer_dir, 'bpe_merges.txt')
config_path = os.path.join(tokenizer_dir, 'tokenizer_config.json')
self.vocab: Dict[str,int] = load_json(vocab_path)
self.id_to_tok = {i:t for t,i in self.vocab.items()}
self.merges: List[tuple[str,str]] = []
if os.path.exists(merges_path):
with open(merges_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) == 2:
self.merges.append((parts[0], parts[1]))
self.config = load_json(config_path) if os.path.exists(config_path) else {}
# Build merge map for fast lookup
self.merge_map = {a+b: (a,b) for a,b in self.merges}
def _byte_fallback(self, text: str) -> List[str]:
return [f"<b:{b:02X}>" for b in text.encode('utf-8')]
def encode(self, text: str) -> List[int]:
if not text:
return []
# Start with the same pretokenization used in training
symbols = pretokenize_line(text)
# Fallback decomposition for unknown tokens based on granularity
gran = getattr(CFG, 'TOKEN_GRANULARITY', 'byte')
if gran == 'word':
expanded: List[str] = []
for s in symbols:
if s in self.vocab or (s.startswith('<m:') and s.endswith('>')):
expanded.append(s)
else:
# Decompose rare word token into characters, then bytes for unseen characters
for ch in list(s):
if ch in self.vocab:
expanded.append(ch)
else:
expanded.extend([f"<b:{b:02X}>" for b in ch.encode('utf-8')])
symbols = expanded
elif gran == 'char':
expanded: List[str] = []
for s in symbols:
if s in self.vocab or (s.startswith('<m:') and s.endswith('>')):
expanded.append(s)
else:
# Fallback to bytes for unseen characters
expanded.extend([f"<b:{b:02X}>" for b in s.encode('utf-8')])
symbols = expanded
# Apply merges greedily left-to-right repeatedly
changed = True
while changed:
changed = False
i = 0
new_syms: List[str] = []
while i < len(symbols):
if i < len(symbols)-1:
pair = symbols[i] + symbols[i+1]
if pair in self.vocab:
new_syms.append(pair)
i += 2
changed = True
continue
new_syms.append(symbols[i])
i += 1
symbols = new_syms
return [self.vocab.get(s, CFG.UNK_ID) for s in symbols]
def decode(self, ids: List[int]) -> str:
# Expand merged tokens back to bytes by recursive splitting using known merges heuristically
out_bytes: List[int] = []
for i in ids:
tok = self.id_to_tok.get(i, '<UNK>')
# If token is a byte token <b:XX>
if tok.startswith('<b:') and tok.endswith('>') and len(tok) == 6:
try:
val = int(tok[3:5], 16)
out_bytes.append(val)
continue
except ValueError:
pass
# Structural tokens
if tok == '<ws>':
out_bytes.append(ord(' '))
continue
if tok == '<nl>':
out_bytes.append(ord('\n'))
continue
# Macro token <m:...>
if tok.startswith('<m:') and tok.endswith('>'):
literal = tok[3+1:-1] # after '<m:' read the payload until '>'
# Note: tok[0:3] == '<m:'; payload starts at 3
literal = tok[3:-1]
out_bytes.extend(literal.encode('utf-8'))
continue
# Otherwise attempt to split into byte tokens greedily
# Fallback: treat composite token as raw UTF-8 by searching embedded byte patterns
parts = []
j = 0
while j < len(tok):
if tok.startswith('<b:', j):
seg = tok[j:j+6]
if seg.startswith('<b:') and seg.endswith('>') and len(seg)==6:
try:
val = int(seg[3:5],16)
parts.append(val)
j += 6
continue
except Exception:
pass
# If not parseable, encode remaining substring
remainder = tok[j:]
parts.extend(list(remainder.encode('utf-8')))
break
out_bytes.extend(parts)
return bytes(out_bytes).decode('utf-8', errors='replace')