| import os |
| import time |
| import json |
| import argparse |
| import multiprocessing |
| from collections import defaultdict |
| from tqdm import tqdm |
|
|
| def get_stats_chunk(ids): |
| counts = defaultdict(int) |
| for pair in zip(ids, ids[1:]): |
| counts[pair] += 1 |
| return counts |
|
|
|
|
| def merge_chunk(args): |
| ids, pair, idx = args |
| new_ids = [] |
| i = 0 |
| while i < len(ids): |
| if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]: |
| new_ids.append(idx) |
| i += 2 |
| else: |
| new_ids.append(ids[i]) |
| i += 1 |
| return new_ids |
|
|
|
|
| class ParallelBPETokenizer: |
| def __init__(self): |
| self.merges = {} |
| self.vocab = {i: bytes([i]) for i in range(256)} |
|
|
| def train(self, text, vocab_size, pct_bpe=1.0, workers=None, verbose=True): |
| assert vocab_size >= 256 |
| num_merges = vocab_size - 256 |
|
|
| if verbose: |
| print("Pre-processing text...") |
| text_subset = text[: int(len(text) * max(0.01, min(1.0, pct_bpe)))] |
| ids = list(text_subset.encode("utf-8")) |
|
|
| num_procs = workers if workers is not None else max(1, (os.cpu_count() or 4) - 1) |
| chunk_len = max(1, len(ids) // num_procs) |
| chunks = [ids[i : i + chunk_len] for i in range(0, len(ids), chunk_len)] |
|
|
| if verbose: |
| print(f"Using {num_procs} workers for {len(ids)} bytes...") |
|
|
| with multiprocessing.Pool(num_procs) as pool: |
| for i in tqdm(range(num_merges), desc="Training BPE"): |
| chunk_stats = pool.map(get_stats_chunk, chunks) |
|
|
| totals = defaultdict(int) |
| for stat in chunk_stats: |
| for pair, count in stat.items(): |
| totals[pair] += count |
|
|
| if not totals: |
| break |
|
|
| best_pair = max(totals, key=totals.get) |
| new_idx = 256 + i |
|
|
| merge_args = [(chunk, best_pair, new_idx) for chunk in chunks] |
| chunks = pool.map(merge_chunk, merge_args) |
|
|
| self.merges[best_pair] = new_idx |
| self.vocab[new_idx] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]] |
|
|
| if verbose and i % 20 == 0: |
| try: |
| decoded = self.vocab[new_idx].decode("utf-8") |
| tqdm.write(f"Merged {best_pair} -> {new_idx} ('{decoded}')") |
| except Exception: |
| pass |
|
|
| if verbose: |
| print(f"Training complete. Vocab size: {len(self.vocab)}") |
| return self.merges |
|
|
| def save(self, filename): |
| save_merges = {f"{p[0]},{p[1]}": idx for p, idx in self.merges.items()} |
| save_vocab = {idx: b.decode("latin1") for idx, b in self.vocab.items()} |
|
|
| with open(filename, "w", encoding="utf-8") as f: |
| json.dump({"merges": save_merges, "vocab": save_vocab}, f) |
| print(f"Saved tokenizer to {filename}") |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Multi-core BPE trainer") |
| p.add_argument("--input", default=os.path.join("data", "jarvis_train.txt")) |
| p.add_argument("--vocab-size", type=int, default=2048) |
| p.add_argument("--pct-bpe", type=float, default=1.0) |
| p.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 1)) |
| p.add_argument("--output", default="bpe_vocab.json") |
| return p.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| multiprocessing.freeze_support() |
| args = parse_args() |
| print("--- MULTI-CORE BPE TRAINER ---") |
|
|
| if not os.path.exists(args.input): |
| print(f"Error: {args.input} not found.") |
| raise SystemExit(1) |
|
|
| with open(args.input, "r", encoding="utf-8", errors="ignore") as f: |
| text = f.read() |
|
|
| tokenizer = ParallelBPETokenizer() |
| start_time = time.time() |
| tokenizer.train( |
| text, |
| vocab_size=args.vocab_size, |
| pct_bpe=args.pct_bpe, |
| workers=args.workers, |
| ) |
| end_time = time.time() |
|
|
| print(f"Total time: {end_time - start_time:.2f} seconds") |
| tokenizer.save(args.output) |
|
|