Rabbinic-Embedding-Bench / check_token_limits.py
Lev Israel
Initial Commit
018c4c5
"""
Check token limits for benchmark data entries.
This script scans the benchmark dataset and flags entries that exceed
the 8192 token limit used by OpenAI embedding models (text-embedding-ada-002,
text-embedding-3-small, text-embedding-3-large).
Uses tiktoken with the cl100k_base encoding, which is the tokenizer used
by OpenAI's embedding models.
"""
import json
import argparse
from pathlib import Path
from dataclasses import dataclass
import tiktoken
# OpenAI embedding models use cl100k_base encoding
ENCODING_NAME = "cl100k_base"
MAX_TOKENS = 8192
@dataclass
class TokenOverage:
"""Represents an entry that exceeds the token limit."""
ref: str
category: str
field: str # 'he', 'en', or 'combined'
token_count: int
char_count: int
text_preview: str # First N characters of the text
def count_tokens(text: str, encoding: tiktoken.Encoding) -> int:
"""Count the number of tokens in a text string."""
return len(encoding.encode(text))
def check_entry(
entry: dict,
encoding: tiktoken.Encoding,
max_tokens: int = MAX_TOKENS,
preview_length: int = 100
) -> list[TokenOverage]:
"""
Check a single entry for token limit violations.
Args:
entry: Dictionary with 'ref', 'he', 'en', 'category' keys
encoding: tiktoken encoding to use
max_tokens: Maximum allowed tokens
preview_length: Number of characters to include in preview
Returns:
List of TokenOverage objects for any fields exceeding the limit
"""
overages = []
ref = entry.get("ref", "unknown")
category = entry.get("category", "unknown")
for field in ["he", "en"]:
text = entry.get(field, "")
if not text:
continue
token_count = count_tokens(text, encoding)
if token_count > max_tokens:
preview = text[:preview_length] + "..." if len(text) > preview_length else text
overages.append(TokenOverage(
ref=ref,
category=category,
field=field,
token_count=token_count,
char_count=len(text),
text_preview=preview
))
return overages
def check_benchmark_data(
data_path: str,
max_tokens: int = MAX_TOKENS,
verbose: bool = False
) -> tuple[list[TokenOverage], dict]:
"""
Check all entries in the benchmark dataset for token limit violations.
Args:
data_path: Path to the benchmark JSON file
max_tokens: Maximum allowed tokens (default: 8192)
verbose: Print progress information
Returns:
Tuple of (list of overages, statistics dict)
"""
# Load the encoding
if verbose:
print(f"Loading tokenizer: {ENCODING_NAME}")
encoding = tiktoken.get_encoding(ENCODING_NAME)
# Load the data
if verbose:
print(f"Loading data from: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
if verbose:
print(f"Checking {len(data)} entries for token limit ({max_tokens} tokens)...")
# Check all entries
all_overages = []
token_counts_he = []
token_counts_en = []
for i, entry in enumerate(data):
if verbose and (i + 1) % 1000 == 0:
print(f" Processed {i + 1}/{len(data)} entries...")
# Count tokens for statistics
he_text = entry.get("he", "")
en_text = entry.get("en", "")
if he_text:
token_counts_he.append(count_tokens(he_text, encoding))
if en_text:
token_counts_en.append(count_tokens(en_text, encoding))
# Check for overages
overages = check_entry(entry, encoding, max_tokens)
all_overages.extend(overages)
# Compute statistics
stats = {
"total_entries": len(data),
"entries_with_overages": len(set(o.ref for o in all_overages)),
"total_overages": len(all_overages),
"he_overages": len([o for o in all_overages if o.field == "he"]),
"en_overages": len([o for o in all_overages if o.field == "en"]),
"max_tokens_checked": max_tokens,
}
if token_counts_he:
stats["he_token_stats"] = {
"min": min(token_counts_he),
"max": max(token_counts_he),
"avg": sum(token_counts_he) / len(token_counts_he),
"total_entries": len(token_counts_he),
}
if token_counts_en:
stats["en_token_stats"] = {
"min": min(token_counts_en),
"max": max(token_counts_en),
"avg": sum(token_counts_en) / len(token_counts_en),
"total_entries": len(token_counts_en),
}
return all_overages, stats
def print_report(overages: list[TokenOverage], stats: dict) -> None:
"""Print a formatted report of token limit violations."""
print("\n" + "=" * 70)
print("TOKEN LIMIT CHECK REPORT")
print("=" * 70)
print(f"\nDataset Summary:")
print(f" Total entries checked: {stats['total_entries']:,}")
print(f" Token limit: {stats['max_tokens_checked']:,}")
if "he_token_stats" in stats:
he_stats = stats["he_token_stats"]
print(f"\nHebrew/Aramaic Token Statistics:")
print(f" Min tokens: {he_stats['min']:,}")
print(f" Max tokens: {he_stats['max']:,}")
print(f" Avg tokens: {he_stats['avg']:.1f}")
if "en_token_stats" in stats:
en_stats = stats["en_token_stats"]
print(f"\nEnglish Token Statistics:")
print(f" Min tokens: {en_stats['min']:,}")
print(f" Max tokens: {en_stats['max']:,}")
print(f" Avg tokens: {en_stats['avg']:.1f}")
print(f"\nOverage Summary:")
print(f" Entries exceeding limit: {stats['entries_with_overages']:,}")
print(f" Total field overages: {stats['total_overages']:,}")
print(f" - Hebrew/Aramaic fields: {stats['he_overages']:,}")
print(f" - English fields: {stats['en_overages']:,}")
if overages:
print("\n" + "-" * 70)
print("FLAGGED ENTRIES (exceeding token limit):")
print("-" * 70)
# Group by category
by_category = {}
for overage in overages:
if overage.category not in by_category:
by_category[overage.category] = []
by_category[overage.category].append(overage)
for category, category_overages in sorted(by_category.items()):
print(f"\n[{category}] - {len(category_overages)} overage(s)")
for overage in category_overages:
print(f"\n Reference: {overage.ref}")
print(f" Field: {overage.field}")
print(f" Token count: {overage.token_count:,} (limit: {stats['max_tokens_checked']:,})")
print(f" Character count: {overage.char_count:,}")
print(f" Preview: {overage.text_preview}")
else:
print("\n✓ No entries exceed the token limit!")
print("\n" + "=" * 70)
def save_report(
overages: list[TokenOverage],
stats: dict,
output_path: str
) -> None:
"""Save the report to a JSON file."""
report = {
"stats": stats,
"overages": [
{
"ref": o.ref,
"category": o.category,
"field": o.field,
"token_count": o.token_count,
"char_count": o.char_count,
"text_preview": o.text_preview,
}
for o in overages
]
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"\nReport saved to: {output_path}")
def main():
parser = argparse.ArgumentParser(
description="Check benchmark data for entries exceeding OpenAI embedding token limits."
)
parser.add_argument(
"--data",
"-d",
type=str,
default="benchmark_data/benchmark.json",
help="Path to the benchmark JSON file (default: benchmark_data/benchmark.json)"
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=MAX_TOKENS,
help=f"Maximum allowed tokens (default: {MAX_TOKENS})"
)
parser.add_argument(
"--output",
"-o",
type=str,
help="Path to save JSON report (optional)"
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Print progress information"
)
args = parser.parse_args()
# Check if data file exists
if not Path(args.data).exists():
print(f"Error: Data file not found: {args.data}")
return 1
# Run the check
overages, stats = check_benchmark_data(
args.data,
max_tokens=args.max_tokens,
verbose=args.verbose
)
# Print report
print_report(overages, stats)
# Save report if requested
if args.output:
save_report(overages, stats, args.output)
# Return exit code based on whether overages were found
return 1 if overages else 0
if __name__ == "__main__":
exit(main())