|
|
""" |
|
|
Check token limits for benchmark data entries. |
|
|
|
|
|
This script scans the benchmark dataset and flags entries that exceed |
|
|
the 8192 token limit used by OpenAI embedding models (text-embedding-ada-002, |
|
|
text-embedding-3-small, text-embedding-3-large). |
|
|
|
|
|
Uses tiktoken with the cl100k_base encoding, which is the tokenizer used |
|
|
by OpenAI's embedding models. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import tiktoken |
|
|
|
|
|
|
|
|
|
|
|
ENCODING_NAME = "cl100k_base" |
|
|
MAX_TOKENS = 8192 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TokenOverage: |
|
|
"""Represents an entry that exceeds the token limit.""" |
|
|
ref: str |
|
|
category: str |
|
|
field: str |
|
|
token_count: int |
|
|
char_count: int |
|
|
text_preview: str |
|
|
|
|
|
|
|
|
def count_tokens(text: str, encoding: tiktoken.Encoding) -> int: |
|
|
"""Count the number of tokens in a text string.""" |
|
|
return len(encoding.encode(text)) |
|
|
|
|
|
|
|
|
def check_entry( |
|
|
entry: dict, |
|
|
encoding: tiktoken.Encoding, |
|
|
max_tokens: int = MAX_TOKENS, |
|
|
preview_length: int = 100 |
|
|
) -> list[TokenOverage]: |
|
|
""" |
|
|
Check a single entry for token limit violations. |
|
|
|
|
|
Args: |
|
|
entry: Dictionary with 'ref', 'he', 'en', 'category' keys |
|
|
encoding: tiktoken encoding to use |
|
|
max_tokens: Maximum allowed tokens |
|
|
preview_length: Number of characters to include in preview |
|
|
|
|
|
Returns: |
|
|
List of TokenOverage objects for any fields exceeding the limit |
|
|
""" |
|
|
overages = [] |
|
|
|
|
|
ref = entry.get("ref", "unknown") |
|
|
category = entry.get("category", "unknown") |
|
|
|
|
|
for field in ["he", "en"]: |
|
|
text = entry.get(field, "") |
|
|
if not text: |
|
|
continue |
|
|
|
|
|
token_count = count_tokens(text, encoding) |
|
|
|
|
|
if token_count > max_tokens: |
|
|
preview = text[:preview_length] + "..." if len(text) > preview_length else text |
|
|
overages.append(TokenOverage( |
|
|
ref=ref, |
|
|
category=category, |
|
|
field=field, |
|
|
token_count=token_count, |
|
|
char_count=len(text), |
|
|
text_preview=preview |
|
|
)) |
|
|
|
|
|
return overages |
|
|
|
|
|
|
|
|
def check_benchmark_data( |
|
|
data_path: str, |
|
|
max_tokens: int = MAX_TOKENS, |
|
|
verbose: bool = False |
|
|
) -> tuple[list[TokenOverage], dict]: |
|
|
""" |
|
|
Check all entries in the benchmark dataset for token limit violations. |
|
|
|
|
|
Args: |
|
|
data_path: Path to the benchmark JSON file |
|
|
max_tokens: Maximum allowed tokens (default: 8192) |
|
|
verbose: Print progress information |
|
|
|
|
|
Returns: |
|
|
Tuple of (list of overages, statistics dict) |
|
|
""" |
|
|
|
|
|
if verbose: |
|
|
print(f"Loading tokenizer: {ENCODING_NAME}") |
|
|
encoding = tiktoken.get_encoding(ENCODING_NAME) |
|
|
|
|
|
|
|
|
if verbose: |
|
|
print(f"Loading data from: {data_path}") |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if verbose: |
|
|
print(f"Checking {len(data)} entries for token limit ({max_tokens} tokens)...") |
|
|
|
|
|
|
|
|
all_overages = [] |
|
|
token_counts_he = [] |
|
|
token_counts_en = [] |
|
|
|
|
|
for i, entry in enumerate(data): |
|
|
if verbose and (i + 1) % 1000 == 0: |
|
|
print(f" Processed {i + 1}/{len(data)} entries...") |
|
|
|
|
|
|
|
|
he_text = entry.get("he", "") |
|
|
en_text = entry.get("en", "") |
|
|
|
|
|
if he_text: |
|
|
token_counts_he.append(count_tokens(he_text, encoding)) |
|
|
if en_text: |
|
|
token_counts_en.append(count_tokens(en_text, encoding)) |
|
|
|
|
|
|
|
|
overages = check_entry(entry, encoding, max_tokens) |
|
|
all_overages.extend(overages) |
|
|
|
|
|
|
|
|
stats = { |
|
|
"total_entries": len(data), |
|
|
"entries_with_overages": len(set(o.ref for o in all_overages)), |
|
|
"total_overages": len(all_overages), |
|
|
"he_overages": len([o for o in all_overages if o.field == "he"]), |
|
|
"en_overages": len([o for o in all_overages if o.field == "en"]), |
|
|
"max_tokens_checked": max_tokens, |
|
|
} |
|
|
|
|
|
if token_counts_he: |
|
|
stats["he_token_stats"] = { |
|
|
"min": min(token_counts_he), |
|
|
"max": max(token_counts_he), |
|
|
"avg": sum(token_counts_he) / len(token_counts_he), |
|
|
"total_entries": len(token_counts_he), |
|
|
} |
|
|
|
|
|
if token_counts_en: |
|
|
stats["en_token_stats"] = { |
|
|
"min": min(token_counts_en), |
|
|
"max": max(token_counts_en), |
|
|
"avg": sum(token_counts_en) / len(token_counts_en), |
|
|
"total_entries": len(token_counts_en), |
|
|
} |
|
|
|
|
|
return all_overages, stats |
|
|
|
|
|
|
|
|
def print_report(overages: list[TokenOverage], stats: dict) -> None: |
|
|
"""Print a formatted report of token limit violations.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("TOKEN LIMIT CHECK REPORT") |
|
|
print("=" * 70) |
|
|
|
|
|
print(f"\nDataset Summary:") |
|
|
print(f" Total entries checked: {stats['total_entries']:,}") |
|
|
print(f" Token limit: {stats['max_tokens_checked']:,}") |
|
|
|
|
|
if "he_token_stats" in stats: |
|
|
he_stats = stats["he_token_stats"] |
|
|
print(f"\nHebrew/Aramaic Token Statistics:") |
|
|
print(f" Min tokens: {he_stats['min']:,}") |
|
|
print(f" Max tokens: {he_stats['max']:,}") |
|
|
print(f" Avg tokens: {he_stats['avg']:.1f}") |
|
|
|
|
|
if "en_token_stats" in stats: |
|
|
en_stats = stats["en_token_stats"] |
|
|
print(f"\nEnglish Token Statistics:") |
|
|
print(f" Min tokens: {en_stats['min']:,}") |
|
|
print(f" Max tokens: {en_stats['max']:,}") |
|
|
print(f" Avg tokens: {en_stats['avg']:.1f}") |
|
|
|
|
|
print(f"\nOverage Summary:") |
|
|
print(f" Entries exceeding limit: {stats['entries_with_overages']:,}") |
|
|
print(f" Total field overages: {stats['total_overages']:,}") |
|
|
print(f" - Hebrew/Aramaic fields: {stats['he_overages']:,}") |
|
|
print(f" - English fields: {stats['en_overages']:,}") |
|
|
|
|
|
if overages: |
|
|
print("\n" + "-" * 70) |
|
|
print("FLAGGED ENTRIES (exceeding token limit):") |
|
|
print("-" * 70) |
|
|
|
|
|
|
|
|
by_category = {} |
|
|
for overage in overages: |
|
|
if overage.category not in by_category: |
|
|
by_category[overage.category] = [] |
|
|
by_category[overage.category].append(overage) |
|
|
|
|
|
for category, category_overages in sorted(by_category.items()): |
|
|
print(f"\n[{category}] - {len(category_overages)} overage(s)") |
|
|
for overage in category_overages: |
|
|
print(f"\n Reference: {overage.ref}") |
|
|
print(f" Field: {overage.field}") |
|
|
print(f" Token count: {overage.token_count:,} (limit: {stats['max_tokens_checked']:,})") |
|
|
print(f" Character count: {overage.char_count:,}") |
|
|
print(f" Preview: {overage.text_preview}") |
|
|
else: |
|
|
print("\n✓ No entries exceed the token limit!") |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
|
|
|
|
|
|
def save_report( |
|
|
overages: list[TokenOverage], |
|
|
stats: dict, |
|
|
output_path: str |
|
|
) -> None: |
|
|
"""Save the report to a JSON file.""" |
|
|
report = { |
|
|
"stats": stats, |
|
|
"overages": [ |
|
|
{ |
|
|
"ref": o.ref, |
|
|
"category": o.category, |
|
|
"field": o.field, |
|
|
"token_count": o.token_count, |
|
|
"char_count": o.char_count, |
|
|
"text_preview": o.text_preview, |
|
|
} |
|
|
for o in overages |
|
|
] |
|
|
} |
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
json.dump(report, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"\nReport saved to: {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Check benchmark data for entries exceeding OpenAI embedding token limits." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data", |
|
|
"-d", |
|
|
type=str, |
|
|
default="benchmark_data/benchmark.json", |
|
|
help="Path to the benchmark JSON file (default: benchmark_data/benchmark.json)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-tokens", |
|
|
"-m", |
|
|
type=int, |
|
|
default=MAX_TOKENS, |
|
|
help=f"Maximum allowed tokens (default: {MAX_TOKENS})" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
"-o", |
|
|
type=str, |
|
|
help="Path to save JSON report (optional)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verbose", |
|
|
"-v", |
|
|
action="store_true", |
|
|
help="Print progress information" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not Path(args.data).exists(): |
|
|
print(f"Error: Data file not found: {args.data}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
overages, stats = check_benchmark_data( |
|
|
args.data, |
|
|
max_tokens=args.max_tokens, |
|
|
verbose=args.verbose |
|
|
) |
|
|
|
|
|
|
|
|
print_report(overages, stats) |
|
|
|
|
|
|
|
|
if args.output: |
|
|
save_report(overages, stats, args.output) |
|
|
|
|
|
|
|
|
return 1 if overages else 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|