""" Utility functions for tokenizer-related operations. """ import torch import logging from typing import Dict, List, Any, Union, Optional from transformers import AutoTokenizer logger = logging.getLogger(__name__) def get_special_tokens_mask(tokenizer, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieve special tokens mask. Args: tokenizer: Tokenizer to use token_ids_0: First token IDs token_ids_1: Second token IDs (for pairs) already_has_special_tokens: Whether token_ids already contain special tokens Returns: List of 1s and 0s, where 1 indicates a special token """ if already_has_special_tokens: return tokenizer.get_special_tokens_mask( token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) if token_ids_1 is None: return tokenizer.get_special_tokens_mask( token_ids_0, token_ids_1=None, already_has_special_tokens=False ) return tokenizer.get_special_tokens_mask( token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False ) def add_tokens_to_tokenizer(tokenizer, new_tokens): """ Add new tokens to tokenizer vocabulary. Args: tokenizer: Tokenizer to modify new_tokens: List of new tokens to add Returns: Number of tokens added """ return tokenizer.add_tokens(new_tokens) def format_batch_for_model( batch: Dict[str, torch.Tensor], device: torch.device = None ) -> Dict[str, torch.Tensor]: """ Format a batch for model input, moving tensors to specified device. Args: batch: Dictionary of tensors device: Device to move tensors to Returns: Formatted batch dictionary """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") formatted_batch = {} for k, v in batch.items(): if isinstance(v, torch.Tensor): formatted_batch[k] = v.to(device) else: formatted_batch[k] = v return formatted_batch def batch_encode_plus( tokenizer, texts: List[str], batch_size: int = 32, max_length: int = 512, return_tensors: str = "pt", **kwargs ) -> List[Dict[str, torch.Tensor]]: """ Encode a large batch of texts in smaller chunks. Args: tokenizer: Tokenizer to use texts: List of texts to encode batch_size: Size of each processing batch max_length: Maximum sequence length return_tensors: Return format ('pt' for PyTorch) **kwargs: Additional encoding parameters Returns: List of encoded batches """ batches = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] encoded = tokenizer( batch_texts, max_length=max_length, padding="max_length", truncation=True, return_tensors=return_tensors, **kwargs ) batches.append(encoded) return batches def get_tokenizer_info(tokenizer) -> Dict[str, Any]: """ Get information about a tokenizer. Args: tokenizer: Tokenizer to inspect Returns: Dictionary with tokenizer information """ info = { "vocab_size": len(tokenizer), "model_name": getattr(tokenizer, "name_or_path", None), "special_tokens": {} } # Get special token attributes if available special_tokens = [ "pad_token", "unk_token", "sep_token", "cls_token", "mask_token", "bos_token", "eos_token" ] for token_name in special_tokens: token_value = getattr(tokenizer, f"{token_name}", None) if token_value is not None: info["special_tokens"][token_name] = token_value return info