babylm-mop-100m-gpt2 / morpiece_tokenizer.py
NeTS-lab's picture
Add files using upload-large-folder tool
7c3b84f verified
"""MorPiece Tokenizer for Hugging Face Transformers"""
import json
import os
from typing import List, Optional, Tuple, Union, Dict, Any
from transformers import PreTrainedTokenizer
from transformers.utils import logging
from transformers.tokenization_utils_base import BatchEncoding
logger = logging.get_logger(__name__)
class MorPieceTokenizer(PreTrainedTokenizer):
"""MorPiece tokenizer for Hugging Face transformers.
This tokenizer uses morphological segmentation based on tries and the sufficiency principle.
"""
vocab_files_names = {
"vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json",
"added_tokens_file": "added_tokens.json",
"special_tokens_map_file": "special_tokens_map.json",
}
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
added_tokens_file=None,
special_tokens_map_file=None,
unk_token="<unk>",
pad_token="<pad>",
bos_token="<s>",
eos_token="</s>",
mask_token="<mask>",
sep_token="<sep>",
cls_token="<cls>",
add_prefix_space=True,
vocab_size=60000,
min_frequency=10,
cutoff=100,
bf=4,
use_tokenizers_lib=True,
**kwargs
):
self.vocab_to_id = {}
self.id_to_vocab = {}
# Load special tokens map first if available
special_tokens_map = {}
if special_tokens_map_file and os.path.exists(special_tokens_map_file):
with open(special_tokens_map_file, "r", encoding="utf-8") as f:
special_tokens_map = json.load(f)
logger.info(f"Loaded special tokens map from {special_tokens_map_file}")
# Override default special tokens with those from special_tokens_map.json
unk_token = self._get_token_from_map(special_tokens_map, "unk_token", unk_token)
pad_token = self._get_token_from_map(special_tokens_map, "pad_token", pad_token)
bos_token = self._get_token_from_map(special_tokens_map, "bos_token", bos_token)
eos_token = self._get_token_from_map(special_tokens_map, "eos_token", eos_token)
mask_token = self._get_token_from_map(special_tokens_map, "mask_token", mask_token)
sep_token = self._get_token_from_map(special_tokens_map, "sep_token", sep_token)
cls_token = self._get_token_from_map(special_tokens_map, "cls_token", cls_token)
# Initialize the parent class
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
mask_token=mask_token,
sep_token=sep_token,
cls_token=cls_token,
add_prefix_space=add_prefix_space,
**kwargs
)
# Store MorPiece specific parameters
self.min_frequency = min_frequency
self.cutoff = cutoff
self.bf = bf
self.use_tokenizers_lib = use_tokenizers_lib
# Load vocabulary
if vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self.vocab_to_id = json.load(f)
logger.info(f"Loaded vocabulary with {len(self.vocab_to_id)} tokens from {vocab_file}")
else:
self.vocab_to_id = {}
logger.warning("No vocabulary file found, starting with empty vocabulary")
# Load added tokens
self.added_tokens = {}
if added_tokens_file and os.path.exists(added_tokens_file):
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_data = json.load(f)
self.added_tokens = self._parse_added_tokens(added_tokens_data)
logger.info(f"Loaded {len(self.added_tokens)} added tokens from {added_tokens_file}")
# Add special tokens to vocabulary if they're not already there
self._ensure_special_tokens_in_vocab()
# Add added tokens to vocabulary
self._add_tokens_to_vocab(self.added_tokens)
# Create reverse mapping
self.id_to_vocab = {v: k for k, v in self.vocab_to_id.items()}
# Load tokenizer configuration
if tokenizer_file and os.path.exists(tokenizer_file):
with open(tokenizer_file, "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
if "model" in tokenizer_config:
self.roots = tokenizer_config["model"].get("roots", {})
else:
self.roots = {}
logger.info(f"Loaded tokenizer configuration from {tokenizer_file}")
else:
self.roots = {}
# Set special token IDs after ensuring they're in vocab
self.unk_token_id = self.vocab_to_id.get(str(self.unk_token), 0)
self.pad_token_id = self.vocab_to_id.get(str(self.pad_token), 1)
self.bos_token_id = self.vocab_to_id.get(str(self.bos_token), 2)
self.eos_token_id = self.vocab_to_id.get(str(self.eos_token), 3)
self.mask_token_id = self.vocab_to_id.get(str(self.mask_token), 4)
self.sep_token_id = self.vocab_to_id.get(str(self.sep_token), 5)
self.cls_token_id = self.vocab_to_id.get(str(self.cls_token), 6)
def _get_token_from_map(self, special_tokens_map: Dict, key: str, default: str) -> str:
"""Extract token from special tokens map, handling both string and dict formats"""
if key in special_tokens_map:
token_data = special_tokens_map[key]
if isinstance(token_data, dict):
return token_data.get("content", default)
else:
return str(token_data)
return default
def _parse_added_tokens(self, added_tokens_data: Union[List, Dict]) -> Dict[str, Dict]:
"""Parse added tokens from JSON data"""
added_tokens = {}
if isinstance(added_tokens_data, list):
# Format: [{"id": 0, "content": "<token>", "single_word": false, ...}, ...]
for token_info in added_tokens_data:
if isinstance(token_info, dict) and "content" in token_info:
content = token_info["content"]
added_tokens[content] = {
"id": token_info.get("id"),
"single_word": token_info.get("single_word", False),
"lstrip": token_info.get("lstrip", False),
"rstrip": token_info.get("rstrip", False),
"normalized": token_info.get("normalized", True),
"special": token_info.get("special", False),
}
elif isinstance(added_tokens_data, dict):
# Format: {"<token1>": {"id": 0, "single_word": false, ...}, ...}
for token, token_info in added_tokens_data.items():
if isinstance(token_info, dict):
added_tokens[token] = token_info
else:
# Simple format: {"<token>": id}
added_tokens[token] = {"id": token_info}
return added_tokens
def _ensure_special_tokens_in_vocab(self):
"""Ensure all special tokens are in the vocabulary"""
special_tokens = [
self.unk_token, self.pad_token, self.bos_token,
self.eos_token, self.mask_token, self.sep_token, self.cls_token
]
# Find the next available ID
next_id = max(self.vocab_to_id.values()) + 1 if self.vocab_to_id else 0
for token in special_tokens:
if token and str(token) not in self.vocab_to_id:
self.vocab_to_id[str(token)] = next_id
logger.info(f"Added special token '{token}' with id {next_id}")
next_id += 1
def _add_tokens_to_vocab(self, tokens_dict: Dict[str, Dict]):
"""Add tokens from added_tokens.json to vocabulary"""
for token, token_info in tokens_dict.items():
if token not in self.vocab_to_id:
if "id" in token_info and token_info["id"] is not None:
# Use the specified ID
token_id = token_info["id"]
# Check if ID is already used
if token_id in self.id_to_vocab:
logger.warning(f"Token ID {token_id} already in use for token '{self.id_to_vocab[token_id]}', "
f"assigning new ID for token '{token}'")
token_id = max(self.vocab_to_id.values()) + 1 if self.vocab_to_id else 0
else:
# Assign next available ID
token_id = max(self.vocab_to_id.values()) + 1 if self.vocab_to_id else 0
self.vocab_to_id[token] = token_id
logger.info(f"Added token '{token}' with id {token_id}")
@property
def vocab_size(self) -> int:
return len(self.vocab_to_id)
def get_vocab(self) -> Dict[str, int]:
return self.vocab_to_id.copy()
def get_added_vocab(self) -> Dict[str, int]:
"""Return dictionary of added tokens"""
added_vocab = {}
for token, token_info in self.added_tokens.items():
if token in self.vocab_to_id:
added_vocab[token] = self.vocab_to_id[token]
return added_vocab
def _tokenize(self, text: str, **kwargs) -> List[str]:
"""Tokenize a string using MorPiece algorithm"""
# This is a simplified version - you may want to integrate the full MorPiece logic
words = text.strip().split()
tokens = []
for word in words:
if word in self.all_special_tokens:
tokens.append(word)
elif word in self.added_tokens:
tokens.append(word)
else:
# Use simplified tokenization for now
tokens.extend(self._tokenize_word(word))
return tokens
def _tokenize_word(self, word: str) -> List[str]:
"""Tokenize a single word using MorPiece trie traversal"""
# Simplified implementation
tokens = []
i = 0
while i < len(word):
found = False
# Try to find longest match in vocabulary (including added tokens)
for j in range(len(word), i, -1):
subword = word[i:j]
if subword in self.vocab_to_id:
tokens.append(subword)
i = j
found = True
break
if not found:
tokens.append(str(self.unk_token))
i += 1
return tokens
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID"""
return self.vocab_to_id.get(token, self.unk_token_id)
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token"""
return self.id_to_vocab.get(index, str(self.unk_token))
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert a list of tokens to a string"""
# Handle special prefix tokens
result = []
for token in tokens:
if token.startswith('++'):
result.append(token[2:]) # Remove ++ prefix
else:
result.append(token)
return ''.join(result)
def _encode_plus(
self,
text: Union[str, List[str]],
text_pair: Optional[Union[str, List[str]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str] = False,
truncation: Union[bool, str] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
"""
Tokenize and prepare for the model a sequence or a pair of sequences.
"""
# Handle the case where text is already a list of tokens
if isinstance(text, list):
if is_split_into_words:
tokens = []
for word in text:
tokens.extend(self._tokenize(word))
else:
tokens = text
else:
tokens = self._tokenize(text)
# Convert tokens to IDs
ids = self.convert_tokens_to_ids(tokens)
# Handle text_pair if provided
pair_ids = []
if text_pair is not None:
if isinstance(text_pair, list):
if is_split_into_words:
pair_tokens = []
for word in text_pair:
pair_tokens.extend(self._tokenize(word))
else:
pair_tokens = text_pair
else:
pair_tokens = self._tokenize(text_pair)
pair_ids = self.convert_tokens_to_ids(pair_tokens)
# Add special tokens if requested
if add_special_tokens:
if pair_ids:
ids = [self.bos_token_id] + ids + [self.sep_token_id] + pair_ids + [self.eos_token_id]
else:
ids = [self.bos_token_id] + ids
else:
# When add_special_tokens=False, just concatenate without special tokens
if pair_ids:
ids = ids + pair_ids
# For single text, ids remains as is
# Create the encoding dict
encoding_dict = {"input_ids": ids}
# Add attention mask
if return_attention_mask or return_attention_mask is None:
encoding_dict["attention_mask"] = [1] * len(ids)
# Add token type ids if we have a pair
if return_token_type_ids or (return_token_type_ids is None and pair_ids):
if pair_ids:
if add_special_tokens:
# [BOS] + text + [SEP] + text_pair + [EOS]
token_type_ids = [0] * (len(tokens) + 2) + [1] * (len(pair_tokens) + 1)
else:
# text + text_pair (no special tokens)
token_type_ids = [0] * len(tokens) + [1] * len(pair_tokens)
else:
token_type_ids = [0] * len(ids)
encoding_dict["token_type_ids"] = token_type_ids
# Handle offset mapping if requested
if return_offsets_mapping:
# Simple implementation - assumes character-level offsets
offsets = []
char_pos = 0
for token in tokens:
start = char_pos
end = char_pos + len(token)
offsets.append((start, end))
char_pos = end
if add_special_tokens:
# Add offsets for special tokens
if pair_ids:
special_offsets = [(0, 0)] + offsets + [(0, 0)] + [(0, 0)] * len(pair_tokens) + [(0, 0)]
else:
special_offsets = [(0, 0)] + offsets + [(0, 0)]
encoding_dict["offset_mapping"] = special_offsets
else:
# No special tokens, just use regular offsets
if pair_ids:
# Add offsets for pair tokens
pair_offsets = []
for token in pair_tokens:
start = char_pos
end = char_pos + len(token)
pair_offsets.append((start, end))
char_pos = end
encoding_dict["offset_mapping"] = offsets + pair_offsets
else:
encoding_dict["offset_mapping"] = offsets
# Handle special tokens mask if requested
if return_special_tokens_mask:
if add_special_tokens:
if pair_ids:
special_tokens_mask = [1] + [0] * len(tokens) + [1] + [0] * len(pair_tokens) + [1]
else:
special_tokens_mask = [1] + [0] * len(tokens) + [1]
else:
# No special tokens added, so all tokens are regular tokens
special_tokens_mask = [0] * len(ids)
encoding_dict["special_tokens_mask"] = special_tokens_mask
# Handle padding and truncation
if max_length is not None and truncation:
if len(ids) > max_length:
encoding_dict["input_ids"] = ids[:max_length]
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = encoding_dict["attention_mask"][:max_length]
if "token_type_ids" in encoding_dict:
encoding_dict["token_type_ids"] = encoding_dict["token_type_ids"][:max_length]
if "offset_mapping" in encoding_dict:
encoding_dict["offset_mapping"] = encoding_dict["offset_mapping"][:max_length]
if "special_tokens_mask" in encoding_dict:
encoding_dict["special_tokens_mask"] = encoding_dict["special_tokens_mask"][:max_length]
# Handle padding
if padding and max_length is not None:
current_length = len(encoding_dict["input_ids"])
if current_length < max_length:
pad_length = max_length - current_length
encoding_dict["input_ids"].extend([self.pad_token_id] * pad_length)
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"].extend([0] * pad_length)
if "token_type_ids" in encoding_dict:
encoding_dict["token_type_ids"].extend([0] * pad_length)
if "offset_mapping" in encoding_dict:
encoding_dict["offset_mapping"].extend([(0, 0)] * pad_length)
if "special_tokens_mask" in encoding_dict:
encoding_dict["special_tokens_mask"].extend([1] * pad_length)
# Return length if requested
if return_length:
encoding_dict["length"] = len(encoding_dict["input_ids"])
return BatchEncoding(encoding_dict, tensor_type=return_tensors)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]:
"""Save vocabulary and token files"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
files = []
# Save vocabulary
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self.vocab_to_id, f, indent=2, sort_keys=True, ensure_ascii=False)
files.append(vocab_file)
# Save added tokens
if self.added_tokens:
added_tokens_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "added_tokens.json"
)
# Convert back to list format expected by HF
added_tokens_list = []
for token, info in self.added_tokens.items():
token_data = {"content": token, **info}
if token in self.vocab_to_id:
token_data["id"] = self.vocab_to_id[token]
added_tokens_list.append(token_data)
with open(added_tokens_file, "w", encoding="utf-8") as f:
json.dump(added_tokens_list, f, indent=2, ensure_ascii=False)
files.append(added_tokens_file)
# Save special tokens map
special_tokens_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "special_tokens_map.json"
)
special_tokens_map = {
"unk_token": str(self.unk_token),
"pad_token": str(self.pad_token),
"bos_token": str(self.bos_token),
"eos_token": str(self.eos_token),
"mask_token": str(self.mask_token),
"sep_token": str(self.sep_token),
"cls_token": str(self.cls_token),
}
with open(special_tokens_file, "w", encoding="utf-8") as f:
json.dump(special_tokens_map, f, indent=2, ensure_ascii=False)
files.append(special_tokens_file)
return tuple(files)