"""Tokenizer wrapper — loads a pretrained HuggingFace tokenizer.""" from typing import Optional, List from llm_lab.config import DataConfig class Tokenizer: """Pretrained tokenizer wrapper. Loads a pretrained HF tokenizer (e.g., LLaMA 2 tokenizer) and provides a unified encode/decode interface for the training pipeline. BPE (Byte Pair Encoding) core principle: 1) Split text into byte/character units 2) Repeatedly merge the most frequent adjacent pair 3) Repeat until vocab_size is reached → Frequent words become a single token; rare words are split into multiple tokens """ def __init__(self, config: DataConfig): self.config = config self._tokenizer = None self.vocab_size = config.vocab_size # Special token IDs (set after initialization) self.bos_id: int = 1 # Beginning of Sequence self.eos_id: int = 2 # End of Sequence self.pad_id: int = 0 # Padding def load_pretrained_hf(self, name_or_path: Optional[str] = None): """Loads a pretrained tokenizer from HuggingFace. Default: LLaMA 2 tokenizer (NousResearch/Llama-2-7b-hf mirror). - vocab_size : 32,000 - SentencePiece BPE — optimal for 1B-scale models (TinyLlama, LLaMA 1/2) - No HuggingFace authentication required (community mirror) Official source (requires HF auth): - "meta-llama/Llama-2-7b-hf" """ from transformers import AutoTokenizer name_or_path = name_or_path or self.config.tokenizer_name print(f"[Tokenizer] Loading HF tokenizer: {name_or_path}") tokenizer = AutoTokenizer.from_pretrained(name_or_path) self._tokenizer = tokenizer self.vocab_size = tokenizer.vocab_size self.bos_id = tokenizer.bos_token_id or 1 self.eos_id = tokenizer.eos_token_id or 2 self.pad_id = tokenizer.pad_token_id or 0 self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False) self._decode_fn = lambda ids: tokenizer.decode(ids) print(f"[Tokenizer] Loaded: vocab_size={self.vocab_size}") # ──────────────────────────────────────────────── # Common interface # ──────────────────────────────────────────────── def encode(self, text: str, add_special_tokens: bool = False) -> List[int]: """Text → list of token IDs.""" ids = self._encode_fn(text) if add_special_tokens: ids = [self.bos_id] + ids + [self.eos_id] return ids def decode(self, ids: List[int]) -> str: """List of token IDs → text.""" return self._decode_fn(ids) def __len__(self) -> int: return self.vocab_size