File size: 4,082 Bytes
c5f49b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | import os
import time
import json
import argparse
import multiprocessing
from collections import defaultdict
from tqdm import tqdm
def get_stats_chunk(ids):
counts = defaultdict(int)
for pair in zip(ids, ids[1:]):
counts[pair] += 1
return counts
def merge_chunk(args):
ids, pair, idx = args
new_ids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
new_ids.append(idx)
i += 2
else:
new_ids.append(ids[i])
i += 1
return new_ids
class ParallelBPETokenizer:
def __init__(self):
self.merges = {}
self.vocab = {i: bytes([i]) for i in range(256)}
def train(self, text, vocab_size, pct_bpe=1.0, workers=None, verbose=True):
assert vocab_size >= 256
num_merges = vocab_size - 256
if verbose:
print("Pre-processing text...")
text_subset = text[: int(len(text) * max(0.01, min(1.0, pct_bpe)))]
ids = list(text_subset.encode("utf-8"))
num_procs = workers if workers is not None else max(1, (os.cpu_count() or 4) - 1)
chunk_len = max(1, len(ids) // num_procs)
chunks = [ids[i : i + chunk_len] for i in range(0, len(ids), chunk_len)]
if verbose:
print(f"Using {num_procs} workers for {len(ids)} bytes...")
with multiprocessing.Pool(num_procs) as pool:
for i in tqdm(range(num_merges), desc="Training BPE"):
chunk_stats = pool.map(get_stats_chunk, chunks)
totals = defaultdict(int)
for stat in chunk_stats:
for pair, count in stat.items():
totals[pair] += count
if not totals:
break
best_pair = max(totals, key=totals.get)
new_idx = 256 + i
merge_args = [(chunk, best_pair, new_idx) for chunk in chunks]
chunks = pool.map(merge_chunk, merge_args)
self.merges[best_pair] = new_idx
self.vocab[new_idx] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
if verbose and i % 20 == 0:
try:
decoded = self.vocab[new_idx].decode("utf-8")
tqdm.write(f"Merged {best_pair} -> {new_idx} ('{decoded}')")
except Exception:
pass
if verbose:
print(f"Training complete. Vocab size: {len(self.vocab)}")
return self.merges
def save(self, filename):
save_merges = {f"{p[0]},{p[1]}": idx for p, idx in self.merges.items()}
save_vocab = {idx: b.decode("latin1") for idx, b in self.vocab.items()}
with open(filename, "w", encoding="utf-8") as f:
json.dump({"merges": save_merges, "vocab": save_vocab}, f)
print(f"Saved tokenizer to {filename}")
def parse_args():
p = argparse.ArgumentParser(description="Multi-core BPE trainer")
p.add_argument("--input", default=os.path.join("data", "jarvis_train.txt"))
p.add_argument("--vocab-size", type=int, default=2048)
p.add_argument("--pct-bpe", type=float, default=1.0)
p.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 1))
p.add_argument("--output", default="bpe_vocab.json")
return p.parse_args()
if __name__ == "__main__":
multiprocessing.freeze_support()
args = parse_args()
print("--- MULTI-CORE BPE TRAINER ---")
if not os.path.exists(args.input):
print(f"Error: {args.input} not found.")
raise SystemExit(1)
with open(args.input, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
tokenizer = ParallelBPETokenizer()
start_time = time.time()
tokenizer.train(
text,
vocab_size=args.vocab_size,
pct_bpe=args.pct_bpe,
workers=args.workers,
)
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")
tokenizer.save(args.output)
|