LLM-1B-Lab / llm_lab /data /tokenizer.py
Vjeong's picture
Remove unused tokenizer training code (train_bpe, load_sentencepiece, load_trained_hf)
33ba3d1
"""Tokenizer wrapper β€” loads a pretrained HuggingFace tokenizer."""
from typing import Optional, List
from llm_lab.config import DataConfig
class Tokenizer:
"""Pretrained tokenizer wrapper.
Loads a pretrained HF tokenizer (e.g., LLaMA 2 tokenizer) and provides
a unified encode/decode interface for the training pipeline.
BPE (Byte Pair Encoding) core principle:
1) Split text into byte/character units
2) Repeatedly merge the most frequent adjacent pair
3) Repeat until vocab_size is reached
β†’ Frequent words become a single token; rare words are split into multiple tokens
"""
def __init__(self, config: DataConfig):
self.config = config
self._tokenizer = None
self.vocab_size = config.vocab_size
# Special token IDs (set after initialization)
self.bos_id: int = 1 # Beginning of Sequence
self.eos_id: int = 2 # End of Sequence
self.pad_id: int = 0 # Padding
def load_pretrained_hf(self, name_or_path: Optional[str] = None):
"""Loads a pretrained tokenizer from HuggingFace.
Default: LLaMA 2 tokenizer (NousResearch/Llama-2-7b-hf mirror).
- vocab_size : 32,000
- SentencePiece BPE β€” optimal for 1B-scale models (TinyLlama, LLaMA 1/2)
- No HuggingFace authentication required (community mirror)
Official source (requires HF auth):
- "meta-llama/Llama-2-7b-hf"
"""
from transformers import AutoTokenizer
name_or_path = name_or_path or self.config.tokenizer_name
print(f"[Tokenizer] Loading HF tokenizer: {name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
self._tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size
self.bos_id = tokenizer.bos_token_id or 1
self.eos_id = tokenizer.eos_token_id or 2
self.pad_id = tokenizer.pad_token_id or 0
self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
self._decode_fn = lambda ids: tokenizer.decode(ids)
print(f"[Tokenizer] Loaded: vocab_size={self.vocab_size}")
# ────────────────────────────────────────────────
# Common interface
# ────────────────────────────────────────────────
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
"""Text β†’ list of token IDs."""
ids = self._encode_fn(text)
if add_special_tokens:
ids = [self.bos_id] + ids + [self.eos_id]
return ids
def decode(self, ids: List[int]) -> str:
"""List of token IDs β†’ text."""
return self._decode_fn(ids)
def __len__(self) -> int:
return self.vocab_size