| """ |
| Merlin BNN — self-contained inference module. |
| |
| No PyTorch, no MLX required. Only numpy. |
| |
| Usage: |
| from bnn import load_selection_mlp |
| |
| selector = load_selection_mlp("selection_mlp.npz") |
| best_idx = selector.select(candidates, signals) |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| from pathlib import Path |
|
|
|
|
| |
| |
| |
|
|
| def _relu(x: np.ndarray) -> np.ndarray: |
| return np.maximum(0.0, x) |
|
|
|
|
| def _sigmoid(x: np.ndarray) -> np.ndarray: |
| return 1.0 / (1.0 + np.exp(-np.clip(x, -30, 30))) |
|
|
|
|
| def _mlp_forward(weights: dict, x: np.ndarray) -> float: |
| h = _relu(x @ weights["W1"] + weights["b1"]) |
| h = _relu(h @ weights["W2"] + weights["b2"]) |
| return float(_sigmoid(h @ weights["W3"] + weights["b3"])) |
|
|
|
|
| |
| |
| |
|
|
| def _selection_features(candidate: dict, signals: dict) -> np.ndarray: |
| entropy = candidate.get("entropy", signals.get("mean_entropy", 0.5)) |
| margin = candidate.get("margin", signals.get("mean_margin", 0.3)) |
| top1_prob = candidate.get("top1_prob", signals.get("mean_top1", 0.5)) |
| mean_e = signals.get("mean_entropy", 0.5) |
| mean_m = signals.get("mean_margin", 0.3) |
| mean_t1 = signals.get("mean_top1", 0.5) |
| return np.array([ |
| entropy, margin, top1_prob, |
| signals.get("diversity", 0.5), |
| signals.get("consistency", 0.5), |
| entropy - mean_e, |
| margin - mean_m, |
| top1_prob - mean_t1, |
| ], dtype=np.float32) |
|
|
|
|
| |
| |
| |
|
|
| class BNNSelector: |
| """Trained SelectionMLP — scores candidates and returns the best index.""" |
|
|
| def __init__(self, weights: dict): |
| self._w = weights |
|
|
| def score(self, candidate: dict, signals: dict) -> float: |
| """Return probability in [0, 1] that this candidate is correct.""" |
| x = _selection_features(candidate, signals) |
| return _mlp_forward(self._w, x) |
|
|
| def select(self, candidates: list[dict], signals: dict) -> int: |
| """Return index of the best candidate.""" |
| scores = [self.score(c, signals) for c in candidates] |
| return int(np.argmax(scores)) |
|
|
| def score_all(self, candidates: list[dict], signals: dict) -> list[float]: |
| """Return scores for all candidates.""" |
| return [self.score(c, signals) for c in candidates] |
|
|
|
|
| |
| |
| |
|
|
| class SpikePredictor: |
| """Trained SpikeMLP — predicts whether a generation is high-entropy.""" |
|
|
| def __init__(self, weights: dict): |
| self._w = weights |
|
|
| def predict(self, entropy_window: list[float]) -> float: |
| """Return probability [0, 1] that generation entropy is above median. |
| |
| entropy_window: last 8 per-token entropy values from the running buffer. |
| """ |
| ew = list(entropy_window) |
| ew = ew[-8:] |
| while len(ew) < 8: |
| ew = [0.0] + ew |
| x = np.array(ew, dtype=np.float32) |
| return _mlp_forward(self._w, x) |
|
|
| def should_intervene(self, entropy_window: list[float], threshold: float = 0.5) -> bool: |
| return self.predict(entropy_window) > threshold |
|
|
|
|
| |
| |
| |
|
|
| def load_selection_mlp(path: str | Path) -> BNNSelector: |
| """Load SelectionMLP checkpoint from .npz file.""" |
| data = np.load(path) |
| weights = {k: data[k] for k in data.files} |
| return BNNSelector(weights) |
|
|
|
|
| def load_spike_mlp(path: str | Path) -> SpikePredictor: |
| """Load SpikeMLP checkpoint from .npz file.""" |
| data = np.load(path) |
| weights = {k: data[k] for k in data.files} |
| return SpikePredictor(weights) |
|
|