VISDOM / src /tokenizer.py
VishalPreetham's picture
Upload folder using huggingface_hub
18be545 verified
Raw
History Blame Contribute Delete
3.13 kB
from __future__ import annotations
from pathlib import Path
from typing import Iterable, List
import sentencepiece as spm
from .utils import resolve_path
SAMPLE_BYTES = 4 * 1024 * 1024
def choose_vocab_size(text: str, requested_vocab_size: int) -> int:
# Keep the requested vocab for reasonably large corpora.
# Only shrink for truly small samples that cannot support it.
if len(text) >= requested_vocab_size * 64:
return requested_vocab_size
unique_chars = len(set(text))
lower_bound = max(512, unique_chars + 256)
sample_limited = max(512, len(text) // 8)
return min(requested_vocab_size, max(lower_bound, sample_limited))
class VisdomTokenizer:
def __init__(self, model_path: str | Path):
self.model_path = resolve_path(model_path)
if not self.model_path.exists():
raise FileNotFoundError(f"Tokenizer model not found: {self.model_path}")
self.sp = spm.SentencePieceProcessor(model_file=str(self.model_path))
@property
def vocab_size(self) -> int:
return int(self.sp.vocab_size())
@property
def bos_id(self) -> int:
return int(self.sp.bos_id())
@property
def eos_id(self) -> int:
return int(self.sp.eos_id())
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]:
ids = self.sp.encode(text, out_type=int)
if add_bos and self.bos_id >= 0:
ids = [self.bos_id] + ids
if add_eos and self.eos_id >= 0:
ids = ids + [self.eos_id]
return ids
def decode(self, ids: Iterable[int]) -> str:
return self.sp.decode(list(map(int, ids)))
def train_sentencepiece_tokenizer(
input_text_path: str | Path,
model_prefix: str | Path,
vocab_size: int,
model_type: str = "bpe",
character_coverage: float = 1.0,
) -> Path:
input_text_path = resolve_path(input_text_path)
model_prefix = resolve_path(model_prefix)
model_prefix.parent.mkdir(parents=True, exist_ok=True)
if not input_text_path.exists():
raise FileNotFoundError(f"Input text file not found: {input_text_path}")
with input_text_path.open("r", encoding="utf-8", errors="ignore") as f:
text = f.read(SAMPLE_BYTES)
if len(text.strip()) < 100:
raise ValueError("Input text is too small. Add more text to data/raw/input.txt before preparing data.")
actual_vocab_size = choose_vocab_size(text, vocab_size)
if actual_vocab_size != vocab_size:
print(f"Requested vocab_size={vocab_size}, using vocab_size={actual_vocab_size} for this dataset size.")
spm.SentencePieceTrainer.train(
input=str(input_text_path),
model_prefix=str(model_prefix),
vocab_size=actual_vocab_size,
model_type=model_type,
character_coverage=character_coverage,
bos_id=1,
eos_id=2,
unk_id=0,
pad_id=3,
user_defined_symbols=[],
byte_fallback=True,
split_digits=True,
allow_whitespace_only_pieces=True,
hard_vocab_limit=False,
)
return model_prefix.with_suffix(".model")