ogma-micro / tokenizer.py
Antreas's picture
Initial upload: ogma-micro embedding model
6efaeab verified
"""Tokenizer wrapper around SentencePiece for Ogma."""
from __future__ import annotations
from collections.abc import Sequence
from pathlib import Path
from typing import Any
import numpy as np
__all__ = ["OgmaTokenizer"]
# Number of special tokens reserved at the start of the vocabulary
N_SPECIAL = 7
SPECIAL_TOKENS = ["<pad>", "<unk>", "<s>", "</s>", "[QRY]", "[DOC]", "[SYM]"]
class OgmaTokenizer:
"""Wrapper around SentencePiece with special token handling.
Special token layout:
0: <pad>, 1: <unk>, 2: <s>, 3: </s>,
4: [QRY], 5: [DOC], 6: [SYM]
Regular tokens start at index 7.
"""
def __init__(self, model_path: str | Path) -> None:
import sentencepiece as spm # type: ignore[import-untyped]
self.sp = spm.SentencePieceProcessor()
self.sp.Load(str(model_path))
self._pad_id = 0
self._unk_id = 1
self._bos_id = 2
self._eos_id = 3
@property
def vocab_size(self) -> int:
"""Total vocab size including special tokens."""
return int(self.sp.GetPieceSize()) + N_SPECIAL
@property
def pad_id(self) -> int:
return self._pad_id
def encode(
self,
text: str,
max_length: int = 512,
add_special_tokens: bool = True,
) -> list[int]:
"""Encode text to token IDs.
Args:
text: Input text string.
max_length: Maximum number of tokens.
add_special_tokens: Whether to add BOS/EOS.
Returns:
List of token IDs (offset by N_SPECIAL).
"""
ids = self.sp.Encode(text)
# Offset by N_SPECIAL to reserve space for special tokens
ids = [i + N_SPECIAL for i in ids]
if add_special_tokens:
ids = [self._bos_id] + ids + [self._eos_id]
return ids[:max_length]
def decode(self, ids: list[int]) -> str:
"""Decode token IDs back to text.
Args:
ids: Token IDs.
Returns:
Decoded text string.
"""
# Remove special tokens and un-offset
regular_ids = [
i - N_SPECIAL for i in ids if i >= N_SPECIAL
]
return self.sp.Decode(regular_ids) # type: ignore[no-any-return]
def batch_encode(
self,
texts: list[str],
max_length: int = 512,
padding: bool = True,
) -> dict[str, np.ndarray[Any, np.dtype[np.int32]]]:
"""Batch encode texts with padding.
Args:
texts: List of input texts.
max_length: Maximum sequence length.
padding: Whether to pad to max_length.
Returns:
Dict with 'input_ids' and 'attention_mask' as numpy arrays.
"""
encoded = [self.encode(t, max_length) for t in texts]
if padding:
max_len = min(max(len(e) for e in encoded), max_length)
input_ids = np.full(
(len(texts), max_len), self._pad_id, dtype=np.int32
)
attention_mask = np.zeros(
(len(texts), max_len), dtype=np.int32
)
for i, ids in enumerate(encoded):
length = min(len(ids), max_len)
input_ids[i, :length] = ids[:length]
attention_mask[i, :length] = 1
else:
max_len = max_length
input_ids = np.array(
[e + [self._pad_id] * (max_len - len(e)) for e in encoded],
dtype=np.int32,
)
attention_mask = np.array(
[[1] * len(e) + [0] * (max_len - len(e)) for e in encoded],
dtype=np.int32,
)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@staticmethod
def train(
corpus_files: Sequence[str | Path],
output_path: str | Path,
vocab_size: int = 30_000,
character_coverage: float = 0.9999,
) -> None:
"""Train a SentencePiece tokenizer.
Args:
corpus_files: Paths to text corpus files (one sentence per line).
output_path: Output path for the model file (without extension).
vocab_size: Target vocabulary size (excluding special tokens).
character_coverage: Character coverage for training.
"""
import sentencepiece as spm
input_str = ",".join(str(f) for f in corpus_files)
spm.SentencePieceTrainer.Train(
input=input_str,
model_prefix=str(output_path),
vocab_size=vocab_size,
model_type="unigram",
character_coverage=character_coverage,
byte_fallback=True,
pad_id=-1, # We handle padding ourselves
bos_id=-1,
eos_id=-1,
unk_id=0,
)