File size: 4,082 Bytes
c5f49b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import time
import json
import argparse
import multiprocessing
from collections import defaultdict
from tqdm import tqdm

def get_stats_chunk(ids):
    counts = defaultdict(int)
    for pair in zip(ids, ids[1:]):
        counts[pair] += 1
    return counts


def merge_chunk(args):
    ids, pair, idx = args
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids


class ParallelBPETokenizer:
    def __init__(self):
        self.merges = {}
        self.vocab = {i: bytes([i]) for i in range(256)}

    def train(self, text, vocab_size, pct_bpe=1.0, workers=None, verbose=True):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        if verbose:
            print("Pre-processing text...")
        text_subset = text[: int(len(text) * max(0.01, min(1.0, pct_bpe)))]
        ids = list(text_subset.encode("utf-8"))

        num_procs = workers if workers is not None else max(1, (os.cpu_count() or 4) - 1)
        chunk_len = max(1, len(ids) // num_procs)
        chunks = [ids[i : i + chunk_len] for i in range(0, len(ids), chunk_len)]

        if verbose:
            print(f"Using {num_procs} workers for {len(ids)} bytes...")

        with multiprocessing.Pool(num_procs) as pool:
            for i in tqdm(range(num_merges), desc="Training BPE"):
                chunk_stats = pool.map(get_stats_chunk, chunks)

                totals = defaultdict(int)
                for stat in chunk_stats:
                    for pair, count in stat.items():
                        totals[pair] += count

                if not totals:
                    break

                best_pair = max(totals, key=totals.get)
                new_idx = 256 + i

                merge_args = [(chunk, best_pair, new_idx) for chunk in chunks]
                chunks = pool.map(merge_chunk, merge_args)

                self.merges[best_pair] = new_idx
                self.vocab[new_idx] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]

                if verbose and i % 20 == 0:
                    try:
                        decoded = self.vocab[new_idx].decode("utf-8")
                        tqdm.write(f"Merged {best_pair} -> {new_idx} ('{decoded}')")
                    except Exception:
                        pass

        if verbose:
            print(f"Training complete. Vocab size: {len(self.vocab)}")
        return self.merges

    def save(self, filename):
        save_merges = {f"{p[0]},{p[1]}": idx for p, idx in self.merges.items()}
        save_vocab = {idx: b.decode("latin1") for idx, b in self.vocab.items()}

        with open(filename, "w", encoding="utf-8") as f:
            json.dump({"merges": save_merges, "vocab": save_vocab}, f)
        print(f"Saved tokenizer to {filename}")


def parse_args():
    p = argparse.ArgumentParser(description="Multi-core BPE trainer")
    p.add_argument("--input", default=os.path.join("data", "jarvis_train.txt"))
    p.add_argument("--vocab-size", type=int, default=2048)
    p.add_argument("--pct-bpe", type=float, default=1.0)
    p.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 1))
    p.add_argument("--output", default="bpe_vocab.json")
    return p.parse_args()


if __name__ == "__main__":
    multiprocessing.freeze_support()
    args = parse_args()
    print("--- MULTI-CORE BPE TRAINER ---")

    if not os.path.exists(args.input):
        print(f"Error: {args.input} not found.")
        raise SystemExit(1)

    with open(args.input, "r", encoding="utf-8", errors="ignore") as f:
        text = f.read()

    tokenizer = ParallelBPETokenizer()
    start_time = time.time()
    tokenizer.train(
        text,
        vocab_size=args.vocab_size,
        pct_bpe=args.pct_bpe,
        workers=args.workers,
    )
    end_time = time.time()

    print(f"Total time: {end_time - start_time:.2f} seconds")
    tokenizer.save(args.output)