"""SentencePiece tokenizer wrapper for HuggingFace compatibility.""" from typing import Optional, List, Union class SentencePieceTokenizerWrapper: """Wrapper to make SentencePiece tokenizer compatible with HuggingFace interface.""" def __init__(self, sp_processor): """ Initialize wrapper. Args: sp_processor: sentencepiece.SentencePieceProcessor instance """ self.sp = sp_processor self.vocab_size = self.sp.vocab_size() self.pad_token_id = self.sp.pad_id() self.eos_token_id = self.sp.eos_id() self.bos_token_id = self.sp.bos_id() self.unk_token_id = self.sp.unk_id() def __call__(self, text, **kwargs): """ Tokenize text. Args: text: Input text or list of texts **kwargs: Additional arguments (truncation, max_length, padding, return_attention_mask) Returns: Dict with input_ids and attention_mask """ # Handle both single string and list of strings is_single = isinstance(text, str) texts = [text] if is_single else text max_length = kwargs.get('max_length', None) padding = kwargs.get('padding', None) truncation = kwargs.get('truncation', False) return_attention_mask = kwargs.get('return_attention_mask', True) # Tokenize all texts all_input_ids = [] for t in texts: tokens = self.sp.encode(t, out_type=int) # Truncate if needed if truncation and max_length and len(tokens) > max_length: tokens = tokens[:max_length] all_input_ids.append(tokens) # Padding if padding or max_length: target_length = max_length or max(len(ids) for ids in all_input_ids) if all_input_ids else 0 padded_input_ids = [] padded_attention_masks = [] for ids in all_input_ids: pad_length = target_length - len(ids) if pad_length > 0: padded_ids = ids + [self.pad_token_id] * pad_length else: padded_ids = ids[:target_length] padded_input_ids.append(padded_ids) attention_mask = [1] * len(ids) + [0] * (target_length - len(ids)) padded_attention_masks.append(attention_mask) result = { "input_ids": padded_input_ids if not is_single else padded_input_ids[0], } if return_attention_mask: result["attention_mask"] = padded_attention_masks if not is_single else padded_attention_masks[0] else: result = { "input_ids": all_input_ids[0] if is_single else all_input_ids, } if return_attention_mask: attention_masks = [[1] * len(ids) for ids in all_input_ids] result["attention_mask"] = attention_masks[0] if is_single else attention_masks return result def encode(self, text, return_tensors=None, **kwargs): """Encode text to token IDs.""" result = self(text, **kwargs) input_ids = result["input_ids"] if return_tensors == "pt": import torch # Ensure input_ids is a 1D list of ints if isinstance(input_ids[0], list): input_ids = input_ids[0] return torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) return input_ids def encode_plus(self, text, **kwargs): """Encode text with additional information (HuggingFace compatibility).""" return self(text, **kwargs) def decode(self, token_ids, skip_special_tokens=False, **kwargs): """Decode token IDs to text.""" if hasattr(token_ids, 'tolist'): # Handle torch tensors token_ids = token_ids.tolist() # Handle various input formats if isinstance(token_ids, (list, tuple)): if len(token_ids) > 0 and isinstance(token_ids[0], (list, tuple)): token_ids = token_ids[0] # Ensure it's a list of ints if not isinstance(token_ids, list): token_ids = [int(t) for t in token_ids] return self.sp.decode(token_ids)