""" Merlin BNN — self-contained inference module. No PyTorch, no MLX required. Only numpy. Usage: from bnn import load_selection_mlp selector = load_selection_mlp("selection_mlp.npz") best_idx = selector.select(candidates, signals) """ from __future__ import annotations import numpy as np from pathlib import Path # --------------------------------------------------------------------------- # MLP helpers # --------------------------------------------------------------------------- def _relu(x: np.ndarray) -> np.ndarray: return np.maximum(0.0, x) def _sigmoid(x: np.ndarray) -> np.ndarray: return 1.0 / (1.0 + np.exp(-np.clip(x, -30, 30))) def _mlp_forward(weights: dict, x: np.ndarray) -> float: h = _relu(x @ weights["W1"] + weights["b1"]) h = _relu(h @ weights["W2"] + weights["b2"]) return float(_sigmoid(h @ weights["W3"] + weights["b3"])) # --------------------------------------------------------------------------- # Feature extraction # --------------------------------------------------------------------------- def _selection_features(candidate: dict, signals: dict) -> np.ndarray: entropy = candidate.get("entropy", signals.get("mean_entropy", 0.5)) margin = candidate.get("margin", signals.get("mean_margin", 0.3)) top1_prob = candidate.get("top1_prob", signals.get("mean_top1", 0.5)) mean_e = signals.get("mean_entropy", 0.5) mean_m = signals.get("mean_margin", 0.3) mean_t1 = signals.get("mean_top1", 0.5) return np.array([ entropy, margin, top1_prob, signals.get("diversity", 0.5), signals.get("consistency", 0.5), entropy - mean_e, # calibration inversion: >0 → likely correct margin - mean_m, # calibration inversion: <0 → likely correct top1_prob - mean_t1, # calibration inversion: <0 → likely correct ], dtype=np.float32) # --------------------------------------------------------------------------- # SelectionMLP # --------------------------------------------------------------------------- class BNNSelector: """Trained SelectionMLP — scores candidates and returns the best index.""" def __init__(self, weights: dict): self._w = weights def score(self, candidate: dict, signals: dict) -> float: """Return probability in [0, 1] that this candidate is correct.""" x = _selection_features(candidate, signals) return _mlp_forward(self._w, x) def select(self, candidates: list[dict], signals: dict) -> int: """Return index of the best candidate.""" scores = [self.score(c, signals) for c in candidates] return int(np.argmax(scores)) def score_all(self, candidates: list[dict], signals: dict) -> list[float]: """Return scores for all candidates.""" return [self.score(c, signals) for c in candidates] # --------------------------------------------------------------------------- # SpikeMLP # --------------------------------------------------------------------------- class SpikePredictor: """Trained SpikeMLP — predicts whether a generation is high-entropy.""" def __init__(self, weights: dict): self._w = weights def predict(self, entropy_window: list[float]) -> float: """Return probability [0, 1] that generation entropy is above median. entropy_window: last 8 per-token entropy values from the running buffer. """ ew = list(entropy_window) ew = ew[-8:] while len(ew) < 8: ew = [0.0] + ew x = np.array(ew, dtype=np.float32) return _mlp_forward(self._w, x) def should_intervene(self, entropy_window: list[float], threshold: float = 0.5) -> bool: return self.predict(entropy_window) > threshold # --------------------------------------------------------------------------- # Loaders # --------------------------------------------------------------------------- def load_selection_mlp(path: str | Path) -> BNNSelector: """Load SelectionMLP checkpoint from .npz file.""" data = np.load(path) weights = {k: data[k] for k in data.files} return BNNSelector(weights) def load_spike_mlp(path: str | Path) -> SpikePredictor: """Load SpikeMLP checkpoint from .npz file.""" data = np.load(path) weights = {k: data[k] for k in data.files} return SpikePredictor(weights)