Spaces:
Running
Running
| """ | |
| Panini Tokenizer V3 - Morphology-Aware Sanskrit Tokenizer | |
| HuggingFace PreTrainedTokenizer compatible. | |
| """ | |
| import json | |
| import os | |
| from typing import Dict, List, Optional, Tuple, Union | |
| from collections import OrderedDict | |
| # HuggingFace imports | |
| try: | |
| from transformers import PreTrainedTokenizer | |
| from transformers.tokenization_utils_base import AddedToken | |
| HAS_TRANSFORMERS = True | |
| except ImportError: | |
| HAS_TRANSFORMERS = False | |
| PreTrainedTokenizer = object # Fallback | |
| from .analyzer import VidyutAnalyzer, MorphParse | |
| from .splitter import SamasaSplitter, CompoundSplit | |
| class PaniniTokenizerV3(PreTrainedTokenizer if HAS_TRANSFORMERS else object): | |
| """ | |
| Morphology-aware Sanskrit tokenizer using Vidyut. | |
| Pipeline: | |
| 1. Vidyut analysis → extract morphological structure | |
| 2. Compound splitting → split at samāsa boundaries | |
| 3. Vibhakti separation → separate inflection from stem | |
| 4. Dynamic vocab → Kosha-backed vocabulary | |
| """ | |
| # Special tokens | |
| vocab_files_names = {"vocab_file": "vocab.json"} | |
| model_input_names = ["input_ids", "attention_mask"] | |
| def __init__( | |
| self, | |
| vocab_file: Optional[str] = None, | |
| unk_token: str = "<unk>", | |
| bos_token: str = "<s>", | |
| eos_token: str = "</s>", | |
| pad_token: str = "<pad>", | |
| sep_token: str = "<sep>", | |
| cls_token: str = "<cls>", | |
| mask_token: str = "<mask>", | |
| add_prefix_space: bool = True, | |
| freeze_vocab: bool = False, | |
| **kwargs | |
| ): | |
| # Initialize special tokens | |
| self.add_prefix_space = add_prefix_space | |
| self.freeze_vocab = freeze_vocab # Prevent vocab explosion during training | |
| # Core components | |
| self.analyzer = VidyutAnalyzer(preload_cache=True) | |
| self.splitter = SamasaSplitter(self.analyzer) | |
| # Vocabulary | |
| self._vocab: Dict[str, int] = {} | |
| self._id_to_token: Dict[int, str] = {} | |
| # Load or build vocab | |
| if vocab_file and os.path.exists(vocab_file): | |
| self._load_vocab(vocab_file) | |
| else: | |
| self._build_initial_vocab() | |
| # Call parent init if using transformers | |
| if HAS_TRANSFORMERS: | |
| super().__init__( | |
| unk_token=unk_token, | |
| bos_token=bos_token, | |
| eos_token=eos_token, | |
| pad_token=pad_token, | |
| sep_token=sep_token, | |
| cls_token=cls_token, | |
| mask_token=mask_token, | |
| add_prefix_space=add_prefix_space, | |
| **kwargs | |
| ) | |
| def _build_initial_vocab(self): | |
| """Build initial vocabulary with special tokens and common morphemes.""" | |
| # Special tokens first (IDs 0-7) | |
| special = ["<unk>", "<s>", "</s>", "<pad>", "<sep>", "<cls>", "<mask>", "▁"] | |
| for i, tok in enumerate(special): | |
| self._vocab[tok] = i | |
| self._id_to_token[i] = tok | |
| # Common vibhakti endings | |
| vibhaktis = [ | |
| "H", "m", "am", "At", "Aya", "asya", "e", "O", "ayoH", | |
| "AH", "An", "eByo", "EH", "ezu", "ena", "ABym", | |
| "A", "AyAH", "AyAm", "ayA", "Ani", "AnAm", | |
| "sya", "ya", "aH", "iH", "uH", | |
| ] | |
| # Common pratyayas | |
| pratyayas = [ | |
| "tvA", "ya", "ta", "tavat", "at", "Ana", "tum", | |
| "ti", "ana", "aka", "in", "tf", "tva", "tA", | |
| "maya", "vat", "mat", "ika", "Iya", | |
| ] | |
| # Common upasargas | |
| upasargas = [ | |
| "pra", "parA", "apa", "sam", "anu", "ava", "nis", "nir", | |
| "vi", "A", "ni", "aDi", "api", "ati", "su", "ut", "ud", | |
| "aBi", "prati", "pari", "upa", "dur", "dus", | |
| ] | |
| # Add morphemes to vocab | |
| next_id = len(self._vocab) | |
| for morpheme_list in [vibhaktis, pratyayas, upasargas]: | |
| for m in morpheme_list: | |
| if m not in self._vocab: | |
| self._vocab[m] = next_id | |
| self._id_to_token[next_id] = m | |
| next_id += 1 | |
| # Also add with space prefix | |
| spaced = "▁" + m | |
| if spaced not in self._vocab: | |
| self._vocab[spaced] = next_id | |
| self._id_to_token[next_id] = spaced | |
| next_id += 1 | |
| print(f" PaniniTokenizerV3: Initial vocab size = {len(self._vocab)}") | |
| def _load_vocab(self, vocab_file: str): | |
| """Load vocabulary from JSON file.""" | |
| with open(vocab_file, "r", encoding="utf-8") as f: | |
| self._vocab = json.load(f) | |
| self._id_to_token = {v: k for k, v in self._vocab.items()} | |
| print(f" PaniniTokenizerV3: Loaded vocab size = {len(self._vocab)}") | |
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: | |
| """Save vocabulary to directory.""" | |
| if not os.path.isdir(save_directory): | |
| os.makedirs(save_directory, exist_ok=True) | |
| vocab_file = os.path.join( | |
| save_directory, | |
| (filename_prefix + "-" if filename_prefix else "") + "vocab.json" | |
| ) | |
| with open(vocab_file, "w", encoding="utf-8") as f: | |
| json.dump(self._vocab, f, ensure_ascii=False, indent=2) | |
| return (vocab_file,) | |
| def save_pretrained(self, save_directory: str, **kwargs): | |
| """ | |
| Save the tokenizer to a directory (HuggingFace compatible). | |
| Creates: vocab.json, tokenizer_config.json, special_tokens_map.json | |
| """ | |
| os.makedirs(save_directory, exist_ok=True) | |
| # 1. Save vocabulary | |
| vocab_file = os.path.join(save_directory, "vocab.json") | |
| with open(vocab_file, "w", encoding="utf-8") as f: | |
| json.dump(self._vocab, f, ensure_ascii=False, indent=2) | |
| # 2. Save tokenizer config | |
| config = { | |
| "tokenizer_class": "PaniniTokenizerV3", | |
| "vocab_size": len(self._vocab), | |
| "unk_token": "<unk>", | |
| "bos_token": "<s>", | |
| "eos_token": "</s>", | |
| "pad_token": "<pad>", | |
| "sep_token": "<sep>", | |
| "cls_token": "<cls>", | |
| "mask_token": "<mask>", | |
| "add_prefix_space": self.add_prefix_space, | |
| "freeze_vocab": self.freeze_vocab, | |
| } | |
| config_file = os.path.join(save_directory, "tokenizer_config.json") | |
| with open(config_file, "w", encoding="utf-8") as f: | |
| json.dump(config, f, ensure_ascii=False, indent=2) | |
| # 3. Save special tokens map | |
| special_tokens = { | |
| "unk_token": "<unk>", | |
| "bos_token": "<s>", | |
| "eos_token": "</s>", | |
| "pad_token": "<pad>", | |
| "sep_token": "<sep>", | |
| "cls_token": "<cls>", | |
| "mask_token": "<mask>", | |
| } | |
| special_file = os.path.join(save_directory, "special_tokens_map.json") | |
| with open(special_file, "w", encoding="utf-8") as f: | |
| json.dump(special_tokens, f, ensure_ascii=False, indent=2) | |
| print(f"✅ Saved PaniniTokenizerV3 to {save_directory}/") | |
| print(f" vocab.json: {len(self._vocab)} tokens") | |
| return save_directory | |
| def from_pretrained(cls, pretrained_path: str, **kwargs): | |
| """ | |
| Load a tokenizer from a directory (HuggingFace compatible). | |
| """ | |
| vocab_file = os.path.join(pretrained_path, "vocab.json") | |
| config_file = os.path.join(pretrained_path, "tokenizer_config.json") | |
| # Load config if exists | |
| config = {} | |
| if os.path.exists(config_file): | |
| with open(config_file, "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| # Create tokenizer | |
| tokenizer = cls( | |
| vocab_file=vocab_file, | |
| freeze_vocab=config.get("freeze_vocab", True), | |
| add_prefix_space=config.get("add_prefix_space", True), | |
| **kwargs | |
| ) | |
| print(f"✅ Loaded PaniniTokenizerV3 from {pretrained_path}/") | |
| print(f" vocab.json: {len(tokenizer._vocab)} tokens") | |
| return tokenizer | |
| def vocab_size(self) -> int: | |
| return len(self._vocab) | |
| def get_vocab(self) -> Dict[str, int]: | |
| return dict(self._vocab) | |
| def _add_to_vocab(self, token: str) -> int: | |
| """Dynamically add a token to vocabulary.""" | |
| if token in self._vocab: | |
| return self._vocab[token] | |
| new_id = len(self._vocab) | |
| self._vocab[token] = new_id | |
| self._id_to_token[new_id] = token | |
| return new_id | |
| def _convert_token_to_id(self, token: str) -> int: | |
| """Convert token to ID, adding to vocab if needed (dynamic vocab).""" | |
| if token in self._vocab: | |
| return self._vocab[token] | |
| # Freeze mode: return unk_id for unknown tokens (prevents vocab explosion) | |
| if self.freeze_vocab: | |
| return self._vocab.get("<unk>", 0) | |
| # Dynamic vocab: add new tokens | |
| return self._add_to_vocab(token) | |
| def _convert_id_to_token(self, index: int) -> str: | |
| """Convert ID to token.""" | |
| return self._id_to_token.get(index, self.unk_token) | |
| def _tokenize_word(self, word: str) -> List[str]: | |
| """ | |
| Tokenize a single word using morphological analysis. | |
| New Grammar-Safe Pipeline (Rule A, B, C): | |
| 1. Parse with Vidyut (Collapse spines) | |
| 2. Iterative Samasa Splitting | |
| 3. No SP fallback for valid stems | |
| """ | |
| if not word: | |
| return [] | |
| # Rule 3: Verbal forms (tiṅanta/kṛdanta) are atomic | |
| # If word ends with verbal suffix, emit as single token without splitting | |
| if self.analyzer._is_verb_form(word): | |
| return ["▁" + word] | |
| # Step 1: Get morphological parse (Derivational Collapse) | |
| parse = self.analyzer.get_best_parse(word) | |
| stem = parse.token_form() | |
| # Rule A: If stem is valid in Kosha, DO NOT SPLIT further with SP | |
| # Check if it's a compound that needs splitting | |
| # Step 2: Iterative Samasa Splitting (Rule B) | |
| # We split the *stem* recursively | |
| final_tokens = [] | |
| # If the analyzer says it's a compound OR it looks like one | |
| # We try to split it repeatedly | |
| current_components = [stem] | |
| # Helper: merge adjacent tokens that form known compounds | |
| def merge_known_compounds(parts): | |
| """Merge adjacent parts that together form a known compound.""" | |
| merged = [] | |
| i = 0 | |
| while i < len(parts): | |
| if i + 1 < len(parts): | |
| # Try merging with Sandhi normalization | |
| left = parts[i] | |
| right = parts[i + 1] | |
| # Handle vowel Sandhi: pratyag + AtmA → pratyagAtman | |
| if left.endswith('A'): | |
| candidate = left[:-1] + 'a' + right # AtmA → Atma + next | |
| else: | |
| candidate = left + right | |
| # Also try: left ends with 'a' consumed by right starting with 'A' | |
| # pratyag + AtmA → check if pratyagAtma or pratyagAtman in kosha | |
| candidates = [candidate] | |
| if left.endswith('A') and not right.startswith(('a', 'A', 'i', 'I', 'u', 'U', 'e', 'E', 'o', 'O')): | |
| # Right starts with consonant but might have lost initial vowel | |
| candidates.append(left + 'A' + right) # pratyagA + bhAsa | |
| if self.analyzer._in_kosha(candidate): | |
| merged.append(candidate) | |
| i += 2 | |
| continue | |
| # Try with Atman ending | |
| atman_candidate = left[:-1] + 'an' if left.endswith('A') else left + 'an' | |
| if right.endswith('A'): | |
| atman_full = atman_candidate + right[:-1] + 'a' | |
| else: | |
| atman_full = atman_candidate | |
| if len(atman_candidate) > 3 and self.analyzer._in_kosha(atman_candidate): | |
| merged.append(atman_candidate) | |
| # Still need to process right | |
| merged.append(right) | |
| i += 2 | |
| continue | |
| merged.append(parts[i]) | |
| i += 1 | |
| return merged | |
| # Iterative splitting until fixed point | |
| MAX_PASSES = 6 # Increased for deep compounds | |
| for _ in range(MAX_PASSES): | |
| new_components = [] | |
| changed = False | |
| # Split pass | |
| for comp in current_components: | |
| # Try to split this component | |
| split_res = self.splitter.split(comp) | |
| if split_res.is_compound and len(split_res.components) > 1: | |
| new_components.extend(split_res.components) | |
| changed = True | |
| else: | |
| # Sandhi restoration retry: if starts with consonant, NO split found, | |
| # AND token is NOT valid (it's an OOV leftover from previous split), | |
| # try prepending 'A' (initial vowel eaten in Sandhi) | |
| # FIXED: Use _is_valid_stem (includes pratyaya stripping) not just _in_kosha | |
| if (len(comp) > 3 and | |
| comp[0] not in 'aAiIuUeEoO' and | |
| not self.splitter._is_valid_stem(comp)): # Guard: only for truly invalid OOV | |
| restored = 'A' + comp | |
| restored_res = self.splitter.split(restored) | |
| if restored_res.is_compound and len(restored_res.components) > 1: | |
| # Map result back: first component keeps A prefix | |
| new_components.extend(restored_res.components) | |
| changed = True | |
| continue | |
| new_components.append(comp) | |
| # Merge pass: merge adjacent tokens that form known compounds | |
| merged_components = merge_known_compounds(new_components) | |
| if len(merged_components) != len(new_components): | |
| changed = True | |
| if not changed: | |
| break | |
| current_components = merged_components | |
| # Add tokens with spacing | |
| for i, comp in enumerate(current_components): | |
| # Rule A Violation Check: | |
| # If 'comp' is in Kosha, use it AS IS. | |
| # Only fall back to char/subword if it's garbage. | |
| prefix = "▁" if i == 0 else "" | |
| if self.analyzer._in_kosha(comp): | |
| # Valid stem -> Atomic Token | |
| final_tokens.append(prefix + comp) | |
| else: | |
| # OOV -> Only then maybe SP (but here we just keep as is for now) | |
| # Ideally we want to mark it or maybe split chars if desperate | |
| final_tokens.append(prefix + comp) | |
| # Append vibhakti if separated (only for the last component usually) | |
| # Append vibhakti if separated (only if not already present) | |
| if parse.vibhakti and final_tokens: | |
| last_token = final_tokens[-1].lstrip('▁') | |
| # Guard: don't double-append if last token already ends with vibhakti | |
| if not last_token.endswith(parse.vibhakti): | |
| final_tokens.append(parse.vibhakti) | |
| return final_tokens | |
| def tokenize(self, text: str, **kwargs) -> List[str]: | |
| """ | |
| Tokenize text into morphological tokens. | |
| This is the main entry point for tokenization. | |
| """ | |
| if not text: | |
| return [] | |
| # Split on whitespace | |
| words = text.split() | |
| all_tokens = [] | |
| for i, word in enumerate(words): | |
| word_tokens = self._tokenize_word(word) | |
| all_tokens.extend(word_tokens) | |
| return all_tokens | |
| def _encode_impl(self, text: str) -> List[int]: | |
| """Internal encode implementation.""" | |
| tokens = self.tokenize(text) | |
| return [self._convert_token_to_id(t) for t in tokens] | |
| def encode( | |
| self, | |
| text: Union[str, List[str]], | |
| add_special_tokens: bool = True, | |
| **kwargs | |
| ) -> List[int]: | |
| """Encode text to token IDs.""" | |
| if isinstance(text, list): | |
| text = " ".join(text) | |
| ids = self._encode_impl(text) | |
| if add_special_tokens: | |
| bos_id = self._vocab.get("<s>", 1) | |
| eos_id = self._vocab.get("</s>", 2) | |
| ids = [bos_id] + ids + [eos_id] | |
| return ids | |
| def decode( | |
| self, | |
| token_ids: List[int], | |
| skip_special_tokens: bool = True, | |
| **kwargs | |
| ) -> str: | |
| """Decode token IDs back to text.""" | |
| special_ids = {0, 1, 2, 3, 4, 5, 6} # Special token IDs | |
| tokens = [] | |
| for tid in token_ids: | |
| if skip_special_tokens and tid in special_ids: | |
| continue | |
| token = self._convert_id_to_token(tid) | |
| tokens.append(token) | |
| # Join tokens, handling space prefix | |
| text = "" | |
| for t in tokens: | |
| if t.startswith("▁"): | |
| text += " " + t[1:] | |
| else: | |
| text += t | |
| return text.strip() | |
| def convert_tokens_to_string(self, tokens: List[str]) -> str: | |
| """Convert token list back to string.""" | |
| text = "" | |
| for t in tokens: | |
| if t.startswith("▁"): | |
| text += " " + t[1:] | |
| else: | |
| text += t | |
| return text.strip() | |
| # --- CONVENIENCE FUNCTION --- | |
| def create_tokenizer(vocab_path: Optional[str] = None) -> PaniniTokenizerV3: | |
| """Create a PaniniTokenizerV3 instance.""" | |
| return PaniniTokenizerV3(vocab_file=vocab_path) | |
| # --- TEST --- | |
| if __name__ == "__main__": | |
| print("\n" + "="*60) | |
| print(" Testing PaniniTokenizerV3") | |
| print("="*60) | |
| tokenizer = PaniniTokenizerV3() | |
| test_cases = [ | |
| "rAmaH gacCati", | |
| "hfdpadmagataM paramAtma", | |
| "sopAdhikapratyagAtmAbhAsabhedAbhedavicAraH", | |
| ] | |
| for text in test_cases: | |
| tokens = tokenizer.tokenize(text) | |
| ids = tokenizer.encode(text, add_special_tokens=False) | |
| decoded = tokenizer.decode(ids) | |
| print(f"\n Input: {text}") | |
| print(f" Tokens: {tokens}") | |
| print(f" IDs: {ids[:10]}..." if len(ids) > 10 else f" IDs: {ids}") | |
| print(f" Decoded: {decoded}") | |