""" Check token limits for benchmark data entries. This script scans the benchmark dataset and flags entries that exceed the 8192 token limit used by OpenAI embedding models (text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large). Uses tiktoken with the cl100k_base encoding, which is the tokenizer used by OpenAI's embedding models. """ import json import argparse from pathlib import Path from dataclasses import dataclass import tiktoken # OpenAI embedding models use cl100k_base encoding ENCODING_NAME = "cl100k_base" MAX_TOKENS = 8192 @dataclass class TokenOverage: """Represents an entry that exceeds the token limit.""" ref: str category: str field: str # 'he', 'en', or 'combined' token_count: int char_count: int text_preview: str # First N characters of the text def count_tokens(text: str, encoding: tiktoken.Encoding) -> int: """Count the number of tokens in a text string.""" return len(encoding.encode(text)) def check_entry( entry: dict, encoding: tiktoken.Encoding, max_tokens: int = MAX_TOKENS, preview_length: int = 100 ) -> list[TokenOverage]: """ Check a single entry for token limit violations. Args: entry: Dictionary with 'ref', 'he', 'en', 'category' keys encoding: tiktoken encoding to use max_tokens: Maximum allowed tokens preview_length: Number of characters to include in preview Returns: List of TokenOverage objects for any fields exceeding the limit """ overages = [] ref = entry.get("ref", "unknown") category = entry.get("category", "unknown") for field in ["he", "en"]: text = entry.get(field, "") if not text: continue token_count = count_tokens(text, encoding) if token_count > max_tokens: preview = text[:preview_length] + "..." if len(text) > preview_length else text overages.append(TokenOverage( ref=ref, category=category, field=field, token_count=token_count, char_count=len(text), text_preview=preview )) return overages def check_benchmark_data( data_path: str, max_tokens: int = MAX_TOKENS, verbose: bool = False ) -> tuple[list[TokenOverage], dict]: """ Check all entries in the benchmark dataset for token limit violations. Args: data_path: Path to the benchmark JSON file max_tokens: Maximum allowed tokens (default: 8192) verbose: Print progress information Returns: Tuple of (list of overages, statistics dict) """ # Load the encoding if verbose: print(f"Loading tokenizer: {ENCODING_NAME}") encoding = tiktoken.get_encoding(ENCODING_NAME) # Load the data if verbose: print(f"Loading data from: {data_path}") with open(data_path, "r", encoding="utf-8") as f: data = json.load(f) if verbose: print(f"Checking {len(data)} entries for token limit ({max_tokens} tokens)...") # Check all entries all_overages = [] token_counts_he = [] token_counts_en = [] for i, entry in enumerate(data): if verbose and (i + 1) % 1000 == 0: print(f" Processed {i + 1}/{len(data)} entries...") # Count tokens for statistics he_text = entry.get("he", "") en_text = entry.get("en", "") if he_text: token_counts_he.append(count_tokens(he_text, encoding)) if en_text: token_counts_en.append(count_tokens(en_text, encoding)) # Check for overages overages = check_entry(entry, encoding, max_tokens) all_overages.extend(overages) # Compute statistics stats = { "total_entries": len(data), "entries_with_overages": len(set(o.ref for o in all_overages)), "total_overages": len(all_overages), "he_overages": len([o for o in all_overages if o.field == "he"]), "en_overages": len([o for o in all_overages if o.field == "en"]), "max_tokens_checked": max_tokens, } if token_counts_he: stats["he_token_stats"] = { "min": min(token_counts_he), "max": max(token_counts_he), "avg": sum(token_counts_he) / len(token_counts_he), "total_entries": len(token_counts_he), } if token_counts_en: stats["en_token_stats"] = { "min": min(token_counts_en), "max": max(token_counts_en), "avg": sum(token_counts_en) / len(token_counts_en), "total_entries": len(token_counts_en), } return all_overages, stats def print_report(overages: list[TokenOverage], stats: dict) -> None: """Print a formatted report of token limit violations.""" print("\n" + "=" * 70) print("TOKEN LIMIT CHECK REPORT") print("=" * 70) print(f"\nDataset Summary:") print(f" Total entries checked: {stats['total_entries']:,}") print(f" Token limit: {stats['max_tokens_checked']:,}") if "he_token_stats" in stats: he_stats = stats["he_token_stats"] print(f"\nHebrew/Aramaic Token Statistics:") print(f" Min tokens: {he_stats['min']:,}") print(f" Max tokens: {he_stats['max']:,}") print(f" Avg tokens: {he_stats['avg']:.1f}") if "en_token_stats" in stats: en_stats = stats["en_token_stats"] print(f"\nEnglish Token Statistics:") print(f" Min tokens: {en_stats['min']:,}") print(f" Max tokens: {en_stats['max']:,}") print(f" Avg tokens: {en_stats['avg']:.1f}") print(f"\nOverage Summary:") print(f" Entries exceeding limit: {stats['entries_with_overages']:,}") print(f" Total field overages: {stats['total_overages']:,}") print(f" - Hebrew/Aramaic fields: {stats['he_overages']:,}") print(f" - English fields: {stats['en_overages']:,}") if overages: print("\n" + "-" * 70) print("FLAGGED ENTRIES (exceeding token limit):") print("-" * 70) # Group by category by_category = {} for overage in overages: if overage.category not in by_category: by_category[overage.category] = [] by_category[overage.category].append(overage) for category, category_overages in sorted(by_category.items()): print(f"\n[{category}] - {len(category_overages)} overage(s)") for overage in category_overages: print(f"\n Reference: {overage.ref}") print(f" Field: {overage.field}") print(f" Token count: {overage.token_count:,} (limit: {stats['max_tokens_checked']:,})") print(f" Character count: {overage.char_count:,}") print(f" Preview: {overage.text_preview}") else: print("\n✓ No entries exceed the token limit!") print("\n" + "=" * 70) def save_report( overages: list[TokenOverage], stats: dict, output_path: str ) -> None: """Save the report to a JSON file.""" report = { "stats": stats, "overages": [ { "ref": o.ref, "category": o.category, "field": o.field, "token_count": o.token_count, "char_count": o.char_count, "text_preview": o.text_preview, } for o in overages ] } with open(output_path, "w", encoding="utf-8") as f: json.dump(report, f, ensure_ascii=False, indent=2) print(f"\nReport saved to: {output_path}") def main(): parser = argparse.ArgumentParser( description="Check benchmark data for entries exceeding OpenAI embedding token limits." ) parser.add_argument( "--data", "-d", type=str, default="benchmark_data/benchmark.json", help="Path to the benchmark JSON file (default: benchmark_data/benchmark.json)" ) parser.add_argument( "--max-tokens", "-m", type=int, default=MAX_TOKENS, help=f"Maximum allowed tokens (default: {MAX_TOKENS})" ) parser.add_argument( "--output", "-o", type=str, help="Path to save JSON report (optional)" ) parser.add_argument( "--verbose", "-v", action="store_true", help="Print progress information" ) args = parser.parse_args() # Check if data file exists if not Path(args.data).exists(): print(f"Error: Data file not found: {args.data}") return 1 # Run the check overages, stats = check_benchmark_data( args.data, max_tokens=args.max_tokens, verbose=args.verbose ) # Print report print_report(overages, stats) # Save report if requested if args.output: save_report(overages, stats, args.output) # Return exit code based on whether overages were found return 1 if overages else 0 if __name__ == "__main__": exit(main())