babylm-mop-10m-gpt2 / morpiece_tokenizer.py
NeTS-lab's picture
Upload morpiece_tokenizer.py
e26cd9a verified
raw
history blame
5.78 kB
"""MorPiece Tokenizer for Hugging Face Transformers"""
import json
import os
from typing import List, Optional, Tuple, Union, Dict, Any
from transformers import PreTrainedTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MorPieceTokenizer(PreTrainedTokenizer):
"""MorPiece tokenizer for Hugging Face transformers.
This tokenizer uses morphological segmentation based on tries and the sufficiency principle.
"""
vocab_files_names = {
"vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json",
}
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
unk_token="<unk>",
pad_token="<pad>",
bos_token="<s>",
eos_token="</s>",
mask_token="<mask>",
sep_token="<sep>",
cls_token="<cls>",
add_prefix_space=True,
vocab_size=60000,
min_frequency=10,
cutoff=100,
bf=4,
use_tokenizers_lib=True,
**kwargs
):
self.vocab_to_id = {}
self.id_to_vocab = {}
# Initialize the parent class
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
mask_token=mask_token,
sep_token=sep_token,
cls_token=cls_token,
add_prefix_space=add_prefix_space,
**kwargs
)
# Store MorPiece specific parameters
self.min_frequency = min_frequency
self.cutoff = cutoff
self.bf = bf
self.use_tokenizers_lib = use_tokenizers_lib
# Load vocabulary
if vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self.vocab_to_id = json.load(f)
else:
self.vocab_to_id = {}
self.id_to_vocab = {v: k for k, v in self.vocab_to_id.items()}
# Load tokenizer configuration
if tokenizer_file and os.path.exists(tokenizer_file):
with open(tokenizer_file, "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
if "model" in tokenizer_config:
self.roots = tokenizer_config["model"].get("roots", {})
else:
self.roots = {}
else:
self.roots = {}
# Set special token IDs
self.unk_token_id = self.vocab_to_id.get(unk_token, 0)
self.pad_token_id = self.vocab_to_id.get(pad_token, 1)
self.bos_token_id = self.vocab_to_id.get(bos_token, 2)
self.eos_token_id = self.vocab_to_id.get(eos_token, 3)
self.mask_token_id = self.vocab_to_id.get(mask_token, 4)
self.sep_token_id = self.vocab_to_id.get(sep_token, 5)
self.cls_token_id = self.vocab_to_id.get(cls_token, 6)
@property
def vocab_size(self) -> int:
return len(self.vocab_to_id)
def get_vocab(self) -> Dict[str, int]:
return self.vocab_to_id.copy()
def _tokenize(self, text: str, **kwargs) -> List[str]:
"""Tokenize a string using MorPiece algorithm"""
# This is a simplified version - you may want to integrate the full MorPiece logic
words = text.strip().split()
tokens = []
for word in words:
if word in self.roots.get('[RSX]', {}):
tokens.append(word)
else:
# Use simplified tokenization for now
tokens.extend(self._tokenize_word(word))
return tokens
def _tokenize_word(self, word: str) -> List[str]:
"""Tokenize a single word using MorPiece trie traversal"""
# Simplified implementation
tokens = []
i = 0
while i < len(word):
found = False
# Try to find longest match in vocabulary
for j in range(len(word), i, -1):
subword = word[i:j]
if subword in self.vocab_to_id:
tokens.append(subword)
i = j
found = True
break
if not found:
tokens.append(self.unk_token)
i += 1
return tokens
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID"""
return self.vocab_to_id.get(token, self.unk_token_id)
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token"""
return self.id_to_vocab.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert a list of tokens to a string"""
# Handle special prefix tokens
result = []
for token in tokens:
if token.startswith('++'):
result.append(token[2:]) # Remove ++ prefix
else:
result.append(token)
return ''.join(result)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""Save vocabulary to files"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self.vocab_to_id, f, indent=2, sort_keys=True, ensure_ascii=False)
return (vocab_file,)