File size: 3,129 Bytes
18be545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

from pathlib import Path
from typing import Iterable, List

import sentencepiece as spm

from .utils import resolve_path


SAMPLE_BYTES = 4 * 1024 * 1024


def choose_vocab_size(text: str, requested_vocab_size: int) -> int:
    # Keep the requested vocab for reasonably large corpora.
    # Only shrink for truly small samples that cannot support it.
    if len(text) >= requested_vocab_size * 64:
        return requested_vocab_size

    unique_chars = len(set(text))
    lower_bound = max(512, unique_chars + 256)
    sample_limited = max(512, len(text) // 8)
    return min(requested_vocab_size, max(lower_bound, sample_limited))


class VisdomTokenizer:
    def __init__(self, model_path: str | Path):
        self.model_path = resolve_path(model_path)
        if not self.model_path.exists():
            raise FileNotFoundError(f"Tokenizer model not found: {self.model_path}")
        self.sp = spm.SentencePieceProcessor(model_file=str(self.model_path))

    @property
    def vocab_size(self) -> int:
        return int(self.sp.vocab_size())

    @property
    def bos_id(self) -> int:
        return int(self.sp.bos_id())

    @property
    def eos_id(self) -> int:
        return int(self.sp.eos_id())

    def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]:
        ids = self.sp.encode(text, out_type=int)
        if add_bos and self.bos_id >= 0:
            ids = [self.bos_id] + ids
        if add_eos and self.eos_id >= 0:
            ids = ids + [self.eos_id]
        return ids

    def decode(self, ids: Iterable[int]) -> str:
        return self.sp.decode(list(map(int, ids)))


def train_sentencepiece_tokenizer(
    input_text_path: str | Path,
    model_prefix: str | Path,
    vocab_size: int,
    model_type: str = "bpe",
    character_coverage: float = 1.0,
) -> Path:
    input_text_path = resolve_path(input_text_path)
    model_prefix = resolve_path(model_prefix)
    model_prefix.parent.mkdir(parents=True, exist_ok=True)

    if not input_text_path.exists():
        raise FileNotFoundError(f"Input text file not found: {input_text_path}")

    with input_text_path.open("r", encoding="utf-8", errors="ignore") as f:
        text = f.read(SAMPLE_BYTES)
    if len(text.strip()) < 100:
        raise ValueError("Input text is too small. Add more text to data/raw/input.txt before preparing data.")

    actual_vocab_size = choose_vocab_size(text, vocab_size)
    if actual_vocab_size != vocab_size:
        print(f"Requested vocab_size={vocab_size}, using vocab_size={actual_vocab_size} for this dataset size.")

    spm.SentencePieceTrainer.train(
        input=str(input_text_path),
        model_prefix=str(model_prefix),
        vocab_size=actual_vocab_size,
        model_type=model_type,
        character_coverage=character_coverage,
        bos_id=1,
        eos_id=2,
        unk_id=0,
        pad_id=3,
        user_defined_symbols=[],
        byte_fallback=True,
        split_digits=True,
        allow_whitespace_only_pieces=True,
        hard_vocab_limit=False,
    )
    return model_prefix.with_suffix(".model")