from datasets import load_dataset from transformers import GPTNeoXTokenizerFast from collections import Counter import matplotlib.pyplot as plt import numpy as np import os # Parameters MAX_EXAMPLES = 4_096_000 LOG_INTERVAL = 1_000_000 BATCH_SIZE = 1000 SAVE_DIR = "length_dist_plots-2" # Create output directory os.makedirs(SAVE_DIR, exist_ok=True) # Load full dataset (cached) dataset = load_dataset("mlfoundations/dclm-baseline-1.0", split="train", num_proc=64) dataset = dataset.select(range(MAX_EXAMPLES)) # Load tokenizer tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") # Initialize counter and length list counter = Counter() length_list = [] # Process in batches for start in range(0, MAX_EXAMPLES, BATCH_SIZE): end = min(start + BATCH_SIZE, MAX_EXAMPLES) batch = dataset[start:end]["text"] encodings = tokenizer(batch, truncation=False, add_special_tokens=False) lengths = [len(ids) for ids in encodings["input_ids"]] counter.update(lengths) length_list.extend(lengths) # Save plot every 1M if (end % LOG_INTERVAL == 0) or (end == MAX_EXAMPLES): count_millions = end // 1_000_000 x, y = zip(*sorted(counter.items())) # Compute percentiles for current subset lengths_np = np.array(length_list) results = { p: int(np.percentile(lengths_np, p)) for p in [10, 20, 30, 40, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 99] } # Plot plt.figure(figsize=(12, 6)) plt.bar(x, y, color="skyblue") plt.yscale("log") plt.xscale("log") plt.xlabel("Token Length") plt.ylabel("Frequency (log scale)") plt.title(f"Token Length Distribution (Up to {end:,} Examples)") plt.grid(True, linestyle="--", alpha=0.5) # Annotate percentiles for p in [50, 60, 70, 80, 90, 95, 99]: val = results[p] plt.axvline(val, color="red", linestyle="--", linewidth=1.5) plt.text( val + 10, max(y) / 10, f"{p}%", rotation=90, color="red", fontsize=10, verticalalignment="center", ) plt.tight_layout() filename = os.path.join(SAVE_DIR, f"length_dist_{count_millions}M.png") plt.savefig(filename) plt.close() print(f"āœ… Saved plot: {filename}") # --- Final Percentiles --- print("\nšŸ“Š Computing Final Percentiles...") lengths_np = np.array(length_list) all_percentiles = [0, 1, 5, 10, 25, 50, 75, 90, 95, 99, 100] final_results = {p: int(np.percentile(lengths_np, p)) for p in all_percentiles} print("\nšŸ“Š Token Length Percentiles:") for p in all_percentiles: print(f" {p:>3}%: {final_results[p]:,} tokens") # --- Save final plot with annotations --- x, y = zip(*sorted(counter.items())) plt.figure(figsize=(12, 6)) plt.bar(x, y, width=5, color="skyblue", edgecolor="black") plt.yscale("log") plt.xscale("log") plt.xlabel("Token Length") plt.ylabel("Frequency (log scale)") plt.title("Final Token Length Distribution (4.096M Samples)") plt.grid(True, linestyle="--", alpha=0.5) for p in [50, 95, 99]: val = final_results[p] plt.axvline(val, color="red", linestyle="--", linewidth=1.5) plt.text( val + 10, max(y) / 10, f"{p}%", rotation=90, color="red", fontsize=10, verticalalignment="center", ) plt.tight_layout() final_path = os.path.join(SAVE_DIR, "length_dist_final_annotated.png") plt.savefig(final_path) plt.close() print(f"\nāœ… Final annotated plot saved: {final_path}")