File size: 4,811 Bytes
6efaeab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Tokenizer wrapper around SentencePiece for Ogma."""

from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import Any

import numpy as np

__all__ = ["OgmaTokenizer"]

# Number of special tokens reserved at the start of the vocabulary
N_SPECIAL = 7
SPECIAL_TOKENS = ["<pad>", "<unk>", "<s>", "</s>", "[QRY]", "[DOC]", "[SYM]"]


class OgmaTokenizer:
    """Wrapper around SentencePiece with special token handling.

    Special token layout:
        0: <pad>, 1: <unk>, 2: <s>, 3: </s>,
        4: [QRY], 5: [DOC], 6: [SYM]
    Regular tokens start at index 7.
    """

    def __init__(self, model_path: str | Path) -> None:
        import sentencepiece as spm  # type: ignore[import-untyped]

        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(str(model_path))
        self._pad_id = 0
        self._unk_id = 1
        self._bos_id = 2
        self._eos_id = 3

    @property
    def vocab_size(self) -> int:
        """Total vocab size including special tokens."""
        return int(self.sp.GetPieceSize()) + N_SPECIAL

    @property
    def pad_id(self) -> int:
        return self._pad_id

    def encode(
        self,
        text: str,
        max_length: int = 512,
        add_special_tokens: bool = True,
    ) -> list[int]:
        """Encode text to token IDs.

        Args:
            text: Input text string.
            max_length: Maximum number of tokens.
            add_special_tokens: Whether to add BOS/EOS.

        Returns:
            List of token IDs (offset by N_SPECIAL).
        """
        ids = self.sp.Encode(text)
        # Offset by N_SPECIAL to reserve space for special tokens
        ids = [i + N_SPECIAL for i in ids]

        if add_special_tokens:
            ids = [self._bos_id] + ids + [self._eos_id]

        return ids[:max_length]

    def decode(self, ids: list[int]) -> str:
        """Decode token IDs back to text.

        Args:
            ids: Token IDs.

        Returns:
            Decoded text string.
        """
        # Remove special tokens and un-offset
        regular_ids = [
            i - N_SPECIAL for i in ids if i >= N_SPECIAL
        ]
        return self.sp.Decode(regular_ids)  # type: ignore[no-any-return]

    def batch_encode(
        self,
        texts: list[str],
        max_length: int = 512,
        padding: bool = True,
    ) -> dict[str, np.ndarray[Any, np.dtype[np.int32]]]:
        """Batch encode texts with padding.

        Args:
            texts: List of input texts.
            max_length: Maximum sequence length.
            padding: Whether to pad to max_length.

        Returns:
            Dict with 'input_ids' and 'attention_mask' as numpy arrays.
        """
        encoded = [self.encode(t, max_length) for t in texts]

        if padding:
            max_len = min(max(len(e) for e in encoded), max_length)
            input_ids = np.full(
                (len(texts), max_len), self._pad_id, dtype=np.int32
            )
            attention_mask = np.zeros(
                (len(texts), max_len), dtype=np.int32
            )
            for i, ids in enumerate(encoded):
                length = min(len(ids), max_len)
                input_ids[i, :length] = ids[:length]
                attention_mask[i, :length] = 1
        else:
            max_len = max_length
            input_ids = np.array(
                [e + [self._pad_id] * (max_len - len(e)) for e in encoded],
                dtype=np.int32,
            )
            attention_mask = np.array(
                [[1] * len(e) + [0] * (max_len - len(e)) for e in encoded],
                dtype=np.int32,
            )

        return {"input_ids": input_ids, "attention_mask": attention_mask}

    @staticmethod
    def train(
        corpus_files: Sequence[str | Path],
        output_path: str | Path,
        vocab_size: int = 30_000,
        character_coverage: float = 0.9999,
    ) -> None:
        """Train a SentencePiece tokenizer.

        Args:
            corpus_files: Paths to text corpus files (one sentence per line).
            output_path: Output path for the model file (without extension).
            vocab_size: Target vocabulary size (excluding special tokens).
            character_coverage: Character coverage for training.
        """
        import sentencepiece as spm

        input_str = ",".join(str(f) for f in corpus_files)
        spm.SentencePieceTrainer.Train(
            input=input_str,
            model_prefix=str(output_path),
            vocab_size=vocab_size,
            model_type="unigram",
            character_coverage=character_coverage,
            byte_fallback=True,
            pad_id=-1,  # We handle padding ourselves
            bos_id=-1,
            eos_id=-1,
            unk_id=0,
        )