Auto-FineTune-Ops / preprocessing /tokenization.py
aneeb15's picture
Initial release of Auto-FineTune-Ops
d4398e6
"""
Tokenization Controls Module
==============================
Tokenizer selection, token counting, truncation, and splitting.
Supports tiktoken (OpenAI) and HuggingFace tokenizers.
"""
from dataclasses import dataclass
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
@dataclass
class TokenizationConfig:
"""Configuration for tokenization controls."""
tokenizer_name: str = "tiktoken" # "tiktoken" or HF model name
tiktoken_encoding: str = "cl100k_base" # for tiktoken
max_total_tokens: int = 2048
truncate_long: bool = False
split_long: bool = False
split_overlap: int = 50 # overlap tokens when splitting
def get_tokenizer(config: TokenizationConfig):
"""
Return a tokenizer-like object.
For tiktoken: returns the encoding object.
For HF: returns AutoTokenizer instance.
"""
if config.tokenizer_name == "tiktoken":
try:
import tiktoken
return tiktoken.get_encoding(config.tiktoken_encoding)
except ImportError:
raise ImportError("tiktoken is required. Install with: pip install tiktoken")
else:
try:
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(config.tokenizer_name)
except ImportError:
raise ImportError("transformers is required for HF tokenizers.")
def count_tokens(text: str, tokenizer, is_tiktoken: bool = True) -> int:
"""Count tokens in a text string."""
if not isinstance(text, str) or not text.strip():
return 0
if is_tiktoken:
return len(tokenizer.encode(text))
else:
return len(tokenizer.encode(text, add_special_tokens=False))
def compute_token_stats(
df: pd.DataFrame,
columns: List[str],
tokenizer,
is_tiktoken: bool = True,
) -> Dict[str, Dict[str, float]]:
"""
Compute token statistics for specified columns.
Returns dict of column -> {min, max, mean, median, p95, total}.
"""
stats = {}
for col in columns:
if col not in df.columns:
continue
counts = df[col].apply(lambda t: count_tokens(t, tokenizer, is_tiktoken))
stats[col] = {
'min': int(counts.min()) if len(counts) > 0 else 0,
'max': int(counts.max()) if len(counts) > 0 else 0,
'mean': round(float(counts.mean()), 1) if len(counts) > 0 else 0,
'median': int(counts.median()) if len(counts) > 0 else 0,
'p95': int(np.percentile(counts, 95)) if len(counts) > 0 else 0,
'total': int(counts.sum()),
}
return stats
def truncate_samples(
df: pd.DataFrame,
col: str,
max_tokens: int,
tokenizer,
is_tiktoken: bool = True,
) -> pd.DataFrame:
"""Truncate text in a column to max_tokens."""
df = df.copy()
def _truncate(text):
if not isinstance(text, str):
return text
if is_tiktoken:
tokens = tokenizer.encode(text)
if len(tokens) > max_tokens:
return tokenizer.decode(tokens[:max_tokens])
else:
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) > max_tokens:
return tokenizer.decode(tokens[:max_tokens])
return text
df[col] = df[col].apply(_truncate)
return df
def split_long_samples(
df: pd.DataFrame,
col: str,
max_tokens: int,
tokenizer,
is_tiktoken: bool = True,
overlap: int = 50,
) -> pd.DataFrame:
"""
Split rows whose text exceeds max_tokens into multiple rows.
Each chunk has `overlap` tokens of context from the previous chunk.
"""
new_rows = []
for _, row in df.iterrows():
text = row[col]
if not isinstance(text, str):
new_rows.append(row)
continue
if is_tiktoken:
tokens = tokenizer.encode(text)
else:
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= max_tokens:
new_rows.append(row)
else:
step = max(1, max_tokens - overlap)
for i in range(0, len(tokens), step):
chunk_tokens = tokens[i:i + max_tokens]
if not chunk_tokens:
break
new_row = row.copy()
if is_tiktoken:
new_row[col] = tokenizer.decode(chunk_tokens)
else:
new_row[col] = tokenizer.decode(chunk_tokens)
new_rows.append(new_row)
return pd.DataFrame(new_rows).reset_index(drop=True)