| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Iterable, List |
|
|
| import sentencepiece as spm |
|
|
| from .utils import resolve_path |
|
|
|
|
| SAMPLE_BYTES = 4 * 1024 * 1024 |
|
|
|
|
| def choose_vocab_size(text: str, requested_vocab_size: int) -> int: |
| |
| |
| if len(text) >= requested_vocab_size * 64: |
| return requested_vocab_size |
|
|
| unique_chars = len(set(text)) |
| lower_bound = max(512, unique_chars + 256) |
| sample_limited = max(512, len(text) // 8) |
| return min(requested_vocab_size, max(lower_bound, sample_limited)) |
|
|
|
|
| class VisdomTokenizer: |
| def __init__(self, model_path: str | Path): |
| self.model_path = resolve_path(model_path) |
| if not self.model_path.exists(): |
| raise FileNotFoundError(f"Tokenizer model not found: {self.model_path}") |
| self.sp = spm.SentencePieceProcessor(model_file=str(self.model_path)) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return int(self.sp.vocab_size()) |
|
|
| @property |
| def bos_id(self) -> int: |
| return int(self.sp.bos_id()) |
|
|
| @property |
| def eos_id(self) -> int: |
| return int(self.sp.eos_id()) |
|
|
| def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: |
| ids = self.sp.encode(text, out_type=int) |
| if add_bos and self.bos_id >= 0: |
| ids = [self.bos_id] + ids |
| if add_eos and self.eos_id >= 0: |
| ids = ids + [self.eos_id] |
| return ids |
|
|
| def decode(self, ids: Iterable[int]) -> str: |
| return self.sp.decode(list(map(int, ids))) |
|
|
|
|
| def train_sentencepiece_tokenizer( |
| input_text_path: str | Path, |
| model_prefix: str | Path, |
| vocab_size: int, |
| model_type: str = "bpe", |
| character_coverage: float = 1.0, |
| ) -> Path: |
| input_text_path = resolve_path(input_text_path) |
| model_prefix = resolve_path(model_prefix) |
| model_prefix.parent.mkdir(parents=True, exist_ok=True) |
|
|
| if not input_text_path.exists(): |
| raise FileNotFoundError(f"Input text file not found: {input_text_path}") |
|
|
| with input_text_path.open("r", encoding="utf-8", errors="ignore") as f: |
| text = f.read(SAMPLE_BYTES) |
| if len(text.strip()) < 100: |
| raise ValueError("Input text is too small. Add more text to data/raw/input.txt before preparing data.") |
|
|
| actual_vocab_size = choose_vocab_size(text, vocab_size) |
| if actual_vocab_size != vocab_size: |
| print(f"Requested vocab_size={vocab_size}, using vocab_size={actual_vocab_size} for this dataset size.") |
|
|
| spm.SentencePieceTrainer.train( |
| input=str(input_text_path), |
| model_prefix=str(model_prefix), |
| vocab_size=actual_vocab_size, |
| model_type=model_type, |
| character_coverage=character_coverage, |
| bos_id=1, |
| eos_id=2, |
| unk_id=0, |
| pad_id=3, |
| user_defined_symbols=[], |
| byte_fallback=True, |
| split_digits=True, |
| allow_whitespace_only_pieces=True, |
| hard_vocab_limit=False, |
| ) |
| return model_prefix.with_suffix(".model") |
|
|