Spaces:
Configuration error
Configuration error
| """ | |
| Tokenization Controls Module | |
| ============================== | |
| Tokenizer selection, token counting, truncation, and splitting. | |
| Supports tiktoken (OpenAI) and HuggingFace tokenizers. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Any, Optional | |
| import pandas as pd | |
| import numpy as np | |
| class TokenizationConfig: | |
| """Configuration for tokenization controls.""" | |
| tokenizer_name: str = "tiktoken" # "tiktoken" or HF model name | |
| tiktoken_encoding: str = "cl100k_base" # for tiktoken | |
| max_total_tokens: int = 2048 | |
| truncate_long: bool = False | |
| split_long: bool = False | |
| split_overlap: int = 50 # overlap tokens when splitting | |
| def get_tokenizer(config: TokenizationConfig): | |
| """ | |
| Return a tokenizer-like object. | |
| For tiktoken: returns the encoding object. | |
| For HF: returns AutoTokenizer instance. | |
| """ | |
| if config.tokenizer_name == "tiktoken": | |
| try: | |
| import tiktoken | |
| return tiktoken.get_encoding(config.tiktoken_encoding) | |
| except ImportError: | |
| raise ImportError("tiktoken is required. Install with: pip install tiktoken") | |
| else: | |
| try: | |
| from transformers import AutoTokenizer | |
| return AutoTokenizer.from_pretrained(config.tokenizer_name) | |
| except ImportError: | |
| raise ImportError("transformers is required for HF tokenizers.") | |
| def count_tokens(text: str, tokenizer, is_tiktoken: bool = True) -> int: | |
| """Count tokens in a text string.""" | |
| if not isinstance(text, str) or not text.strip(): | |
| return 0 | |
| if is_tiktoken: | |
| return len(tokenizer.encode(text)) | |
| else: | |
| return len(tokenizer.encode(text, add_special_tokens=False)) | |
| def compute_token_stats( | |
| df: pd.DataFrame, | |
| columns: List[str], | |
| tokenizer, | |
| is_tiktoken: bool = True, | |
| ) -> Dict[str, Dict[str, float]]: | |
| """ | |
| Compute token statistics for specified columns. | |
| Returns dict of column -> {min, max, mean, median, p95, total}. | |
| """ | |
| stats = {} | |
| for col in columns: | |
| if col not in df.columns: | |
| continue | |
| counts = df[col].apply(lambda t: count_tokens(t, tokenizer, is_tiktoken)) | |
| stats[col] = { | |
| 'min': int(counts.min()) if len(counts) > 0 else 0, | |
| 'max': int(counts.max()) if len(counts) > 0 else 0, | |
| 'mean': round(float(counts.mean()), 1) if len(counts) > 0 else 0, | |
| 'median': int(counts.median()) if len(counts) > 0 else 0, | |
| 'p95': int(np.percentile(counts, 95)) if len(counts) > 0 else 0, | |
| 'total': int(counts.sum()), | |
| } | |
| return stats | |
| def truncate_samples( | |
| df: pd.DataFrame, | |
| col: str, | |
| max_tokens: int, | |
| tokenizer, | |
| is_tiktoken: bool = True, | |
| ) -> pd.DataFrame: | |
| """Truncate text in a column to max_tokens.""" | |
| df = df.copy() | |
| def _truncate(text): | |
| if not isinstance(text, str): | |
| return text | |
| if is_tiktoken: | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) > max_tokens: | |
| return tokenizer.decode(tokens[:max_tokens]) | |
| else: | |
| tokens = tokenizer.encode(text, add_special_tokens=False) | |
| if len(tokens) > max_tokens: | |
| return tokenizer.decode(tokens[:max_tokens]) | |
| return text | |
| df[col] = df[col].apply(_truncate) | |
| return df | |
| def split_long_samples( | |
| df: pd.DataFrame, | |
| col: str, | |
| max_tokens: int, | |
| tokenizer, | |
| is_tiktoken: bool = True, | |
| overlap: int = 50, | |
| ) -> pd.DataFrame: | |
| """ | |
| Split rows whose text exceeds max_tokens into multiple rows. | |
| Each chunk has `overlap` tokens of context from the previous chunk. | |
| """ | |
| new_rows = [] | |
| for _, row in df.iterrows(): | |
| text = row[col] | |
| if not isinstance(text, str): | |
| new_rows.append(row) | |
| continue | |
| if is_tiktoken: | |
| tokens = tokenizer.encode(text) | |
| else: | |
| tokens = tokenizer.encode(text, add_special_tokens=False) | |
| if len(tokens) <= max_tokens: | |
| new_rows.append(row) | |
| else: | |
| step = max(1, max_tokens - overlap) | |
| for i in range(0, len(tokens), step): | |
| chunk_tokens = tokens[i:i + max_tokens] | |
| if not chunk_tokens: | |
| break | |
| new_row = row.copy() | |
| if is_tiktoken: | |
| new_row[col] = tokenizer.decode(chunk_tokens) | |
| else: | |
| new_row[col] = tokenizer.decode(chunk_tokens) | |
| new_rows.append(new_row) | |
| return pd.DataFrame(new_rows).reset_index(drop=True) | |