File size: 8,606 Bytes
37ed739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""
Tokenizer utilities for extracting BPE/SentencePiece metadata.

Provides functions to:
- Extract subword pieces from tokens
- Calculate byte lengths
- Identify multi-split identifiers (≥3 subwords)
- Detect tokenization artifacts
"""

from typing import List, Tuple, Dict, Optional
import re
import logging

logger = logging.getLogger(__name__)


class TokenizerMetadata:
    """Extracts and analyzes tokenization metadata"""

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        # Detect tokenizer type
        self.tokenizer_type = self._detect_tokenizer_type()

    def _detect_tokenizer_type(self) -> str:
        """Detect whether tokenizer uses BPE, SentencePiece, or other"""
        tokenizer_name = self.tokenizer.__class__.__name__.lower()

        if 'sentencepiece' in tokenizer_name:
            return 'sentencepiece'
        elif 'gpt2' in tokenizer_name or 'codegen' in tokenizer_name:
            return 'bpe'
        elif 'llama' in tokenizer_name:
            return 'sentencepiece'
        else:
            return 'unknown'

    def get_subword_pieces(self, token_id: int) -> List[str]:
        """
        Extract subword pieces for a token ID.

        For BPE (GPT-2/CodeGen):
        - Tokens may contain 'Ġ' prefix for spaces
        - Example: token_id=1234 → "Ġuser" → ["user"]

        For SentencePiece (Llama):
        - Tokens may contain '▁' prefix for spaces
        - Example: token_id=5678 → "▁name" → ["name"]

        Returns:
            List of subword pieces (cleaned of special characters)
        """
        try:
            # Decode single token
            token_str = self.tokenizer.decode([token_id])

            # Clean special characters
            if self.tokenizer_type == 'bpe':
                # Remove 'Ġ' (GPT-2 space marker)
                cleaned = token_str.replace('Ġ', '')
            elif self.tokenizer_type == 'sentencepiece':
                # Remove '▁' (SentencePiece space marker)
                cleaned = token_str.replace('▁', '')
            else:
                cleaned = token_str

            # For compound identifiers, split on underscores/camelCase
            pieces = self._split_identifier(cleaned)

            return pieces if pieces else [cleaned]

        except Exception as e:
            logger.warning(f"Failed to extract subword pieces for token_id {token_id}: {e}")
            return []

    def _split_identifier(self, text: str) -> List[str]:
        """
        Split identifier into components.

        Examples:
        - "get_user_data" → ["get", "user", "data"]
        - "getUserData" → ["get", "User", "Data"]
        - "process" → ["process"]
        """
        # Split on underscores
        if '_' in text:
            return [p for p in text.split('_') if p]

        # Split camelCase (insert _ before capitals, then split)
        camel_split = re.sub(r'([a-z])([A-Z])', r'\1_\2', text)
        if '_' in camel_split:
            return [p for p in camel_split.split('_') if p]

        # Single token
        return [text]

    def get_byte_length(self, token_id: int) -> int:
        """Get byte length of token (UTF-8 encoding)"""
        try:
            token_str = self.tokenizer.decode([token_id])
            return len(token_str.encode('utf-8'))
        except Exception as e:
            logger.warning(f"Failed to get byte length for token_id {token_id}: {e}")
            return 0

    def is_multi_split_identifier(self, token_ids: List[int], window_size: int = 5) -> List[bool]:
        """
        Identify sequences of ≥3 tokens that form a single identifier.

        This detects cases like:
        - ["process", "_", "user"] (3 tokens for process_user)
        - ["get", "User", "Data"] (3 tokens for getUserData)

        Args:
            token_ids: List of token IDs
            window_size: Size of sliding window to check (default 5)

        Returns:
            Boolean array indicating if each token is part of multi-split identifier
        """
        flags = [False] * len(token_ids)

        for i in range(len(token_ids)):
            # Look ahead up to window_size tokens
            window_end = min(i + window_size, len(token_ids))
            window_tokens = token_ids[i:window_end]

            # Decode window
            window_text = self.tokenizer.decode(window_tokens)

            # Check if this looks like an identifier
            # Heuristic: contains underscores or camelCase, no spaces
            if self._is_identifier(window_text):
                # Count pieces
                pieces = self._split_identifier(window_text)
                if len(pieces) >= 3:
                    # Mark all tokens in window as part of multi-split
                    for j in range(i, window_end):
                        flags[j] = True

        return flags

    def _is_identifier(self, text: str) -> bool:
        """Check if text looks like a code identifier"""
        # No spaces (identifiers don't have spaces)
        if ' ' in text:
            return False

        # Contains letters (not just punctuation)
        if not any(c.isalpha() for c in text):
            return False

        # Contains underscore or camelCase
        if '_' in text or any(c.isupper() for c in text):
            return True

        return False

    def analyze_tokens(self, token_ids: List[int]) -> List[Dict[str, any]]:
        """
        Comprehensive analysis of token sequence.

        Returns list of dictionaries with:
        - token_id: int
        - text: str (decoded token)
        - bpe_pieces: List[str] (subword pieces)
        - byte_length: int
        - is_multi_split: bool (part of multi-split identifier)
        """
        multi_split_flags = self.is_multi_split_identifier(token_ids)

        results = []
        for i, token_id in enumerate(token_ids):
            pieces = self.get_subword_pieces(token_id)
            byte_len = self.get_byte_length(token_id)
            text = self.tokenizer.decode([token_id])

            results.append({
                'token_id': token_id,
                'text': text,
                'bpe_pieces': pieces,
                'byte_length': byte_len,
                'is_multi_split': multi_split_flags[i],
                'num_pieces': len(pieces)
            })

        return results


def get_tokenizer_stats(tokenizer, text: str) -> Dict[str, any]:
    """
    Get tokenization statistics for a given text.

    Returns:
        Dictionary with:
        - num_tokens: Total tokens
        - avg_bytes_per_token: Average bytes per token
        - num_multi_split: Number of tokens in multi-split identifiers
        - tokenization_ratio: Characters / tokens
    """
    token_ids = tokenizer.encode(text, add_special_tokens=False)

    metadata = TokenizerMetadata(tokenizer)
    analysis = metadata.analyze_tokens(token_ids)

    total_bytes = sum(t['byte_length'] for t in analysis)
    num_multi_split = sum(1 for t in analysis if t['is_multi_split'])

    return {
        'num_tokens': len(token_ids),
        'avg_bytes_per_token': total_bytes / len(token_ids) if token_ids else 0,
        'num_multi_split': num_multi_split,
        'tokenization_ratio': len(text) / len(token_ids) if token_ids else 0,
        'analysis': analysis
    }


def flag_risk_hotspots(token_analysis: List[Dict[str, any]], entropy_threshold: float = 1.5) -> List[int]:
    """
    Flag tokens that are risk hotspots based on tokenization + entropy.

    A token is flagged if:
    - It's part of a multi-split identifier (≥3 subwords)
    - AND has high entropy (model is uncertain)

    Args:
        token_analysis: Output from TokenizerMetadata.analyze_tokens()
        entropy_threshold: Entropy threshold (default 1.5 nats)

    Returns:
        List of indices of flagged tokens

    Note: Entropy must be provided externally (from instrumentation layer)
    This function only checks the tokenization criterion.
    """
    flagged = []

    for i, token in enumerate(token_analysis):
        if token['is_multi_split'] and token['num_pieces'] >= 3:
            flagged.append(i)

    return flagged


# Example usage
if __name__ == "__main__":
    # This would be used with an actual tokenizer
    # from transformers import AutoTokenizer
    # tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
    #
    # metadata = TokenizerMetadata(tokenizer)
    # stats = get_tokenizer_stats(tokenizer, "def process_user_data(user_name):")
    # print(stats)

    print("Tokenizer utilities module loaded successfully")