Nacrith-GPU / ngram_model.py
robtacconelli's picture
Upload 11 files
5b8133e verified
"""Token-level N-gram model for context mixing.
Maintains order-1 through order-N context tables with interpolated
backoff smoothing. Used as a fast, lightweight predictor alongside
the LLM in an ensemble. All operations are deterministic for
lossless codec symmetry.
Uses flat numpy arrays for inner storage instead of nested Python
dicts. This eliminates millions of small dict objects (~3.6 GB →
~1 GB per worker) and replaces the O(K) Python iteration loop in
predict() with a single numpy fancy-indexing call that runs at
C level, drastically reducing GIL hold time with 8+ threads.
"""
import numpy as np
def _context_hash(context_tokens, order):
"""Deterministic 64-bit hash of the last *order* tokens.
Replaces tuple(context_tokens[-order:]) as dict key, eliminating
per-token tuple allocations and reducing GC pressure.
"""
h = 0
end = len(context_tokens)
for i in range(end - order, end):
h = (h * 49157 + context_tokens[i]) & 0xFFFFFFFFFFFFFFFF
return h
class NgramModel:
"""Interpolated N-gram model operating on token IDs.
Uses iterative interpolation: higher-order models are progressively
blended with lower orders, weighted by context frequency. Unseen
contexts fall back smoothly to lower orders down to unigram.
The model updates online after each observed token, so it adapts
to the specific document being compressed.
Inner storage uses flat numpy arrays indexed by slot number.
The outer dict (context_hash → slot) preserves insertion order
for deterministic FIFO eviction.
"""
# Smoothing constant for interpolation weights.
ESCAPE = 5
# Maximum context entries per order.
MAX_TABLE_ENTRIES = 500_000
# Maximum unique continuations per context.
MAX_INNER_ENTRIES = 64
def __init__(self, max_order: int = 4, vocab_size: int = 49152):
self.max_order = max_order
self.vocab_size = vocab_size
# Order-0 (unigram) counts: dense array for fast vector ops.
self.unigram_counts = np.zeros(vocab_size, dtype=np.float64)
self.total_unigram = 0
# Order 1..N: context_hash → slot_index.
# Python dict preserves insertion order for FIFO eviction.
self._slot_map: list = [None] + [dict() for _ in range(max_order)]
# Flat inner storage per order. Each context maps to a "slot"
# containing up to MAX_INNER_ENTRIES (token_id, count) pairs.
# Entries within a slot are kept in insertion order so that
# argmin tie-breaking matches the old dict-based behavior.
self._inner_ids: list = [None] + [
np.empty((self.MAX_TABLE_ENTRIES, self.MAX_INNER_ENTRIES),
dtype=np.int32)
for _ in range(max_order)
]
self._inner_counts: list = [None] + [
np.empty((self.MAX_TABLE_ENTRIES, self.MAX_INNER_ENTRIES),
dtype=np.int32)
for _ in range(max_order)
]
self._inner_sizes: list = [None] + [
np.zeros(self.MAX_TABLE_ENTRIES, dtype=np.int16)
for _ in range(max_order)
]
self._ctx_totals: list = [None] + [
np.zeros(self.MAX_TABLE_ENTRIES, dtype=np.int32)
for _ in range(max_order)
]
# Slot allocation: sequential counter + free list for recycling.
self._next_slot = [0] * (max_order + 1)
self._free_slots: list = [None] + [[] for _ in range(max_order)]
# Pre-allocated buffers for building order predictions.
self._buf = np.zeros(vocab_size, dtype=np.float64)
self._probs = np.zeros(vocab_size, dtype=np.float64)
def reset(self):
"""Reset all counts. Call when starting a new sequence."""
self.unigram_counts[:] = 0
self.total_unigram = 0
self._slot_map = [None] + [dict() for _ in range(self.max_order)]
self._next_slot = [0] * (self.max_order + 1)
self._free_slots = [None] + [[] for _ in range(self.max_order)]
for order in range(1, self.max_order + 1):
self._inner_sizes[order][:] = 0
self._ctx_totals[order][:] = 0
def predict(self, context_tokens: list[int]) -> np.ndarray:
"""Predict next-token distribution given context.
Uses numpy fancy indexing instead of Python dict iteration,
replacing up to 256 Python loop iterations with C-level
array operations that minimize GIL hold time.
Args:
context_tokens: List of preceding token IDs.
Returns:
numpy array of shape (vocab_size,) with probabilities summing to ~1.
"""
# Start with unigram (Laplace-smoothed)
probs = self._probs
np.add(self.unigram_counts, 1.0, out=probs)
probs /= (self.total_unigram + self.vocab_size)
for order in range(1, self.max_order + 1):
if len(context_tokens) < order:
break
ctx = _context_hash(context_tokens, order)
slot = self._slot_map[order].get(ctx)
if slot is None:
continue
total = int(self._ctx_totals[order][slot])
if total == 0:
continue
lam = total / (total + self.ESCAPE)
# Vectorized inner loop: single numpy fancy-index call
# replaces K Python dict iterations (K up to 64).
buf = self._buf
buf[:] = 0
size = int(self._inner_sizes[order][slot])
ids = self._inner_ids[order][slot, :size]
cts = self._inner_counts[order][slot, :size]
buf[ids] = cts # C-level scatter — the key optimization
buf /= buf.sum()
# Blend: probs = lam * order_k + (1-lam) * probs
probs *= (1.0 - lam)
buf *= lam
probs += buf
return probs
def _alloc_slot(self, order: int) -> int:
"""Get a free slot index, recycling evicted slots first."""
if self._free_slots[order]:
return self._free_slots[order].pop()
slot = self._next_slot[order]
self._next_slot[order] += 1
return slot
def update(self, context_tokens: list[int], actual_token: int):
"""Update counts after observing a token.
Must be called identically during compression and decompression
to maintain codec symmetry.
Args:
context_tokens: Context that preceded the token.
actual_token: The token that was actually observed.
"""
# Update unigram
self.unigram_counts[actual_token] += 1
self.total_unigram += 1
# Update higher orders
for order in range(1, self.max_order + 1):
if len(context_tokens) < order:
break
ctx = _context_hash(context_tokens, order)
slot_map = self._slot_map[order]
# Evict oldest context if table is full and this is new
if ctx not in slot_map and len(slot_map) >= self.MAX_TABLE_ENTRIES:
evict_ctx = next(iter(slot_map))
evict_slot = slot_map.pop(evict_ctx)
self._free_slots[order].append(evict_slot)
if ctx in slot_map:
slot = slot_map[ctx]
size = int(self._inner_sizes[order][slot])
ids = self._inner_ids[order][slot]
counts = self._inner_counts[order][slot]
# Search for actual_token (numpy vectorized)
mask = ids[:size] == actual_token
if mask.any():
# Token exists: increment its count
idx = int(np.argmax(mask))
counts[idx] += 1
self._ctx_totals[order][slot] += 1
elif size < self.MAX_INNER_ENTRIES:
# New token, space available: append
ids[size] = actual_token
counts[size] = 1
self._inner_sizes[order][slot] = size + 1
self._ctx_totals[order][slot] += 1
else:
# Full (64 entries). Simulate the original add-then-evict:
# new entry has count=1, evicted entry has count ≤ 1 = 1,
# so net total change is always 0.
min_count = int(counts[:size].min())
if min_count == 1:
# Evict oldest entry with count=1, add new at end.
# Shift maintains insertion order so argmin
# tie-breaking matches the original dict behavior.
min_idx = int(np.argmin(counts[:size]))
if min_idx < size - 1:
ids[min_idx:size-1] = ids[min_idx+1:size]
counts[min_idx:size-1] = counts[min_idx+1:size]
ids[size - 1] = actual_token
counts[size - 1] = 1
# else: min_count > 1, new entry would be sole minimum
# and immediately evicted — no-op on entries and total.
else:
# New context: allocate a slot
slot = self._alloc_slot(order)
slot_map[ctx] = slot
self._inner_ids[order][slot, 0] = actual_token
self._inner_counts[order][slot, 0] = 1
self._inner_sizes[order][slot] = 1
self._ctx_totals[order][slot] = 1