""" Tokenizer utilities for extracting BPE/SentencePiece metadata. Provides functions to: - Extract subword pieces from tokens - Calculate byte lengths - Identify multi-split identifiers (≥3 subwords) - Detect tokenization artifacts """ from typing import List, Tuple, Dict, Optional import re import logging logger = logging.getLogger(__name__) class TokenizerMetadata: """Extracts and analyzes tokenization metadata""" def __init__(self, tokenizer): self.tokenizer = tokenizer # Detect tokenizer type self.tokenizer_type = self._detect_tokenizer_type() def _detect_tokenizer_type(self) -> str: """Detect whether tokenizer uses BPE, SentencePiece, or other""" tokenizer_name = self.tokenizer.__class__.__name__.lower() if 'sentencepiece' in tokenizer_name: return 'sentencepiece' elif 'gpt2' in tokenizer_name or 'codegen' in tokenizer_name: return 'bpe' elif 'llama' in tokenizer_name: return 'sentencepiece' else: return 'unknown' def get_subword_pieces(self, token_id: int) -> List[str]: """ Extract subword pieces for a token ID. For BPE (GPT-2/CodeGen): - Tokens may contain 'Ġ' prefix for spaces - Example: token_id=1234 → "Ġuser" → ["user"] For SentencePiece (Llama): - Tokens may contain '▁' prefix for spaces - Example: token_id=5678 → "▁name" → ["name"] Returns: List of subword pieces (cleaned of special characters) """ try: # Decode single token token_str = self.tokenizer.decode([token_id]) # Clean special characters if self.tokenizer_type == 'bpe': # Remove 'Ġ' (GPT-2 space marker) cleaned = token_str.replace('Ġ', '') elif self.tokenizer_type == 'sentencepiece': # Remove '▁' (SentencePiece space marker) cleaned = token_str.replace('▁', '') else: cleaned = token_str # For compound identifiers, split on underscores/camelCase pieces = self._split_identifier(cleaned) return pieces if pieces else [cleaned] except Exception as e: logger.warning(f"Failed to extract subword pieces for token_id {token_id}: {e}") return [] def _split_identifier(self, text: str) -> List[str]: """ Split identifier into components. Examples: - "get_user_data" → ["get", "user", "data"] - "getUserData" → ["get", "User", "Data"] - "process" → ["process"] """ # Split on underscores if '_' in text: return [p for p in text.split('_') if p] # Split camelCase (insert _ before capitals, then split) camel_split = re.sub(r'([a-z])([A-Z])', r'\1_\2', text) if '_' in camel_split: return [p for p in camel_split.split('_') if p] # Single token return [text] def get_byte_length(self, token_id: int) -> int: """Get byte length of token (UTF-8 encoding)""" try: token_str = self.tokenizer.decode([token_id]) return len(token_str.encode('utf-8')) except Exception as e: logger.warning(f"Failed to get byte length for token_id {token_id}: {e}") return 0 def is_multi_split_identifier(self, token_ids: List[int], window_size: int = 5) -> List[bool]: """ Identify sequences of ≥3 tokens that form a single identifier. This detects cases like: - ["process", "_", "user"] (3 tokens for process_user) - ["get", "User", "Data"] (3 tokens for getUserData) Args: token_ids: List of token IDs window_size: Size of sliding window to check (default 5) Returns: Boolean array indicating if each token is part of multi-split identifier """ flags = [False] * len(token_ids) for i in range(len(token_ids)): # Look ahead up to window_size tokens window_end = min(i + window_size, len(token_ids)) window_tokens = token_ids[i:window_end] # Decode window window_text = self.tokenizer.decode(window_tokens) # Check if this looks like an identifier # Heuristic: contains underscores or camelCase, no spaces if self._is_identifier(window_text): # Count pieces pieces = self._split_identifier(window_text) if len(pieces) >= 3: # Mark all tokens in window as part of multi-split for j in range(i, window_end): flags[j] = True return flags def _is_identifier(self, text: str) -> bool: """Check if text looks like a code identifier""" # No spaces (identifiers don't have spaces) if ' ' in text: return False # Contains letters (not just punctuation) if not any(c.isalpha() for c in text): return False # Contains underscore or camelCase if '_' in text or any(c.isupper() for c in text): return True return False def analyze_tokens(self, token_ids: List[int]) -> List[Dict[str, any]]: """ Comprehensive analysis of token sequence. Returns list of dictionaries with: - token_id: int - text: str (decoded token) - bpe_pieces: List[str] (subword pieces) - byte_length: int - is_multi_split: bool (part of multi-split identifier) """ multi_split_flags = self.is_multi_split_identifier(token_ids) results = [] for i, token_id in enumerate(token_ids): pieces = self.get_subword_pieces(token_id) byte_len = self.get_byte_length(token_id) text = self.tokenizer.decode([token_id]) results.append({ 'token_id': token_id, 'text': text, 'bpe_pieces': pieces, 'byte_length': byte_len, 'is_multi_split': multi_split_flags[i], 'num_pieces': len(pieces) }) return results def get_tokenizer_stats(tokenizer, text: str) -> Dict[str, any]: """ Get tokenization statistics for a given text. Returns: Dictionary with: - num_tokens: Total tokens - avg_bytes_per_token: Average bytes per token - num_multi_split: Number of tokens in multi-split identifiers - tokenization_ratio: Characters / tokens """ token_ids = tokenizer.encode(text, add_special_tokens=False) metadata = TokenizerMetadata(tokenizer) analysis = metadata.analyze_tokens(token_ids) total_bytes = sum(t['byte_length'] for t in analysis) num_multi_split = sum(1 for t in analysis if t['is_multi_split']) return { 'num_tokens': len(token_ids), 'avg_bytes_per_token': total_bytes / len(token_ids) if token_ids else 0, 'num_multi_split': num_multi_split, 'tokenization_ratio': len(text) / len(token_ids) if token_ids else 0, 'analysis': analysis } def flag_risk_hotspots(token_analysis: List[Dict[str, any]], entropy_threshold: float = 1.5) -> List[int]: """ Flag tokens that are risk hotspots based on tokenization + entropy. A token is flagged if: - It's part of a multi-split identifier (≥3 subwords) - AND has high entropy (model is uncertain) Args: token_analysis: Output from TokenizerMetadata.analyze_tokens() entropy_threshold: Entropy threshold (default 1.5 nats) Returns: List of indices of flagged tokens Note: Entropy must be provided externally (from instrumentation layer) This function only checks the tokenization criterion. """ flagged = [] for i, token in enumerate(token_analysis): if token['is_multi_split'] and token['num_pieces'] >= 3: flagged.append(i) return flagged # Example usage if __name__ == "__main__": # This would be used with an actual tokenizer # from transformers import AutoTokenizer # tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono") # # metadata = TokenizerMetadata(tokenizer) # stats = get_tokenizer_stats(tokenizer, "def process_user_data(user_name):") # print(stats) print("Tokenizer utilities module loaded successfully")