api / backend /tokenizer_utils.py
gary-boon
Add research attention analysis endpoints with Q/K/V extraction
37ed739
raw
history blame
8.61 kB
"""
Tokenizer utilities for extracting BPE/SentencePiece metadata.
Provides functions to:
- Extract subword pieces from tokens
- Calculate byte lengths
- Identify multi-split identifiers (≥3 subwords)
- Detect tokenization artifacts
"""
from typing import List, Tuple, Dict, Optional
import re
import logging
logger = logging.getLogger(__name__)
class TokenizerMetadata:
"""Extracts and analyzes tokenization metadata"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
# Detect tokenizer type
self.tokenizer_type = self._detect_tokenizer_type()
def _detect_tokenizer_type(self) -> str:
"""Detect whether tokenizer uses BPE, SentencePiece, or other"""
tokenizer_name = self.tokenizer.__class__.__name__.lower()
if 'sentencepiece' in tokenizer_name:
return 'sentencepiece'
elif 'gpt2' in tokenizer_name or 'codegen' in tokenizer_name:
return 'bpe'
elif 'llama' in tokenizer_name:
return 'sentencepiece'
else:
return 'unknown'
def get_subword_pieces(self, token_id: int) -> List[str]:
"""
Extract subword pieces for a token ID.
For BPE (GPT-2/CodeGen):
- Tokens may contain 'Ġ' prefix for spaces
- Example: token_id=1234 → "Ġuser" → ["user"]
For SentencePiece (Llama):
- Tokens may contain '▁' prefix for spaces
- Example: token_id=5678 → "▁name" → ["name"]
Returns:
List of subword pieces (cleaned of special characters)
"""
try:
# Decode single token
token_str = self.tokenizer.decode([token_id])
# Clean special characters
if self.tokenizer_type == 'bpe':
# Remove 'Ġ' (GPT-2 space marker)
cleaned = token_str.replace('Ġ', '')
elif self.tokenizer_type == 'sentencepiece':
# Remove '▁' (SentencePiece space marker)
cleaned = token_str.replace('▁', '')
else:
cleaned = token_str
# For compound identifiers, split on underscores/camelCase
pieces = self._split_identifier(cleaned)
return pieces if pieces else [cleaned]
except Exception as e:
logger.warning(f"Failed to extract subword pieces for token_id {token_id}: {e}")
return []
def _split_identifier(self, text: str) -> List[str]:
"""
Split identifier into components.
Examples:
- "get_user_data" → ["get", "user", "data"]
- "getUserData" → ["get", "User", "Data"]
- "process" → ["process"]
"""
# Split on underscores
if '_' in text:
return [p for p in text.split('_') if p]
# Split camelCase (insert _ before capitals, then split)
camel_split = re.sub(r'([a-z])([A-Z])', r'\1_\2', text)
if '_' in camel_split:
return [p for p in camel_split.split('_') if p]
# Single token
return [text]
def get_byte_length(self, token_id: int) -> int:
"""Get byte length of token (UTF-8 encoding)"""
try:
token_str = self.tokenizer.decode([token_id])
return len(token_str.encode('utf-8'))
except Exception as e:
logger.warning(f"Failed to get byte length for token_id {token_id}: {e}")
return 0
def is_multi_split_identifier(self, token_ids: List[int], window_size: int = 5) -> List[bool]:
"""
Identify sequences of ≥3 tokens that form a single identifier.
This detects cases like:
- ["process", "_", "user"] (3 tokens for process_user)
- ["get", "User", "Data"] (3 tokens for getUserData)
Args:
token_ids: List of token IDs
window_size: Size of sliding window to check (default 5)
Returns:
Boolean array indicating if each token is part of multi-split identifier
"""
flags = [False] * len(token_ids)
for i in range(len(token_ids)):
# Look ahead up to window_size tokens
window_end = min(i + window_size, len(token_ids))
window_tokens = token_ids[i:window_end]
# Decode window
window_text = self.tokenizer.decode(window_tokens)
# Check if this looks like an identifier
# Heuristic: contains underscores or camelCase, no spaces
if self._is_identifier(window_text):
# Count pieces
pieces = self._split_identifier(window_text)
if len(pieces) >= 3:
# Mark all tokens in window as part of multi-split
for j in range(i, window_end):
flags[j] = True
return flags
def _is_identifier(self, text: str) -> bool:
"""Check if text looks like a code identifier"""
# No spaces (identifiers don't have spaces)
if ' ' in text:
return False
# Contains letters (not just punctuation)
if not any(c.isalpha() for c in text):
return False
# Contains underscore or camelCase
if '_' in text or any(c.isupper() for c in text):
return True
return False
def analyze_tokens(self, token_ids: List[int]) -> List[Dict[str, any]]:
"""
Comprehensive analysis of token sequence.
Returns list of dictionaries with:
- token_id: int
- text: str (decoded token)
- bpe_pieces: List[str] (subword pieces)
- byte_length: int
- is_multi_split: bool (part of multi-split identifier)
"""
multi_split_flags = self.is_multi_split_identifier(token_ids)
results = []
for i, token_id in enumerate(token_ids):
pieces = self.get_subword_pieces(token_id)
byte_len = self.get_byte_length(token_id)
text = self.tokenizer.decode([token_id])
results.append({
'token_id': token_id,
'text': text,
'bpe_pieces': pieces,
'byte_length': byte_len,
'is_multi_split': multi_split_flags[i],
'num_pieces': len(pieces)
})
return results
def get_tokenizer_stats(tokenizer, text: str) -> Dict[str, any]:
"""
Get tokenization statistics for a given text.
Returns:
Dictionary with:
- num_tokens: Total tokens
- avg_bytes_per_token: Average bytes per token
- num_multi_split: Number of tokens in multi-split identifiers
- tokenization_ratio: Characters / tokens
"""
token_ids = tokenizer.encode(text, add_special_tokens=False)
metadata = TokenizerMetadata(tokenizer)
analysis = metadata.analyze_tokens(token_ids)
total_bytes = sum(t['byte_length'] for t in analysis)
num_multi_split = sum(1 for t in analysis if t['is_multi_split'])
return {
'num_tokens': len(token_ids),
'avg_bytes_per_token': total_bytes / len(token_ids) if token_ids else 0,
'num_multi_split': num_multi_split,
'tokenization_ratio': len(text) / len(token_ids) if token_ids else 0,
'analysis': analysis
}
def flag_risk_hotspots(token_analysis: List[Dict[str, any]], entropy_threshold: float = 1.5) -> List[int]:
"""
Flag tokens that are risk hotspots based on tokenization + entropy.
A token is flagged if:
- It's part of a multi-split identifier (≥3 subwords)
- AND has high entropy (model is uncertain)
Args:
token_analysis: Output from TokenizerMetadata.analyze_tokens()
entropy_threshold: Entropy threshold (default 1.5 nats)
Returns:
List of indices of flagged tokens
Note: Entropy must be provided externally (from instrumentation layer)
This function only checks the tokenization criterion.
"""
flagged = []
for i, token in enumerate(token_analysis):
if token['is_multi_split'] and token['num_pieces'] >= 3:
flagged.append(i)
return flagged
# Example usage
if __name__ == "__main__":
# This would be used with an actual tokenizer
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
#
# metadata = TokenizerMetadata(tokenizer)
# stats = get_tokenizer_stats(tokenizer, "def process_user_data(user_name):")
# print(stats)
print("Tokenizer utilities module loaded successfully")