File size: 4,406 Bytes
0779c96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | """
Merlin BNN — self-contained inference module.
No PyTorch, no MLX required. Only numpy.
Usage:
from bnn import load_selection_mlp
selector = load_selection_mlp("selection_mlp.npz")
best_idx = selector.select(candidates, signals)
"""
from __future__ import annotations
import numpy as np
from pathlib import Path
# ---------------------------------------------------------------------------
# MLP helpers
# ---------------------------------------------------------------------------
def _relu(x: np.ndarray) -> np.ndarray:
return np.maximum(0.0, x)
def _sigmoid(x: np.ndarray) -> np.ndarray:
return 1.0 / (1.0 + np.exp(-np.clip(x, -30, 30)))
def _mlp_forward(weights: dict, x: np.ndarray) -> float:
h = _relu(x @ weights["W1"] + weights["b1"])
h = _relu(h @ weights["W2"] + weights["b2"])
return float(_sigmoid(h @ weights["W3"] + weights["b3"]))
# ---------------------------------------------------------------------------
# Feature extraction
# ---------------------------------------------------------------------------
def _selection_features(candidate: dict, signals: dict) -> np.ndarray:
entropy = candidate.get("entropy", signals.get("mean_entropy", 0.5))
margin = candidate.get("margin", signals.get("mean_margin", 0.3))
top1_prob = candidate.get("top1_prob", signals.get("mean_top1", 0.5))
mean_e = signals.get("mean_entropy", 0.5)
mean_m = signals.get("mean_margin", 0.3)
mean_t1 = signals.get("mean_top1", 0.5)
return np.array([
entropy, margin, top1_prob,
signals.get("diversity", 0.5),
signals.get("consistency", 0.5),
entropy - mean_e, # calibration inversion: >0 → likely correct
margin - mean_m, # calibration inversion: <0 → likely correct
top1_prob - mean_t1, # calibration inversion: <0 → likely correct
], dtype=np.float32)
# ---------------------------------------------------------------------------
# SelectionMLP
# ---------------------------------------------------------------------------
class BNNSelector:
"""Trained SelectionMLP — scores candidates and returns the best index."""
def __init__(self, weights: dict):
self._w = weights
def score(self, candidate: dict, signals: dict) -> float:
"""Return probability in [0, 1] that this candidate is correct."""
x = _selection_features(candidate, signals)
return _mlp_forward(self._w, x)
def select(self, candidates: list[dict], signals: dict) -> int:
"""Return index of the best candidate."""
scores = [self.score(c, signals) for c in candidates]
return int(np.argmax(scores))
def score_all(self, candidates: list[dict], signals: dict) -> list[float]:
"""Return scores for all candidates."""
return [self.score(c, signals) for c in candidates]
# ---------------------------------------------------------------------------
# SpikeMLP
# ---------------------------------------------------------------------------
class SpikePredictor:
"""Trained SpikeMLP — predicts whether a generation is high-entropy."""
def __init__(self, weights: dict):
self._w = weights
def predict(self, entropy_window: list[float]) -> float:
"""Return probability [0, 1] that generation entropy is above median.
entropy_window: last 8 per-token entropy values from the running buffer.
"""
ew = list(entropy_window)
ew = ew[-8:]
while len(ew) < 8:
ew = [0.0] + ew
x = np.array(ew, dtype=np.float32)
return _mlp_forward(self._w, x)
def should_intervene(self, entropy_window: list[float], threshold: float = 0.5) -> bool:
return self.predict(entropy_window) > threshold
# ---------------------------------------------------------------------------
# Loaders
# ---------------------------------------------------------------------------
def load_selection_mlp(path: str | Path) -> BNNSelector:
"""Load SelectionMLP checkpoint from .npz file."""
data = np.load(path)
weights = {k: data[k] for k in data.files}
return BNNSelector(weights)
def load_spike_mlp(path: str | Path) -> SpikePredictor:
"""Load SpikeMLP checkpoint from .npz file."""
data = np.load(path)
weights = {k: data[k] for k in data.files}
return SpikePredictor(weights)
|