File size: 4,406 Bytes
0779c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Merlin BNN — self-contained inference module.

No PyTorch, no MLX required. Only numpy.

Usage:
    from bnn import load_selection_mlp

    selector = load_selection_mlp("selection_mlp.npz")
    best_idx = selector.select(candidates, signals)
"""
from __future__ import annotations

import numpy as np
from pathlib import Path


# ---------------------------------------------------------------------------
# MLP helpers
# ---------------------------------------------------------------------------

def _relu(x: np.ndarray) -> np.ndarray:
    return np.maximum(0.0, x)


def _sigmoid(x: np.ndarray) -> np.ndarray:
    return 1.0 / (1.0 + np.exp(-np.clip(x, -30, 30)))


def _mlp_forward(weights: dict, x: np.ndarray) -> float:
    h = _relu(x @ weights["W1"] + weights["b1"])
    h = _relu(h @ weights["W2"] + weights["b2"])
    return float(_sigmoid(h @ weights["W3"] + weights["b3"]))


# ---------------------------------------------------------------------------
# Feature extraction
# ---------------------------------------------------------------------------

def _selection_features(candidate: dict, signals: dict) -> np.ndarray:
    entropy   = candidate.get("entropy",   signals.get("mean_entropy", 0.5))
    margin    = candidate.get("margin",    signals.get("mean_margin",  0.3))
    top1_prob = candidate.get("top1_prob", signals.get("mean_top1",    0.5))
    mean_e  = signals.get("mean_entropy", 0.5)
    mean_m  = signals.get("mean_margin",  0.3)
    mean_t1 = signals.get("mean_top1",    0.5)
    return np.array([
        entropy, margin, top1_prob,
        signals.get("diversity",    0.5),
        signals.get("consistency",  0.5),
        entropy   - mean_e,   # calibration inversion: >0 → likely correct
        margin    - mean_m,   # calibration inversion: <0 → likely correct
        top1_prob - mean_t1,  # calibration inversion: <0 → likely correct
    ], dtype=np.float32)


# ---------------------------------------------------------------------------
# SelectionMLP
# ---------------------------------------------------------------------------

class BNNSelector:
    """Trained SelectionMLP — scores candidates and returns the best index."""

    def __init__(self, weights: dict):
        self._w = weights

    def score(self, candidate: dict, signals: dict) -> float:
        """Return probability in [0, 1] that this candidate is correct."""
        x = _selection_features(candidate, signals)
        return _mlp_forward(self._w, x)

    def select(self, candidates: list[dict], signals: dict) -> int:
        """Return index of the best candidate."""
        scores = [self.score(c, signals) for c in candidates]
        return int(np.argmax(scores))

    def score_all(self, candidates: list[dict], signals: dict) -> list[float]:
        """Return scores for all candidates."""
        return [self.score(c, signals) for c in candidates]


# ---------------------------------------------------------------------------
# SpikeMLP
# ---------------------------------------------------------------------------

class SpikePredictor:
    """Trained SpikeMLP — predicts whether a generation is high-entropy."""

    def __init__(self, weights: dict):
        self._w = weights

    def predict(self, entropy_window: list[float]) -> float:
        """Return probability [0, 1] that generation entropy is above median.

        entropy_window: last 8 per-token entropy values from the running buffer.
        """
        ew = list(entropy_window)
        ew = ew[-8:]
        while len(ew) < 8:
            ew = [0.0] + ew
        x = np.array(ew, dtype=np.float32)
        return _mlp_forward(self._w, x)

    def should_intervene(self, entropy_window: list[float], threshold: float = 0.5) -> bool:
        return self.predict(entropy_window) > threshold


# ---------------------------------------------------------------------------
# Loaders
# ---------------------------------------------------------------------------

def load_selection_mlp(path: str | Path) -> BNNSelector:
    """Load SelectionMLP checkpoint from .npz file."""
    data = np.load(path)
    weights = {k: data[k] for k in data.files}
    return BNNSelector(weights)


def load_spike_mlp(path: str | Path) -> SpikePredictor:
    """Load SpikeMLP checkpoint from .npz file."""
    data = np.load(path)
    weights = {k: data[k] for k in data.files}
    return SpikePredictor(weights)