Spaces:
Running on Zero
Running on Zero
| """Adaptive context mixer for model blending. | |
| Combines multiple probability distributions using linear mixing | |
| with online multiplicative weight updates. The key idea: | |
| models that predict well get higher weights, adapted after every | |
| token. | |
| All operations are deterministic for lossless codec symmetry. | |
| Uses numpy instead of torch for CPU tensor operations to minimize | |
| per-operation dispatch overhead. | |
| """ | |
| import math | |
| import numpy as np | |
| class ContextMixer: | |
| """Linear context mixer with online weight adaptation. | |
| Linear mixing computes a weighted average of input distributions: | |
| mixed[i] = sum(w_j * p_j[i]) | |
| This preserves the dominant model's confidence — if the LLM puts | |
| 0.90 on a token and has weight 0.85, it contributes 0.765 to the | |
| mix. Unlike geometric mixing, near-uniform secondary models don't | |
| flatten the distribution. | |
| Weights are updated using the exponential weights algorithm | |
| (multiplicative updates based on per-model log-loss). Models that | |
| predict accurately gain weight automatically. | |
| """ | |
| # Floor for probabilities before taking log, to avoid -inf. | |
| PROB_FLOOR = 1e-8 | |
| def __init__(self, num_models: int, lr: float = 0.5, | |
| initial_weights: list[float] = None, | |
| vocab_size: int = 49152): | |
| """Initialize the mixer. | |
| Args: | |
| num_models: Number of models to mix. | |
| lr: Learning rate for weight adaptation. Higher values | |
| make the mixer react faster to model performance. | |
| 0 = static equal weights; 1 = full Bayesian updating. | |
| initial_weights: Starting weights for each model. If None, | |
| uses LLM-dominant defaults: first model gets 0.85, | |
| rest share 0.15 equally. This prevents uninformed | |
| secondary models from diluting the LLM early on. | |
| vocab_size: Vocabulary size for pre-allocating mix buffers. | |
| """ | |
| self.num_models = num_models | |
| self.lr = lr | |
| if initial_weights is not None: | |
| assert len(initial_weights) == num_models | |
| self._initial_weights = list(initial_weights) | |
| else: | |
| # LLM-dominant: first model (LLM) gets 0.85, | |
| # remaining 0.15 split equally among secondary models. | |
| if num_models == 1: | |
| self._initial_weights = [1.0] | |
| else: | |
| secondary_w = 0.15 / (num_models - 1) | |
| self._initial_weights = ( | |
| [0.85] + [secondary_w] * (num_models - 1) | |
| ) | |
| self._init_from_weights(self._initial_weights) | |
| # Pre-allocated buffers for zero-alloc mixing. | |
| self._mix_buf = np.zeros(vocab_size, dtype=np.float64) | |
| self._scale_buf = np.zeros(vocab_size, dtype=np.float64) | |
| def _init_from_weights(self, weights: list[float]): | |
| """Set log_weights and weights from a normalized weight list.""" | |
| total = sum(weights) | |
| self.weights = [w / total for w in weights] | |
| # Keep log-space copies for numerically stable updates. | |
| self.log_weights = [math.log(w + 1e-30) for w in self.weights] | |
| def reset(self): | |
| """Reset to initial weights. Call when starting a new sequence.""" | |
| self._init_from_weights(self._initial_weights) | |
| def mix(self, prob_list: list[np.ndarray]) -> np.ndarray: | |
| """Combine multiple probability distributions. | |
| Uses linear mixing: | |
| mixed[i] = sum(w_j * p_j[i]) | |
| Args: | |
| prob_list: List of numpy arrays, each shape (vocab_size,), | |
| each summing to ~1. | |
| Returns: | |
| numpy array of shape (vocab_size,) with blended probabilities. | |
| """ | |
| if len(prob_list) != self.num_models: | |
| raise ValueError( | |
| f"Expected {self.num_models} models, got {len(prob_list)}" | |
| ) | |
| if self.num_models == 1: | |
| return prob_list[0] | |
| # Weighted linear combination (in-place, zero-alloc). | |
| # Avoids w * probs temporary (384 KB per model per token). | |
| mixed = self._mix_buf | |
| mixed[:] = 0 | |
| scale_buf = self._scale_buf | |
| for w, probs in zip(self.weights, prob_list): | |
| np.multiply(probs, w, out=scale_buf) | |
| mixed += scale_buf | |
| return mixed | |
| def update(self, actual_token: int, prob_list: list[np.ndarray]): | |
| """Update weights based on observed token. | |
| Uses the exponential weights algorithm: each model's weight is | |
| multiplied by P(actual_token | model)^lr, then renormalized. | |
| Models that predicted the actual token well gain weight. | |
| Must be called identically during compression and decompression. | |
| Args: | |
| actual_token: The token that was actually observed. | |
| prob_list: Same list passed to mix() for this token. | |
| """ | |
| if self.num_models <= 1: | |
| return | |
| for i, probs in enumerate(prob_list): | |
| p = max(float(probs[actual_token]), self.PROB_FLOOR) | |
| self.log_weights[i] += self.lr * math.log(p) | |
| # Normalize weights (subtract max for numerical stability) | |
| max_lw = max(self.log_weights) | |
| raw = [math.exp(lw - max_lw) for lw in self.log_weights] | |
| total = sum(raw) | |
| self.weights = [w / total for w in raw] | |
| def get_weights(self) -> list[float]: | |
| """Return current mixer weights (for diagnostics).""" | |
| return list(self.weights) | |