| """Tokenizer wrapper β loads a pretrained HuggingFace tokenizer.""" |
|
|
| from typing import Optional, List |
|
|
| from llm_lab.config import DataConfig |
|
|
|
|
| class Tokenizer: |
| """Pretrained tokenizer wrapper. |
| |
| Loads a pretrained HF tokenizer (e.g., LLaMA 2 tokenizer) and provides |
| a unified encode/decode interface for the training pipeline. |
| |
| BPE (Byte Pair Encoding) core principle: |
| 1) Split text into byte/character units |
| 2) Repeatedly merge the most frequent adjacent pair |
| 3) Repeat until vocab_size is reached |
| β Frequent words become a single token; rare words are split into multiple tokens |
| """ |
|
|
| def __init__(self, config: DataConfig): |
| self.config = config |
| self._tokenizer = None |
| self.vocab_size = config.vocab_size |
|
|
| |
| self.bos_id: int = 1 |
| self.eos_id: int = 2 |
| self.pad_id: int = 0 |
|
|
| def load_pretrained_hf(self, name_or_path: Optional[str] = None): |
| """Loads a pretrained tokenizer from HuggingFace. |
| |
| Default: LLaMA 2 tokenizer (NousResearch/Llama-2-7b-hf mirror). |
| - vocab_size : 32,000 |
| - SentencePiece BPE β optimal for 1B-scale models (TinyLlama, LLaMA 1/2) |
| - No HuggingFace authentication required (community mirror) |
| Official source (requires HF auth): |
| - "meta-llama/Llama-2-7b-hf" |
| """ |
| from transformers import AutoTokenizer |
|
|
| name_or_path = name_or_path or self.config.tokenizer_name |
| print(f"[Tokenizer] Loading HF tokenizer: {name_or_path}") |
| tokenizer = AutoTokenizer.from_pretrained(name_or_path) |
|
|
| self._tokenizer = tokenizer |
| self.vocab_size = tokenizer.vocab_size |
| self.bos_id = tokenizer.bos_token_id or 1 |
| self.eos_id = tokenizer.eos_token_id or 2 |
| self.pad_id = tokenizer.pad_token_id or 0 |
|
|
| self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False) |
| self._decode_fn = lambda ids: tokenizer.decode(ids) |
|
|
| print(f"[Tokenizer] Loaded: vocab_size={self.vocab_size}") |
|
|
| |
| |
| |
|
|
| def encode(self, text: str, add_special_tokens: bool = False) -> List[int]: |
| """Text β list of token IDs.""" |
| ids = self._encode_fn(text) |
| if add_special_tokens: |
| ids = [self.bos_id] + ids + [self.eos_id] |
| return ids |
|
|
| def decode(self, ids: List[int]) -> str: |
| """List of token IDs β text.""" |
| return self._decode_fn(ids) |
|
|
| def __len__(self) -> int: |
| return self.vocab_size |
|
|