""" Tokenization Controls Module ============================== Tokenizer selection, token counting, truncation, and splitting. Supports tiktoken (OpenAI) and HuggingFace tokenizers. """ from dataclasses import dataclass from typing import Dict, List, Any, Optional import pandas as pd import numpy as np @dataclass class TokenizationConfig: """Configuration for tokenization controls.""" tokenizer_name: str = "tiktoken" # "tiktoken" or HF model name tiktoken_encoding: str = "cl100k_base" # for tiktoken max_total_tokens: int = 2048 truncate_long: bool = False split_long: bool = False split_overlap: int = 50 # overlap tokens when splitting def get_tokenizer(config: TokenizationConfig): """ Return a tokenizer-like object. For tiktoken: returns the encoding object. For HF: returns AutoTokenizer instance. """ if config.tokenizer_name == "tiktoken": try: import tiktoken return tiktoken.get_encoding(config.tiktoken_encoding) except ImportError: raise ImportError("tiktoken is required. Install with: pip install tiktoken") else: try: from transformers import AutoTokenizer return AutoTokenizer.from_pretrained(config.tokenizer_name) except ImportError: raise ImportError("transformers is required for HF tokenizers.") def count_tokens(text: str, tokenizer, is_tiktoken: bool = True) -> int: """Count tokens in a text string.""" if not isinstance(text, str) or not text.strip(): return 0 if is_tiktoken: return len(tokenizer.encode(text)) else: return len(tokenizer.encode(text, add_special_tokens=False)) def compute_token_stats( df: pd.DataFrame, columns: List[str], tokenizer, is_tiktoken: bool = True, ) -> Dict[str, Dict[str, float]]: """ Compute token statistics for specified columns. Returns dict of column -> {min, max, mean, median, p95, total}. """ stats = {} for col in columns: if col not in df.columns: continue counts = df[col].apply(lambda t: count_tokens(t, tokenizer, is_tiktoken)) stats[col] = { 'min': int(counts.min()) if len(counts) > 0 else 0, 'max': int(counts.max()) if len(counts) > 0 else 0, 'mean': round(float(counts.mean()), 1) if len(counts) > 0 else 0, 'median': int(counts.median()) if len(counts) > 0 else 0, 'p95': int(np.percentile(counts, 95)) if len(counts) > 0 else 0, 'total': int(counts.sum()), } return stats def truncate_samples( df: pd.DataFrame, col: str, max_tokens: int, tokenizer, is_tiktoken: bool = True, ) -> pd.DataFrame: """Truncate text in a column to max_tokens.""" df = df.copy() def _truncate(text): if not isinstance(text, str): return text if is_tiktoken: tokens = tokenizer.encode(text) if len(tokens) > max_tokens: return tokenizer.decode(tokens[:max_tokens]) else: tokens = tokenizer.encode(text, add_special_tokens=False) if len(tokens) > max_tokens: return tokenizer.decode(tokens[:max_tokens]) return text df[col] = df[col].apply(_truncate) return df def split_long_samples( df: pd.DataFrame, col: str, max_tokens: int, tokenizer, is_tiktoken: bool = True, overlap: int = 50, ) -> pd.DataFrame: """ Split rows whose text exceeds max_tokens into multiple rows. Each chunk has `overlap` tokens of context from the previous chunk. """ new_rows = [] for _, row in df.iterrows(): text = row[col] if not isinstance(text, str): new_rows.append(row) continue if is_tiktoken: tokens = tokenizer.encode(text) else: tokens = tokenizer.encode(text, add_special_tokens=False) if len(tokens) <= max_tokens: new_rows.append(row) else: step = max(1, max_tokens - overlap) for i in range(0, len(tokens), step): chunk_tokens = tokens[i:i + max_tokens] if not chunk_tokens: break new_row = row.copy() if is_tiktoken: new_row[col] = tokenizer.decode(chunk_tokens) else: new_row[col] = tokenizer.decode(chunk_tokens) new_rows.append(new_row) return pd.DataFrame(new_rows).reset_index(drop=True)