from __future__ import annotations from pathlib import Path from typing import Iterable, List import sentencepiece as spm from .utils import resolve_path SAMPLE_BYTES = 4 * 1024 * 1024 def choose_vocab_size(text: str, requested_vocab_size: int) -> int: # Keep the requested vocab for reasonably large corpora. # Only shrink for truly small samples that cannot support it. if len(text) >= requested_vocab_size * 64: return requested_vocab_size unique_chars = len(set(text)) lower_bound = max(512, unique_chars + 256) sample_limited = max(512, len(text) // 8) return min(requested_vocab_size, max(lower_bound, sample_limited)) class VisdomTokenizer: def __init__(self, model_path: str | Path): self.model_path = resolve_path(model_path) if not self.model_path.exists(): raise FileNotFoundError(f"Tokenizer model not found: {self.model_path}") self.sp = spm.SentencePieceProcessor(model_file=str(self.model_path)) @property def vocab_size(self) -> int: return int(self.sp.vocab_size()) @property def bos_id(self) -> int: return int(self.sp.bos_id()) @property def eos_id(self) -> int: return int(self.sp.eos_id()) def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: ids = self.sp.encode(text, out_type=int) if add_bos and self.bos_id >= 0: ids = [self.bos_id] + ids if add_eos and self.eos_id >= 0: ids = ids + [self.eos_id] return ids def decode(self, ids: Iterable[int]) -> str: return self.sp.decode(list(map(int, ids))) def train_sentencepiece_tokenizer( input_text_path: str | Path, model_prefix: str | Path, vocab_size: int, model_type: str = "bpe", character_coverage: float = 1.0, ) -> Path: input_text_path = resolve_path(input_text_path) model_prefix = resolve_path(model_prefix) model_prefix.parent.mkdir(parents=True, exist_ok=True) if not input_text_path.exists(): raise FileNotFoundError(f"Input text file not found: {input_text_path}") with input_text_path.open("r", encoding="utf-8", errors="ignore") as f: text = f.read(SAMPLE_BYTES) if len(text.strip()) < 100: raise ValueError("Input text is too small. Add more text to data/raw/input.txt before preparing data.") actual_vocab_size = choose_vocab_size(text, vocab_size) if actual_vocab_size != vocab_size: print(f"Requested vocab_size={vocab_size}, using vocab_size={actual_vocab_size} for this dataset size.") spm.SentencePieceTrainer.train( input=str(input_text_path), model_prefix=str(model_prefix), vocab_size=actual_vocab_size, model_type=model_type, character_coverage=character_coverage, bos_id=1, eos_id=2, unk_id=0, pad_id=3, user_defined_symbols=[], byte_fallback=True, split_digits=True, allow_whitespace_only_pieces=True, hard_vocab_limit=False, ) return model_prefix.with_suffix(".model")