File size: 4,195 Bytes
0861a59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
"""
Utility functions for tokenizer-related operations.
"""
import torch
import logging
from typing import Dict, List, Any, Union, Optional
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
def get_special_tokens_mask(tokenizer, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieve special tokens mask.
Args:
tokenizer: Tokenizer to use
token_ids_0: First token IDs
token_ids_1: Second token IDs (for pairs)
already_has_special_tokens: Whether token_ids already contain special tokens
Returns:
List of 1s and 0s, where 1 indicates a special token
"""
if already_has_special_tokens:
return tokenizer.get_special_tokens_mask(
token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True
)
if token_ids_1 is None:
return tokenizer.get_special_tokens_mask(
token_ids_0,
token_ids_1=None,
already_has_special_tokens=False
)
return tokenizer.get_special_tokens_mask(
token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=False
)
def add_tokens_to_tokenizer(tokenizer, new_tokens):
"""
Add new tokens to tokenizer vocabulary.
Args:
tokenizer: Tokenizer to modify
new_tokens: List of new tokens to add
Returns:
Number of tokens added
"""
return tokenizer.add_tokens(new_tokens)
def format_batch_for_model(
batch: Dict[str, torch.Tensor],
device: torch.device = None
) -> Dict[str, torch.Tensor]:
"""
Format a batch for model input, moving tensors to specified device.
Args:
batch: Dictionary of tensors
device: Device to move tensors to
Returns:
Formatted batch dictionary
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
formatted_batch = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
formatted_batch[k] = v.to(device)
else:
formatted_batch[k] = v
return formatted_batch
def batch_encode_plus(
tokenizer,
texts: List[str],
batch_size: int = 32,
max_length: int = 512,
return_tensors: str = "pt",
**kwargs
) -> List[Dict[str, torch.Tensor]]:
"""
Encode a large batch of texts in smaller chunks.
Args:
tokenizer: Tokenizer to use
texts: List of texts to encode
batch_size: Size of each processing batch
max_length: Maximum sequence length
return_tensors: Return format ('pt' for PyTorch)
**kwargs: Additional encoding parameters
Returns:
List of encoded batches
"""
batches = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
encoded = tokenizer(
batch_texts,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors=return_tensors,
**kwargs
)
batches.append(encoded)
return batches
def get_tokenizer_info(tokenizer) -> Dict[str, Any]:
"""
Get information about a tokenizer.
Args:
tokenizer: Tokenizer to inspect
Returns:
Dictionary with tokenizer information
"""
info = {
"vocab_size": len(tokenizer),
"model_name": getattr(tokenizer, "name_or_path", None),
"special_tokens": {}
}
# Get special token attributes if available
special_tokens = [
"pad_token", "unk_token", "sep_token",
"cls_token", "mask_token", "bos_token", "eos_token"
]
for token_name in special_tokens:
token_value = getattr(tokenizer, f"{token_name}", None)
if token_value is not None:
info["special_tokens"][token_name] = token_value
return info
|