Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from typing import List, Dict | |
| from . import config as CFG | |
| from .utils import load_json | |
| from .pretokenizer import pretokenize_line | |
| class BPETokenizer: | |
| def __init__(self, tokenizer_dir: str = CFG.TOKENIZER_DIR): | |
| vocab_path = os.path.join(tokenizer_dir, 'bpe_vocab.json') | |
| merges_path = os.path.join(tokenizer_dir, 'bpe_merges.txt') | |
| config_path = os.path.join(tokenizer_dir, 'tokenizer_config.json') | |
| self.vocab: Dict[str,int] = load_json(vocab_path) | |
| self.id_to_tok = {i:t for t,i in self.vocab.items()} | |
| self.merges: List[tuple[str,str]] = [] | |
| if os.path.exists(merges_path): | |
| with open(merges_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| parts = line.strip().split() | |
| if len(parts) == 2: | |
| self.merges.append((parts[0], parts[1])) | |
| self.config = load_json(config_path) if os.path.exists(config_path) else {} | |
| # Build merge map for fast lookup | |
| self.merge_map = {a+b: (a,b) for a,b in self.merges} | |
| def _byte_fallback(self, text: str) -> List[str]: | |
| return [f"<b:{b:02X}>" for b in text.encode('utf-8')] | |
| def encode(self, text: str) -> List[int]: | |
| if not text: | |
| return [] | |
| # Start with the same pretokenization used in training | |
| symbols = pretokenize_line(text) | |
| # Fallback decomposition for unknown tokens based on granularity | |
| gran = getattr(CFG, 'TOKEN_GRANULARITY', 'byte') | |
| if gran == 'word': | |
| expanded: List[str] = [] | |
| for s in symbols: | |
| if s in self.vocab or (s.startswith('<m:') and s.endswith('>')): | |
| expanded.append(s) | |
| else: | |
| # Decompose rare word token into characters, then bytes for unseen characters | |
| for ch in list(s): | |
| if ch in self.vocab: | |
| expanded.append(ch) | |
| else: | |
| expanded.extend([f"<b:{b:02X}>" for b in ch.encode('utf-8')]) | |
| symbols = expanded | |
| elif gran == 'char': | |
| expanded: List[str] = [] | |
| for s in symbols: | |
| if s in self.vocab or (s.startswith('<m:') and s.endswith('>')): | |
| expanded.append(s) | |
| else: | |
| # Fallback to bytes for unseen characters | |
| expanded.extend([f"<b:{b:02X}>" for b in s.encode('utf-8')]) | |
| symbols = expanded | |
| # Apply merges greedily left-to-right repeatedly | |
| changed = True | |
| while changed: | |
| changed = False | |
| i = 0 | |
| new_syms: List[str] = [] | |
| while i < len(symbols): | |
| if i < len(symbols)-1: | |
| pair = symbols[i] + symbols[i+1] | |
| if pair in self.vocab: | |
| new_syms.append(pair) | |
| i += 2 | |
| changed = True | |
| continue | |
| new_syms.append(symbols[i]) | |
| i += 1 | |
| symbols = new_syms | |
| return [self.vocab.get(s, CFG.UNK_ID) for s in symbols] | |
| def decode(self, ids: List[int]) -> str: | |
| # Expand merged tokens back to bytes by recursive splitting using known merges heuristically | |
| out_bytes: List[int] = [] | |
| for i in ids: | |
| tok = self.id_to_tok.get(i, '<UNK>') | |
| # If token is a byte token <b:XX> | |
| if tok.startswith('<b:') and tok.endswith('>') and len(tok) == 6: | |
| try: | |
| val = int(tok[3:5], 16) | |
| out_bytes.append(val) | |
| continue | |
| except ValueError: | |
| pass | |
| # Structural tokens | |
| if tok == '<ws>': | |
| out_bytes.append(ord(' ')) | |
| continue | |
| if tok == '<nl>': | |
| out_bytes.append(ord('\n')) | |
| continue | |
| # Macro token <m:...> | |
| if tok.startswith('<m:') and tok.endswith('>'): | |
| literal = tok[3+1:-1] # after '<m:' read the payload until '>' | |
| # Note: tok[0:3] == '<m:'; payload starts at 3 | |
| literal = tok[3:-1] | |
| out_bytes.extend(literal.encode('utf-8')) | |
| continue | |
| # Otherwise attempt to split into byte tokens greedily | |
| # Fallback: treat composite token as raw UTF-8 by searching embedded byte patterns | |
| parts = [] | |
| j = 0 | |
| while j < len(tok): | |
| if tok.startswith('<b:', j): | |
| seg = tok[j:j+6] | |
| if seg.startswith('<b:') and seg.endswith('>') and len(seg)==6: | |
| try: | |
| val = int(seg[3:5],16) | |
| parts.append(val) | |
| j += 6 | |
| continue | |
| except Exception: | |
| pass | |
| # If not parseable, encode remaining substring | |
| remainder = tok[j:] | |
| parts.extend(list(remainder.encode('utf-8'))) | |
| break | |
| out_bytes.extend(parts) | |
| return bytes(out_bytes).decode('utf-8', errors='replace') | |