""" BPE tokenizer for resonance-200m. Uses HuggingFace tokenizers (Rust backend) for fast training + encoding. Saves merge rules in binary format compatible with C inference. Replaces naive Python BPE (O(n²) per merge = days on 200MB). Rust backend: minutes. """ import struct import os import json import numpy as np def _byte_to_unicode(): """GPT-2 byte-to-unicode mapping (ByteLevel pre-tokenizer).""" bs = (list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))) cs = bs[:] n = 0 for b in range(256): if b not in bs: bs.append(b) cs.append(256 + n) n += 1 return {b: chr(c) for b, c in zip(bs, cs)} class BPETokenizer: """BPE tokenizer. 256 byte tokens + learned merges. Rust backend for speed. Binary format for C inference.""" def __init__(self, max_merges=15936): self.max_merges = max_merges self.merges = [] # (a, b, new_id) — C format self.vocab_size = 256 self._hf_tok = None self._remap_lut = None # numpy LUT: HF_id → our_id def train(self, text_bytes, num_merges=None, report_every=2000): """Learn BPE merges using Rust backend. Minutes, not days.""" from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders if num_merges is None: num_merges = self.max_merges num_merges = min(num_merges, self.max_merges) target_vocab = 256 + num_merges print(f" [BPE] Training {num_merges} merges on {len(text_bytes)} bytes (Rust backend)...") tok = Tokenizer(models.BPE()) tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) tok.decoder = decoders.ByteLevel() trainer = trainers.BpeTrainer( vocab_size=target_vocab, min_frequency=2, special_tokens=[], initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), show_progress=True, ) text = text_bytes.decode('utf-8', errors='replace') lines = text.split('\n') del text tok.train_from_iterator(lines, trainer=trainer) del lines self._hf_tok = tok # Extract merges in our (a, b, new_id) format for C inference data = json.loads(tok.to_str()) hf_merges = data['model']['merges'] hf_vocab = data['model']['vocab'] b2u = _byte_to_unicode() # str → our_id mapping for merge conversion str_to_our = {} for bv in range(256): str_to_our[b2u[bv]] = bv self.merges = [] for i, ms in enumerate(hf_merges): if i >= num_merges: break # HF tokenizers >=0.20 returns lists ['a','b'], older returns "a b" if isinstance(ms, list): if len(ms) != 2: continue a_str, b_str = ms[0], ms[1] else: parts = ms.split(' ', 1) if len(parts) != 2: continue a_str, b_str = parts[0], parts[1] if a_str not in str_to_our or b_str not in str_to_our: continue a_id = str_to_our[a_str] b_id = str_to_our[b_str] new_id = 256 + len(self.merges) self.merges.append((a_id, b_id, new_id)) str_to_our[a_str + b_str] = new_id if (i + 1) % report_every == 0: print(f" [BPE] {i + 1}/{len(hf_merges)} merges converted") self.vocab_size = 256 + len(self.merges) # Build HF→our remap LUT (numpy vectorized lookup) hf_to_our = {} for bv in range(256): uc = b2u[bv] if uc in hf_vocab: hf_to_our[hf_vocab[uc]] = bv for tok_str, our_id in str_to_our.items(): if tok_str in hf_vocab and our_id >= 256: hf_to_our[hf_vocab[tok_str]] = our_id max_hf = max(hf_to_our.keys()) + 1 if hf_to_our else 256 self._remap_lut = np.arange(max_hf, dtype=np.int32) for hf_id, our_id in hf_to_our.items(): self._remap_lut[hf_id] = our_id self._hf_to_our = hf_to_our print(f" [BPE] Done: {len(self.merges)} merges, vocab={self.vocab_size}") def encode(self, text): """Encode text to our token IDs. Fast (Rust + numpy remap).""" if isinstance(text, bytes): text = text.decode('utf-8', errors='replace') if self._hf_tok is not None and self._remap_lut is not None: hf_ids = np.array(self._hf_tok.encode(text).ids, dtype=np.int32) return self._remap_lut[hf_ids].tolist() # Slow fallback (binary-only load, no HF JSON) if isinstance(text, str): text = text.encode('utf-8', errors='replace') ids = list(text) for a, b, new_id in self.merges: new_ids = [] i = 0 while i < len(ids): if i < len(ids) - 1 and ids[i] == a and ids[i + 1] == b: new_ids.append(new_id) i += 2 else: new_ids.append(ids[i]) i += 1 ids = new_ids return ids def decode(self, ids): """Decode token IDs to bytes.""" vocab = {} for i in range(256): vocab[i] = bytes([i]) for a, b, new_id in self.merges: vocab[new_id] = vocab[a] + vocab[b] out = b'' for tid in ids: out += vocab.get(tid, b'?') return out def save(self, path): """Save binary merges (C) + HF JSON + ID map.""" with open(path, 'wb') as f: f.write(struct.pack('