Spaces:
No application file
No application file
File size: 3,660 Bytes
4f2b2f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from datasets import load_dataset
from transformers import GPTNeoXTokenizerFast
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import os
# Parameters
MAX_EXAMPLES = 4_096_000
LOG_INTERVAL = 1_000_000
BATCH_SIZE = 1000
SAVE_DIR = "length_dist_plots-2"
# Create output directory
os.makedirs(SAVE_DIR, exist_ok=True)
# Load full dataset (cached)
dataset = load_dataset("mlfoundations/dclm-baseline-1.0", split="train", num_proc=64)
dataset = dataset.select(range(MAX_EXAMPLES))
# Load tokenizer
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
# Initialize counter and length list
counter = Counter()
length_list = []
# Process in batches
for start in range(0, MAX_EXAMPLES, BATCH_SIZE):
end = min(start + BATCH_SIZE, MAX_EXAMPLES)
batch = dataset[start:end]["text"]
encodings = tokenizer(batch, truncation=False, add_special_tokens=False)
lengths = [len(ids) for ids in encodings["input_ids"]]
counter.update(lengths)
length_list.extend(lengths)
# Save plot every 1M
if (end % LOG_INTERVAL == 0) or (end == MAX_EXAMPLES):
count_millions = end // 1_000_000
x, y = zip(*sorted(counter.items()))
# Compute percentiles for current subset
lengths_np = np.array(length_list)
results = {
p: int(np.percentile(lengths_np, p))
for p in [10, 20, 30, 40, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 99]
}
# Plot
plt.figure(figsize=(12, 6))
plt.bar(x, y, color="skyblue")
plt.yscale("log")
plt.xscale("log")
plt.xlabel("Token Length")
plt.ylabel("Frequency (log scale)")
plt.title(f"Token Length Distribution (Up to {end:,} Examples)")
plt.grid(True, linestyle="--", alpha=0.5)
# Annotate percentiles
for p in [50, 60, 70, 80, 90, 95, 99]:
val = results[p]
plt.axvline(val, color="red", linestyle="--", linewidth=1.5)
plt.text(
val + 10,
max(y) / 10,
f"{p}%",
rotation=90,
color="red",
fontsize=10,
verticalalignment="center",
)
plt.tight_layout()
filename = os.path.join(SAVE_DIR, f"length_dist_{count_millions}M.png")
plt.savefig(filename)
plt.close()
print(f"✅ Saved plot: {filename}")
# --- Final Percentiles ---
print("\n📊 Computing Final Percentiles...")
lengths_np = np.array(length_list)
all_percentiles = [0, 1, 5, 10, 25, 50, 75, 90, 95, 99, 100]
final_results = {p: int(np.percentile(lengths_np, p)) for p in all_percentiles}
print("\n📊 Token Length Percentiles:")
for p in all_percentiles:
print(f" {p:>3}%: {final_results[p]:,} tokens")
# --- Save final plot with annotations ---
x, y = zip(*sorted(counter.items()))
plt.figure(figsize=(12, 6))
plt.bar(x, y, width=5, color="skyblue", edgecolor="black")
plt.yscale("log")
plt.xscale("log")
plt.xlabel("Token Length")
plt.ylabel("Frequency (log scale)")
plt.title("Final Token Length Distribution (4.096M Samples)")
plt.grid(True, linestyle="--", alpha=0.5)
for p in [50, 95, 99]:
val = final_results[p]
plt.axvline(val, color="red", linestyle="--", linewidth=1.5)
plt.text(
val + 10,
max(y) / 10,
f"{p}%",
rotation=90,
color="red",
fontsize=10,
verticalalignment="center",
)
plt.tight_layout()
final_path = os.path.join(SAVE_DIR, "length_dist_final_annotated.png")
plt.savefig(final_path)
plt.close()
print(f"\n✅ Final annotated plot saved: {final_path}")
|