squ11z1's picture
Upload 4 files
0779c96 verified
"""
Merlin BNN — self-contained inference module.
No PyTorch, no MLX required. Only numpy.
Usage:
from bnn import load_selection_mlp
selector = load_selection_mlp("selection_mlp.npz")
best_idx = selector.select(candidates, signals)
"""
from __future__ import annotations
import numpy as np
from pathlib import Path
# ---------------------------------------------------------------------------
# MLP helpers
# ---------------------------------------------------------------------------
def _relu(x: np.ndarray) -> np.ndarray:
return np.maximum(0.0, x)
def _sigmoid(x: np.ndarray) -> np.ndarray:
return 1.0 / (1.0 + np.exp(-np.clip(x, -30, 30)))
def _mlp_forward(weights: dict, x: np.ndarray) -> float:
h = _relu(x @ weights["W1"] + weights["b1"])
h = _relu(h @ weights["W2"] + weights["b2"])
return float(_sigmoid(h @ weights["W3"] + weights["b3"]))
# ---------------------------------------------------------------------------
# Feature extraction
# ---------------------------------------------------------------------------
def _selection_features(candidate: dict, signals: dict) -> np.ndarray:
entropy = candidate.get("entropy", signals.get("mean_entropy", 0.5))
margin = candidate.get("margin", signals.get("mean_margin", 0.3))
top1_prob = candidate.get("top1_prob", signals.get("mean_top1", 0.5))
mean_e = signals.get("mean_entropy", 0.5)
mean_m = signals.get("mean_margin", 0.3)
mean_t1 = signals.get("mean_top1", 0.5)
return np.array([
entropy, margin, top1_prob,
signals.get("diversity", 0.5),
signals.get("consistency", 0.5),
entropy - mean_e, # calibration inversion: >0 → likely correct
margin - mean_m, # calibration inversion: <0 → likely correct
top1_prob - mean_t1, # calibration inversion: <0 → likely correct
], dtype=np.float32)
# ---------------------------------------------------------------------------
# SelectionMLP
# ---------------------------------------------------------------------------
class BNNSelector:
"""Trained SelectionMLP — scores candidates and returns the best index."""
def __init__(self, weights: dict):
self._w = weights
def score(self, candidate: dict, signals: dict) -> float:
"""Return probability in [0, 1] that this candidate is correct."""
x = _selection_features(candidate, signals)
return _mlp_forward(self._w, x)
def select(self, candidates: list[dict], signals: dict) -> int:
"""Return index of the best candidate."""
scores = [self.score(c, signals) for c in candidates]
return int(np.argmax(scores))
def score_all(self, candidates: list[dict], signals: dict) -> list[float]:
"""Return scores for all candidates."""
return [self.score(c, signals) for c in candidates]
# ---------------------------------------------------------------------------
# SpikeMLP
# ---------------------------------------------------------------------------
class SpikePredictor:
"""Trained SpikeMLP — predicts whether a generation is high-entropy."""
def __init__(self, weights: dict):
self._w = weights
def predict(self, entropy_window: list[float]) -> float:
"""Return probability [0, 1] that generation entropy is above median.
entropy_window: last 8 per-token entropy values from the running buffer.
"""
ew = list(entropy_window)
ew = ew[-8:]
while len(ew) < 8:
ew = [0.0] + ew
x = np.array(ew, dtype=np.float32)
return _mlp_forward(self._w, x)
def should_intervene(self, entropy_window: list[float], threshold: float = 0.5) -> bool:
return self.predict(entropy_window) > threshold
# ---------------------------------------------------------------------------
# Loaders
# ---------------------------------------------------------------------------
def load_selection_mlp(path: str | Path) -> BNNSelector:
"""Load SelectionMLP checkpoint from .npz file."""
data = np.load(path)
weights = {k: data[k] for k in data.files}
return BNNSelector(weights)
def load_spike_mlp(path: str | Path) -> SpikePredictor:
"""Load SpikeMLP checkpoint from .npz file."""
data = np.load(path)
weights = {k: data[k] for k in data.files}
return SpikePredictor(weights)