panini-tokenizer / src /tokenizer.py
ArthaLabs's picture
Upload folder using huggingface_hub
5ae226b verified
"""
Panini Tokenizer V3 - Morphology-Aware Sanskrit Tokenizer
HuggingFace PreTrainedTokenizer compatible.
"""
import json
import os
from typing import Dict, List, Optional, Tuple, Union
from collections import OrderedDict
# HuggingFace imports
try:
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils_base import AddedToken
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
PreTrainedTokenizer = object # Fallback
from .analyzer import VidyutAnalyzer, MorphParse
from .splitter import SamasaSplitter, CompoundSplit
class PaniniTokenizerV3(PreTrainedTokenizer if HAS_TRANSFORMERS else object):
"""
Morphology-aware Sanskrit tokenizer using Vidyut.
Pipeline:
1. Vidyut analysis → extract morphological structure
2. Compound splitting → split at samāsa boundaries
3. Vibhakti separation → separate inflection from stem
4. Dynamic vocab → Kosha-backed vocabulary
"""
# Special tokens
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file: Optional[str] = None,
unk_token: str = "<unk>",
bos_token: str = "<s>",
eos_token: str = "</s>",
pad_token: str = "<pad>",
sep_token: str = "<sep>",
cls_token: str = "<cls>",
mask_token: str = "<mask>",
add_prefix_space: bool = True,
freeze_vocab: bool = False,
**kwargs
):
# Initialize special tokens
self.add_prefix_space = add_prefix_space
self.freeze_vocab = freeze_vocab # Prevent vocab explosion during training
# Core components
self.analyzer = VidyutAnalyzer(preload_cache=True)
self.splitter = SamasaSplitter(self.analyzer)
# Vocabulary
self._vocab: Dict[str, int] = {}
self._id_to_token: Dict[int, str] = {}
# Load or build vocab
if vocab_file and os.path.exists(vocab_file):
self._load_vocab(vocab_file)
else:
self._build_initial_vocab()
# Call parent init if using transformers
if HAS_TRANSFORMERS:
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
sep_token=sep_token,
cls_token=cls_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs
)
def _build_initial_vocab(self):
"""Build initial vocabulary with special tokens and common morphemes."""
# Special tokens first (IDs 0-7)
special = ["<unk>", "<s>", "</s>", "<pad>", "<sep>", "<cls>", "<mask>", "▁"]
for i, tok in enumerate(special):
self._vocab[tok] = i
self._id_to_token[i] = tok
# Common vibhakti endings
vibhaktis = [
"H", "m", "am", "At", "Aya", "asya", "e", "O", "ayoH",
"AH", "An", "eByo", "EH", "ezu", "ena", "ABym",
"A", "AyAH", "AyAm", "ayA", "Ani", "AnAm",
"sya", "ya", "aH", "iH", "uH",
]
# Common pratyayas
pratyayas = [
"tvA", "ya", "ta", "tavat", "at", "Ana", "tum",
"ti", "ana", "aka", "in", "tf", "tva", "tA",
"maya", "vat", "mat", "ika", "Iya",
]
# Common upasargas
upasargas = [
"pra", "parA", "apa", "sam", "anu", "ava", "nis", "nir",
"vi", "A", "ni", "aDi", "api", "ati", "su", "ut", "ud",
"aBi", "prati", "pari", "upa", "dur", "dus",
]
# Add morphemes to vocab
next_id = len(self._vocab)
for morpheme_list in [vibhaktis, pratyayas, upasargas]:
for m in morpheme_list:
if m not in self._vocab:
self._vocab[m] = next_id
self._id_to_token[next_id] = m
next_id += 1
# Also add with space prefix
spaced = "▁" + m
if spaced not in self._vocab:
self._vocab[spaced] = next_id
self._id_to_token[next_id] = spaced
next_id += 1
print(f" PaniniTokenizerV3: Initial vocab size = {len(self._vocab)}")
def _load_vocab(self, vocab_file: str):
"""Load vocabulary from JSON file."""
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
self._id_to_token = {v: k for k, v in self._vocab.items()}
print(f" PaniniTokenizerV3: Loaded vocab size = {len(self._vocab)}")
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""Save vocabulary to directory."""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def save_pretrained(self, save_directory: str, **kwargs):
"""
Save the tokenizer to a directory (HuggingFace compatible).
Creates: vocab.json, tokenizer_config.json, special_tokens_map.json
"""
os.makedirs(save_directory, exist_ok=True)
# 1. Save vocabulary
vocab_file = os.path.join(save_directory, "vocab.json")
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
# 2. Save tokenizer config
config = {
"tokenizer_class": "PaniniTokenizerV3",
"vocab_size": len(self._vocab),
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<pad>",
"sep_token": "<sep>",
"cls_token": "<cls>",
"mask_token": "<mask>",
"add_prefix_space": self.add_prefix_space,
"freeze_vocab": self.freeze_vocab,
}
config_file = os.path.join(save_directory, "tokenizer_config.json")
with open(config_file, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
# 3. Save special tokens map
special_tokens = {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<pad>",
"sep_token": "<sep>",
"cls_token": "<cls>",
"mask_token": "<mask>",
}
special_file = os.path.join(save_directory, "special_tokens_map.json")
with open(special_file, "w", encoding="utf-8") as f:
json.dump(special_tokens, f, ensure_ascii=False, indent=2)
print(f"✅ Saved PaniniTokenizerV3 to {save_directory}/")
print(f" vocab.json: {len(self._vocab)} tokens")
return save_directory
@classmethod
def from_pretrained(cls, pretrained_path: str, **kwargs):
"""
Load a tokenizer from a directory (HuggingFace compatible).
"""
vocab_file = os.path.join(pretrained_path, "vocab.json")
config_file = os.path.join(pretrained_path, "tokenizer_config.json")
# Load config if exists
config = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
# Create tokenizer
tokenizer = cls(
vocab_file=vocab_file,
freeze_vocab=config.get("freeze_vocab", True),
add_prefix_space=config.get("add_prefix_space", True),
**kwargs
)
print(f"✅ Loaded PaniniTokenizerV3 from {pretrained_path}/")
print(f" vocab.json: {len(tokenizer._vocab)} tokens")
return tokenizer
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _add_to_vocab(self, token: str) -> int:
"""Dynamically add a token to vocabulary."""
if token in self._vocab:
return self._vocab[token]
new_id = len(self._vocab)
self._vocab[token] = new_id
self._id_to_token[new_id] = token
return new_id
def _convert_token_to_id(self, token: str) -> int:
"""Convert token to ID, adding to vocab if needed (dynamic vocab)."""
if token in self._vocab:
return self._vocab[token]
# Freeze mode: return unk_id for unknown tokens (prevents vocab explosion)
if self.freeze_vocab:
return self._vocab.get("<unk>", 0)
# Dynamic vocab: add new tokens
return self._add_to_vocab(token)
def _convert_id_to_token(self, index: int) -> str:
"""Convert ID to token."""
return self._id_to_token.get(index, self.unk_token)
def _tokenize_word(self, word: str) -> List[str]:
"""
Tokenize a single word using morphological analysis.
New Grammar-Safe Pipeline (Rule A, B, C):
1. Parse with Vidyut (Collapse spines)
2. Iterative Samasa Splitting
3. No SP fallback for valid stems
"""
if not word:
return []
# Rule 3: Verbal forms (tiṅanta/kṛdanta) are atomic
# If word ends with verbal suffix, emit as single token without splitting
if self.analyzer._is_verb_form(word):
return ["▁" + word]
# Step 1: Get morphological parse (Derivational Collapse)
parse = self.analyzer.get_best_parse(word)
stem = parse.token_form()
# Rule A: If stem is valid in Kosha, DO NOT SPLIT further with SP
# Check if it's a compound that needs splitting
# Step 2: Iterative Samasa Splitting (Rule B)
# We split the *stem* recursively
final_tokens = []
# If the analyzer says it's a compound OR it looks like one
# We try to split it repeatedly
current_components = [stem]
# Helper: merge adjacent tokens that form known compounds
def merge_known_compounds(parts):
"""Merge adjacent parts that together form a known compound."""
merged = []
i = 0
while i < len(parts):
if i + 1 < len(parts):
# Try merging with Sandhi normalization
left = parts[i]
right = parts[i + 1]
# Handle vowel Sandhi: pratyag + AtmA → pratyagAtman
if left.endswith('A'):
candidate = left[:-1] + 'a' + right # AtmA → Atma + next
else:
candidate = left + right
# Also try: left ends with 'a' consumed by right starting with 'A'
# pratyag + AtmA → check if pratyagAtma or pratyagAtman in kosha
candidates = [candidate]
if left.endswith('A') and not right.startswith(('a', 'A', 'i', 'I', 'u', 'U', 'e', 'E', 'o', 'O')):
# Right starts with consonant but might have lost initial vowel
candidates.append(left + 'A' + right) # pratyagA + bhAsa
if self.analyzer._in_kosha(candidate):
merged.append(candidate)
i += 2
continue
# Try with Atman ending
atman_candidate = left[:-1] + 'an' if left.endswith('A') else left + 'an'
if right.endswith('A'):
atman_full = atman_candidate + right[:-1] + 'a'
else:
atman_full = atman_candidate
if len(atman_candidate) > 3 and self.analyzer._in_kosha(atman_candidate):
merged.append(atman_candidate)
# Still need to process right
merged.append(right)
i += 2
continue
merged.append(parts[i])
i += 1
return merged
# Iterative splitting until fixed point
MAX_PASSES = 6 # Increased for deep compounds
for _ in range(MAX_PASSES):
new_components = []
changed = False
# Split pass
for comp in current_components:
# Try to split this component
split_res = self.splitter.split(comp)
if split_res.is_compound and len(split_res.components) > 1:
new_components.extend(split_res.components)
changed = True
else:
# Sandhi restoration retry: if starts with consonant, NO split found,
# AND token is NOT valid (it's an OOV leftover from previous split),
# try prepending 'A' (initial vowel eaten in Sandhi)
# FIXED: Use _is_valid_stem (includes pratyaya stripping) not just _in_kosha
if (len(comp) > 3 and
comp[0] not in 'aAiIuUeEoO' and
not self.splitter._is_valid_stem(comp)): # Guard: only for truly invalid OOV
restored = 'A' + comp
restored_res = self.splitter.split(restored)
if restored_res.is_compound and len(restored_res.components) > 1:
# Map result back: first component keeps A prefix
new_components.extend(restored_res.components)
changed = True
continue
new_components.append(comp)
# Merge pass: merge adjacent tokens that form known compounds
merged_components = merge_known_compounds(new_components)
if len(merged_components) != len(new_components):
changed = True
if not changed:
break
current_components = merged_components
# Add tokens with spacing
for i, comp in enumerate(current_components):
# Rule A Violation Check:
# If 'comp' is in Kosha, use it AS IS.
# Only fall back to char/subword if it's garbage.
prefix = "▁" if i == 0 else ""
if self.analyzer._in_kosha(comp):
# Valid stem -> Atomic Token
final_tokens.append(prefix + comp)
else:
# OOV -> Only then maybe SP (but here we just keep as is for now)
# Ideally we want to mark it or maybe split chars if desperate
final_tokens.append(prefix + comp)
# Append vibhakti if separated (only for the last component usually)
# Append vibhakti if separated (only if not already present)
if parse.vibhakti and final_tokens:
last_token = final_tokens[-1].lstrip('▁')
# Guard: don't double-append if last token already ends with vibhakti
if not last_token.endswith(parse.vibhakti):
final_tokens.append(parse.vibhakti)
return final_tokens
def tokenize(self, text: str, **kwargs) -> List[str]:
"""
Tokenize text into morphological tokens.
This is the main entry point for tokenization.
"""
if not text:
return []
# Split on whitespace
words = text.split()
all_tokens = []
for i, word in enumerate(words):
word_tokens = self._tokenize_word(word)
all_tokens.extend(word_tokens)
return all_tokens
def _encode_impl(self, text: str) -> List[int]:
"""Internal encode implementation."""
tokens = self.tokenize(text)
return [self._convert_token_to_id(t) for t in tokens]
def encode(
self,
text: Union[str, List[str]],
add_special_tokens: bool = True,
**kwargs
) -> List[int]:
"""Encode text to token IDs."""
if isinstance(text, list):
text = " ".join(text)
ids = self._encode_impl(text)
if add_special_tokens:
bos_id = self._vocab.get("<s>", 1)
eos_id = self._vocab.get("</s>", 2)
ids = [bos_id] + ids + [eos_id]
return ids
def decode(
self,
token_ids: List[int],
skip_special_tokens: bool = True,
**kwargs
) -> str:
"""Decode token IDs back to text."""
special_ids = {0, 1, 2, 3, 4, 5, 6} # Special token IDs
tokens = []
for tid in token_ids:
if skip_special_tokens and tid in special_ids:
continue
token = self._convert_id_to_token(tid)
tokens.append(token)
# Join tokens, handling space prefix
text = ""
for t in tokens:
if t.startswith("▁"):
text += " " + t[1:]
else:
text += t
return text.strip()
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert token list back to string."""
text = ""
for t in tokens:
if t.startswith("▁"):
text += " " + t[1:]
else:
text += t
return text.strip()
# --- CONVENIENCE FUNCTION ---
def create_tokenizer(vocab_path: Optional[str] = None) -> PaniniTokenizerV3:
"""Create a PaniniTokenizerV3 instance."""
return PaniniTokenizerV3(vocab_file=vocab_path)
# --- TEST ---
if __name__ == "__main__":
print("\n" + "="*60)
print(" Testing PaniniTokenizerV3")
print("="*60)
tokenizer = PaniniTokenizerV3()
test_cases = [
"rAmaH gacCati",
"hfdpadmagataM paramAtma",
"sopAdhikapratyagAtmAbhAsabhedAbhedavicAraH",
]
for text in test_cases:
tokens = tokenizer.tokenize(text)
ids = tokenizer.encode(text, add_special_tokens=False)
decoded = tokenizer.decode(ids)
print(f"\n Input: {text}")
print(f" Tokens: {tokens}")
print(f" IDs: {ids[:10]}..." if len(ids) > 10 else f" IDs: {ids}")
print(f" Decoded: {decoded}")