"""Tokenizer wrapper around SentencePiece for Ogma.""" from __future__ import annotations from collections.abc import Sequence from pathlib import Path from typing import Any import numpy as np __all__ = ["OgmaTokenizer"] # Number of special tokens reserved at the start of the vocabulary N_SPECIAL = 7 SPECIAL_TOKENS = ["", "", "", "", "[QRY]", "[DOC]", "[SYM]"] class OgmaTokenizer: """Wrapper around SentencePiece with special token handling. Special token layout: 0: , 1: , 2: , 3: , 4: [QRY], 5: [DOC], 6: [SYM] Regular tokens start at index 7. """ def __init__(self, model_path: str | Path) -> None: import sentencepiece as spm # type: ignore[import-untyped] self.sp = spm.SentencePieceProcessor() self.sp.Load(str(model_path)) self._pad_id = 0 self._unk_id = 1 self._bos_id = 2 self._eos_id = 3 @property def vocab_size(self) -> int: """Total vocab size including special tokens.""" return int(self.sp.GetPieceSize()) + N_SPECIAL @property def pad_id(self) -> int: return self._pad_id def encode( self, text: str, max_length: int = 512, add_special_tokens: bool = True, ) -> list[int]: """Encode text to token IDs. Args: text: Input text string. max_length: Maximum number of tokens. add_special_tokens: Whether to add BOS/EOS. Returns: List of token IDs (offset by N_SPECIAL). """ ids = self.sp.Encode(text) # Offset by N_SPECIAL to reserve space for special tokens ids = [i + N_SPECIAL for i in ids] if add_special_tokens: ids = [self._bos_id] + ids + [self._eos_id] return ids[:max_length] def decode(self, ids: list[int]) -> str: """Decode token IDs back to text. Args: ids: Token IDs. Returns: Decoded text string. """ # Remove special tokens and un-offset regular_ids = [ i - N_SPECIAL for i in ids if i >= N_SPECIAL ] return self.sp.Decode(regular_ids) # type: ignore[no-any-return] def batch_encode( self, texts: list[str], max_length: int = 512, padding: bool = True, ) -> dict[str, np.ndarray[Any, np.dtype[np.int32]]]: """Batch encode texts with padding. Args: texts: List of input texts. max_length: Maximum sequence length. padding: Whether to pad to max_length. Returns: Dict with 'input_ids' and 'attention_mask' as numpy arrays. """ encoded = [self.encode(t, max_length) for t in texts] if padding: max_len = min(max(len(e) for e in encoded), max_length) input_ids = np.full( (len(texts), max_len), self._pad_id, dtype=np.int32 ) attention_mask = np.zeros( (len(texts), max_len), dtype=np.int32 ) for i, ids in enumerate(encoded): length = min(len(ids), max_len) input_ids[i, :length] = ids[:length] attention_mask[i, :length] = 1 else: max_len = max_length input_ids = np.array( [e + [self._pad_id] * (max_len - len(e)) for e in encoded], dtype=np.int32, ) attention_mask = np.array( [[1] * len(e) + [0] * (max_len - len(e)) for e in encoded], dtype=np.int32, ) return {"input_ids": input_ids, "attention_mask": attention_mask} @staticmethod def train( corpus_files: Sequence[str | Path], output_path: str | Path, vocab_size: int = 30_000, character_coverage: float = 0.9999, ) -> None: """Train a SentencePiece tokenizer. Args: corpus_files: Paths to text corpus files (one sentence per line). output_path: Output path for the model file (without extension). vocab_size: Target vocabulary size (excluding special tokens). character_coverage: Character coverage for training. """ import sentencepiece as spm input_str = ",".join(str(f) for f in corpus_files) spm.SentencePieceTrainer.Train( input=input_str, model_prefix=str(output_path), vocab_size=vocab_size, model_type="unigram", character_coverage=character_coverage, byte_fallback=True, pad_id=-1, # We handle padding ourselves bos_id=-1, eos_id=-1, unk_id=0, )