AI / scripts /parallel_bpe.py
shsplas's picture
Upload 15 files
c5f49b9 verified
Raw
History Blame Contribute Delete
4.08 kB
import os
import time
import json
import argparse
import multiprocessing
from collections import defaultdict
from tqdm import tqdm
def get_stats_chunk(ids):
counts = defaultdict(int)
for pair in zip(ids, ids[1:]):
counts[pair] += 1
return counts
def merge_chunk(args):
ids, pair, idx = args
new_ids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
new_ids.append(idx)
i += 2
else:
new_ids.append(ids[i])
i += 1
return new_ids
class ParallelBPETokenizer:
def __init__(self):
self.merges = {}
self.vocab = {i: bytes([i]) for i in range(256)}
def train(self, text, vocab_size, pct_bpe=1.0, workers=None, verbose=True):
assert vocab_size >= 256
num_merges = vocab_size - 256
if verbose:
print("Pre-processing text...")
text_subset = text[: int(len(text) * max(0.01, min(1.0, pct_bpe)))]
ids = list(text_subset.encode("utf-8"))
num_procs = workers if workers is not None else max(1, (os.cpu_count() or 4) - 1)
chunk_len = max(1, len(ids) // num_procs)
chunks = [ids[i : i + chunk_len] for i in range(0, len(ids), chunk_len)]
if verbose:
print(f"Using {num_procs} workers for {len(ids)} bytes...")
with multiprocessing.Pool(num_procs) as pool:
for i in tqdm(range(num_merges), desc="Training BPE"):
chunk_stats = pool.map(get_stats_chunk, chunks)
totals = defaultdict(int)
for stat in chunk_stats:
for pair, count in stat.items():
totals[pair] += count
if not totals:
break
best_pair = max(totals, key=totals.get)
new_idx = 256 + i
merge_args = [(chunk, best_pair, new_idx) for chunk in chunks]
chunks = pool.map(merge_chunk, merge_args)
self.merges[best_pair] = new_idx
self.vocab[new_idx] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
if verbose and i % 20 == 0:
try:
decoded = self.vocab[new_idx].decode("utf-8")
tqdm.write(f"Merged {best_pair} -> {new_idx} ('{decoded}')")
except Exception:
pass
if verbose:
print(f"Training complete. Vocab size: {len(self.vocab)}")
return self.merges
def save(self, filename):
save_merges = {f"{p[0]},{p[1]}": idx for p, idx in self.merges.items()}
save_vocab = {idx: b.decode("latin1") for idx, b in self.vocab.items()}
with open(filename, "w", encoding="utf-8") as f:
json.dump({"merges": save_merges, "vocab": save_vocab}, f)
print(f"Saved tokenizer to {filename}")
def parse_args():
p = argparse.ArgumentParser(description="Multi-core BPE trainer")
p.add_argument("--input", default=os.path.join("data", "jarvis_train.txt"))
p.add_argument("--vocab-size", type=int, default=2048)
p.add_argument("--pct-bpe", type=float, default=1.0)
p.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 1))
p.add_argument("--output", default="bpe_vocab.json")
return p.parse_args()
if __name__ == "__main__":
multiprocessing.freeze_support()
args = parse_args()
print("--- MULTI-CORE BPE TRAINER ---")
if not os.path.exists(args.input):
print(f"Error: {args.input} not found.")
raise SystemExit(1)
with open(args.input, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
tokenizer = ParallelBPETokenizer()
start_time = time.time()
tokenizer.train(
text,
vocab_size=args.vocab_size,
pct_bpe=args.pct_bpe,
workers=args.workers,
)
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")
tokenizer.save(args.output)