File size: 3,572 Bytes
4f2b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import os

# Parameters
LOG_INTERVAL = 1_000_000
BATCH_SIZE = 1000
SAVE_DIR = "length_dist_plots-2"

# Create output directory
os.makedirs(SAVE_DIR, exist_ok=True)

# Load full dataset (cached)
dataset = load_dataset("openwebtext", split="train", num_proc=64)
MAX_EXAMPLES = len(dataset)

# Load tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# Initialize counter and length list
counter = Counter()
length_list = []

# Process in batches
for start in range(0, MAX_EXAMPLES, BATCH_SIZE):
    end = min(start + BATCH_SIZE, MAX_EXAMPLES)
    batch = dataset[start:end]["text"]
    encodings = tokenizer(batch, truncation=False, add_special_tokens=False)
    lengths = [len(ids) for ids in encodings["input_ids"]]
    counter.update(lengths)
    length_list.extend(lengths)

    # Save plot every 1M
    if (end % LOG_INTERVAL == 0) or (end == MAX_EXAMPLES):
        count_millions = end // 1_000_000
        x, y = zip(*sorted(counter.items()))

        # Compute percentiles for current subset
        lengths_np = np.array(length_list)
        results = {
            p: int(np.percentile(lengths_np, p))
            for p in [10, 20, 30, 40, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 99]
        }

        # Plot
        plt.figure(figsize=(12, 6))
        plt.bar(x, y, color="skyblue")
        plt.yscale("log")
        plt.xscale("log")
        plt.xlabel("Token Length")
        plt.ylabel("Frequency (log scale)")
        plt.title(f"Token Length Distribution (Up to {end:,} Examples)")
        plt.grid(True, linestyle="--", alpha=0.5)

        # Annotate percentiles
        for p in [50, 60, 70, 80, 90, 95, 99]:
            val = results[p]
            plt.axvline(val, color="red", linestyle="--", linewidth=1.5)
            plt.text(
                val + 10,
                max(y) / 10,
                f"{p}%",
                rotation=90,
                color="red",
                fontsize=10,
                verticalalignment="center",
            )

        plt.tight_layout()
        filename = os.path.join(SAVE_DIR, f"length_dist_{count_millions}M.png")
        plt.savefig(filename)
        plt.close()
        print(f"✅ Saved plot: {filename}")

# --- Final Percentiles ---
print("\n📊 Computing Final Percentiles...")

lengths_np = np.array(length_list)
all_percentiles = [0, 1, 5, 10, 25, 50, 75, 90, 95, 99, 100]
final_results = {p: int(np.percentile(lengths_np, p)) for p in all_percentiles}

print("\n📊 Token Length Percentiles:")
for p in all_percentiles:
    print(f"  {p:>3}%: {final_results[p]:,} tokens")

# --- Save final plot with annotations ---
x, y = zip(*sorted(counter.items()))
plt.figure(figsize=(12, 6))
plt.bar(x, y, width=5, color="skyblue", edgecolor="black")
plt.yscale("log")
plt.xscale("log")
plt.xlabel("Token Length")
plt.ylabel("Frequency (log scale)")
plt.title("Final Token Length Distribution (4.096M Samples)")
plt.grid(True, linestyle="--", alpha=0.5)

for p in [50, 95, 99]:
    val = final_results[p]
    plt.axvline(val, color="red", linestyle="--", linewidth=1.5)
    plt.text(
        val + 10,
        max(y) / 10,
        f"{p}%",
        rotation=90,
        color="red",
        fontsize=10,
        verticalalignment="center",
    )

plt.tight_layout()
final_path = os.path.join(SAVE_DIR, "length_dist_final_annotated.png")
plt.savefig(final_path)
plt.close()
print(f"\n✅ Final annotated plot saved: {final_path}")