"""Token-level N-gram model for context mixing. Maintains order-1 through order-N context tables with interpolated backoff smoothing. Used as a fast, lightweight predictor alongside the LLM in an ensemble. All operations are deterministic for lossless codec symmetry. Uses flat numpy arrays for inner storage instead of nested Python dicts. This eliminates millions of small dict objects (~3.6 GB → ~1 GB per worker) and replaces the O(K) Python iteration loop in predict() with a single numpy fancy-indexing call that runs at C level, drastically reducing GIL hold time with 8+ threads. """ import numpy as np def _context_hash(context_tokens, order): """Deterministic 64-bit hash of the last *order* tokens. Replaces tuple(context_tokens[-order:]) as dict key, eliminating per-token tuple allocations and reducing GC pressure. """ h = 0 end = len(context_tokens) for i in range(end - order, end): h = (h * 49157 + context_tokens[i]) & 0xFFFFFFFFFFFFFFFF return h class NgramModel: """Interpolated N-gram model operating on token IDs. Uses iterative interpolation: higher-order models are progressively blended with lower orders, weighted by context frequency. Unseen contexts fall back smoothly to lower orders down to unigram. The model updates online after each observed token, so it adapts to the specific document being compressed. Inner storage uses flat numpy arrays indexed by slot number. The outer dict (context_hash → slot) preserves insertion order for deterministic FIFO eviction. """ # Smoothing constant for interpolation weights. ESCAPE = 5 # Maximum context entries per order. MAX_TABLE_ENTRIES = 500_000 # Maximum unique continuations per context. MAX_INNER_ENTRIES = 64 def __init__(self, max_order: int = 4, vocab_size: int = 49152): self.max_order = max_order self.vocab_size = vocab_size # Order-0 (unigram) counts: dense array for fast vector ops. self.unigram_counts = np.zeros(vocab_size, dtype=np.float64) self.total_unigram = 0 # Order 1..N: context_hash → slot_index. # Python dict preserves insertion order for FIFO eviction. self._slot_map: list = [None] + [dict() for _ in range(max_order)] # Flat inner storage per order. Each context maps to a "slot" # containing up to MAX_INNER_ENTRIES (token_id, count) pairs. # Entries within a slot are kept in insertion order so that # argmin tie-breaking matches the old dict-based behavior. self._inner_ids: list = [None] + [ np.empty((self.MAX_TABLE_ENTRIES, self.MAX_INNER_ENTRIES), dtype=np.int32) for _ in range(max_order) ] self._inner_counts: list = [None] + [ np.empty((self.MAX_TABLE_ENTRIES, self.MAX_INNER_ENTRIES), dtype=np.int32) for _ in range(max_order) ] self._inner_sizes: list = [None] + [ np.zeros(self.MAX_TABLE_ENTRIES, dtype=np.int16) for _ in range(max_order) ] self._ctx_totals: list = [None] + [ np.zeros(self.MAX_TABLE_ENTRIES, dtype=np.int32) for _ in range(max_order) ] # Slot allocation: sequential counter + free list for recycling. self._next_slot = [0] * (max_order + 1) self._free_slots: list = [None] + [[] for _ in range(max_order)] # Pre-allocated buffers for building order predictions. self._buf = np.zeros(vocab_size, dtype=np.float64) self._probs = np.zeros(vocab_size, dtype=np.float64) def reset(self): """Reset all counts. Call when starting a new sequence.""" self.unigram_counts[:] = 0 self.total_unigram = 0 self._slot_map = [None] + [dict() for _ in range(self.max_order)] self._next_slot = [0] * (self.max_order + 1) self._free_slots = [None] + [[] for _ in range(self.max_order)] for order in range(1, self.max_order + 1): self._inner_sizes[order][:] = 0 self._ctx_totals[order][:] = 0 def predict(self, context_tokens: list[int]) -> np.ndarray: """Predict next-token distribution given context. Uses numpy fancy indexing instead of Python dict iteration, replacing up to 256 Python loop iterations with C-level array operations that minimize GIL hold time. Args: context_tokens: List of preceding token IDs. Returns: numpy array of shape (vocab_size,) with probabilities summing to ~1. """ # Start with unigram (Laplace-smoothed) probs = self._probs np.add(self.unigram_counts, 1.0, out=probs) probs /= (self.total_unigram + self.vocab_size) for order in range(1, self.max_order + 1): if len(context_tokens) < order: break ctx = _context_hash(context_tokens, order) slot = self._slot_map[order].get(ctx) if slot is None: continue total = int(self._ctx_totals[order][slot]) if total == 0: continue lam = total / (total + self.ESCAPE) # Vectorized inner loop: single numpy fancy-index call # replaces K Python dict iterations (K up to 64). buf = self._buf buf[:] = 0 size = int(self._inner_sizes[order][slot]) ids = self._inner_ids[order][slot, :size] cts = self._inner_counts[order][slot, :size] buf[ids] = cts # C-level scatter — the key optimization buf /= buf.sum() # Blend: probs = lam * order_k + (1-lam) * probs probs *= (1.0 - lam) buf *= lam probs += buf return probs def _alloc_slot(self, order: int) -> int: """Get a free slot index, recycling evicted slots first.""" if self._free_slots[order]: return self._free_slots[order].pop() slot = self._next_slot[order] self._next_slot[order] += 1 return slot def update(self, context_tokens: list[int], actual_token: int): """Update counts after observing a token. Must be called identically during compression and decompression to maintain codec symmetry. Args: context_tokens: Context that preceded the token. actual_token: The token that was actually observed. """ # Update unigram self.unigram_counts[actual_token] += 1 self.total_unigram += 1 # Update higher orders for order in range(1, self.max_order + 1): if len(context_tokens) < order: break ctx = _context_hash(context_tokens, order) slot_map = self._slot_map[order] # Evict oldest context if table is full and this is new if ctx not in slot_map and len(slot_map) >= self.MAX_TABLE_ENTRIES: evict_ctx = next(iter(slot_map)) evict_slot = slot_map.pop(evict_ctx) self._free_slots[order].append(evict_slot) if ctx in slot_map: slot = slot_map[ctx] size = int(self._inner_sizes[order][slot]) ids = self._inner_ids[order][slot] counts = self._inner_counts[order][slot] # Search for actual_token (numpy vectorized) mask = ids[:size] == actual_token if mask.any(): # Token exists: increment its count idx = int(np.argmax(mask)) counts[idx] += 1 self._ctx_totals[order][slot] += 1 elif size < self.MAX_INNER_ENTRIES: # New token, space available: append ids[size] = actual_token counts[size] = 1 self._inner_sizes[order][slot] = size + 1 self._ctx_totals[order][slot] += 1 else: # Full (64 entries). Simulate the original add-then-evict: # new entry has count=1, evicted entry has count ≤ 1 = 1, # so net total change is always 0. min_count = int(counts[:size].min()) if min_count == 1: # Evict oldest entry with count=1, add new at end. # Shift maintains insertion order so argmin # tie-breaking matches the original dict behavior. min_idx = int(np.argmin(counts[:size])) if min_idx < size - 1: ids[min_idx:size-1] = ids[min_idx+1:size] counts[min_idx:size-1] = counts[min_idx+1:size] ids[size - 1] = actual_token counts[size - 1] = 1 # else: min_count > 1, new entry would be sole minimum # and immediately evicted — no-op on entries and total. else: # New context: allocate a slot slot = self._alloc_slot(order) slot_map[ctx] = slot self._inner_ids[order][slot, 0] = actual_token self._inner_counts[order][slot, 0] = 1 self._inner_sizes[order][slot] = 1 self._ctx_totals[order][slot] = 1