File size: 4,195 Bytes
0861a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""

Utility functions for tokenizer-related operations.

"""
import torch
import logging
from typing import Dict, List, Any, Union, Optional
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)

def get_special_tokens_mask(tokenizer, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
    """

    Retrieve special tokens mask.

    

    Args:

        tokenizer: Tokenizer to use

        token_ids_0: First token IDs

        token_ids_1: Second token IDs (for pairs)

        already_has_special_tokens: Whether token_ids already contain special tokens

        

    Returns:

        List of 1s and 0s, where 1 indicates a special token

    """
    if already_has_special_tokens:
        return tokenizer.get_special_tokens_mask(
            token_ids_0,
            token_ids_1=token_ids_1,
            already_has_special_tokens=True
        )
    
    if token_ids_1 is None:
        return tokenizer.get_special_tokens_mask(
            token_ids_0,
            token_ids_1=None,
            already_has_special_tokens=False
        )
    
    return tokenizer.get_special_tokens_mask(
        token_ids_0,
        token_ids_1=token_ids_1,
        already_has_special_tokens=False
    )

def add_tokens_to_tokenizer(tokenizer, new_tokens):
    """

    Add new tokens to tokenizer vocabulary.

    

    Args:

        tokenizer: Tokenizer to modify

        new_tokens: List of new tokens to add

        

    Returns:

        Number of tokens added

    """
    return tokenizer.add_tokens(new_tokens)
    
def format_batch_for_model(

    batch: Dict[str, torch.Tensor],

    device: torch.device = None

) -> Dict[str, torch.Tensor]:
    """

    Format a batch for model input, moving tensors to specified device.

    

    Args:

        batch: Dictionary of tensors

        device: Device to move tensors to

        

    Returns:

        Formatted batch dictionary

    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    formatted_batch = {}
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            formatted_batch[k] = v.to(device)
        else:
            formatted_batch[k] = v
    return formatted_batch

def batch_encode_plus(

    tokenizer, 

    texts: List[str],

    batch_size: int = 32,

    max_length: int = 512,

    return_tensors: str = "pt",

    **kwargs

) -> List[Dict[str, torch.Tensor]]:
    """

    Encode a large batch of texts in smaller chunks.

    

    Args:

        tokenizer: Tokenizer to use

        texts: List of texts to encode

        batch_size: Size of each processing batch

        max_length: Maximum sequence length

        return_tensors: Return format ('pt' for PyTorch)

        **kwargs: Additional encoding parameters

        

    Returns:

        List of encoded batches

    """
    batches = []
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        encoded = tokenizer(
            batch_texts,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors=return_tensors,
            **kwargs
        )
        batches.append(encoded)
        
    return batches

def get_tokenizer_info(tokenizer) -> Dict[str, Any]:
    """

    Get information about a tokenizer.

    

    Args:

        tokenizer: Tokenizer to inspect

        

    Returns:

        Dictionary with tokenizer information

    """
    info = {
        "vocab_size": len(tokenizer),
        "model_name": getattr(tokenizer, "name_or_path", None),
        "special_tokens": {}
    }
    
    # Get special token attributes if available
    special_tokens = [
        "pad_token", "unk_token", "sep_token", 
        "cls_token", "mask_token", "bos_token", "eos_token"
    ]
    
    for token_name in special_tokens:
        token_value = getattr(tokenizer, f"{token_name}", None)
        if token_value is not None:
            info["special_tokens"][token_name] = token_value
            
    return info