from __future__ import annotations import os from typing import List, Dict from . import config as CFG from .utils import load_json from .pretokenizer import pretokenize_line class BPETokenizer: def __init__(self, tokenizer_dir: str = CFG.TOKENIZER_DIR): vocab_path = os.path.join(tokenizer_dir, 'bpe_vocab.json') merges_path = os.path.join(tokenizer_dir, 'bpe_merges.txt') config_path = os.path.join(tokenizer_dir, 'tokenizer_config.json') self.vocab: Dict[str,int] = load_json(vocab_path) self.id_to_tok = {i:t for t,i in self.vocab.items()} self.merges: List[tuple[str,str]] = [] if os.path.exists(merges_path): with open(merges_path, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split() if len(parts) == 2: self.merges.append((parts[0], parts[1])) self.config = load_json(config_path) if os.path.exists(config_path) else {} # Build merge map for fast lookup self.merge_map = {a+b: (a,b) for a,b in self.merges} def _byte_fallback(self, text: str) -> List[str]: return [f"" for b in text.encode('utf-8')] def encode(self, text: str) -> List[int]: if not text: return [] # Start with the same pretokenization used in training symbols = pretokenize_line(text) # Fallback decomposition for unknown tokens based on granularity gran = getattr(CFG, 'TOKEN_GRANULARITY', 'byte') if gran == 'word': expanded: List[str] = [] for s in symbols: if s in self.vocab or (s.startswith('')): expanded.append(s) else: # Decompose rare word token into characters, then bytes for unseen characters for ch in list(s): if ch in self.vocab: expanded.append(ch) else: expanded.extend([f"" for b in ch.encode('utf-8')]) symbols = expanded elif gran == 'char': expanded: List[str] = [] for s in symbols: if s in self.vocab or (s.startswith('')): expanded.append(s) else: # Fallback to bytes for unseen characters expanded.extend([f"" for b in s.encode('utf-8')]) symbols = expanded # Apply merges greedily left-to-right repeatedly changed = True while changed: changed = False i = 0 new_syms: List[str] = [] while i < len(symbols): if i < len(symbols)-1: pair = symbols[i] + symbols[i+1] if pair in self.vocab: new_syms.append(pair) i += 2 changed = True continue new_syms.append(symbols[i]) i += 1 symbols = new_syms return [self.vocab.get(s, CFG.UNK_ID) for s in symbols] def decode(self, ids: List[int]) -> str: # Expand merged tokens back to bytes by recursive splitting using known merges heuristically out_bytes: List[int] = [] for i in ids: tok = self.id_to_tok.get(i, '') # If token is a byte token if tok.startswith('') and len(tok) == 6: try: val = int(tok[3:5], 16) out_bytes.append(val) continue except ValueError: pass # Structural tokens if tok == '': out_bytes.append(ord(' ')) continue if tok == '': out_bytes.append(ord('\n')) continue # Macro token if tok.startswith(''): literal = tok[3+1:-1] # after '' # Note: tok[0:3] == '') and len(seg)==6: try: val = int(seg[3:5],16) parts.append(val) j += 6 continue except Exception: pass # If not parseable, encode remaining substring remainder = tok[j:] parts.extend(list(remainder.encode('utf-8'))) break out_bytes.extend(parts) return bytes(out_bytes).decode('utf-8', errors='replace')