squ11z1 commited on
Commit
0779c96
·
verified ·
1 Parent(s): 7e1a1d5

Upload 4 files

Browse files
Files changed (4) hide show
  1. bnn.py +123 -0
  2. requirements.txt +4 -0
  3. selection_mlp.npz +3 -0
  4. spike_mlp.npz +3 -0
bnn.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Merlin BNN — self-contained inference module.
3
+
4
+ No PyTorch, no MLX required. Only numpy.
5
+
6
+ Usage:
7
+ from bnn import load_selection_mlp
8
+
9
+ selector = load_selection_mlp("selection_mlp.npz")
10
+ best_idx = selector.select(candidates, signals)
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import numpy as np
15
+ from pathlib import Path
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # MLP helpers
20
+ # ---------------------------------------------------------------------------
21
+
22
+ def _relu(x: np.ndarray) -> np.ndarray:
23
+ return np.maximum(0.0, x)
24
+
25
+
26
+ def _sigmoid(x: np.ndarray) -> np.ndarray:
27
+ return 1.0 / (1.0 + np.exp(-np.clip(x, -30, 30)))
28
+
29
+
30
+ def _mlp_forward(weights: dict, x: np.ndarray) -> float:
31
+ h = _relu(x @ weights["W1"] + weights["b1"])
32
+ h = _relu(h @ weights["W2"] + weights["b2"])
33
+ return float(_sigmoid(h @ weights["W3"] + weights["b3"]))
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Feature extraction
38
+ # ---------------------------------------------------------------------------
39
+
40
+ def _selection_features(candidate: dict, signals: dict) -> np.ndarray:
41
+ entropy = candidate.get("entropy", signals.get("mean_entropy", 0.5))
42
+ margin = candidate.get("margin", signals.get("mean_margin", 0.3))
43
+ top1_prob = candidate.get("top1_prob", signals.get("mean_top1", 0.5))
44
+ mean_e = signals.get("mean_entropy", 0.5)
45
+ mean_m = signals.get("mean_margin", 0.3)
46
+ mean_t1 = signals.get("mean_top1", 0.5)
47
+ return np.array([
48
+ entropy, margin, top1_prob,
49
+ signals.get("diversity", 0.5),
50
+ signals.get("consistency", 0.5),
51
+ entropy - mean_e, # calibration inversion: >0 → likely correct
52
+ margin - mean_m, # calibration inversion: <0 → likely correct
53
+ top1_prob - mean_t1, # calibration inversion: <0 → likely correct
54
+ ], dtype=np.float32)
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # SelectionMLP
59
+ # ---------------------------------------------------------------------------
60
+
61
+ class BNNSelector:
62
+ """Trained SelectionMLP — scores candidates and returns the best index."""
63
+
64
+ def __init__(self, weights: dict):
65
+ self._w = weights
66
+
67
+ def score(self, candidate: dict, signals: dict) -> float:
68
+ """Return probability in [0, 1] that this candidate is correct."""
69
+ x = _selection_features(candidate, signals)
70
+ return _mlp_forward(self._w, x)
71
+
72
+ def select(self, candidates: list[dict], signals: dict) -> int:
73
+ """Return index of the best candidate."""
74
+ scores = [self.score(c, signals) for c in candidates]
75
+ return int(np.argmax(scores))
76
+
77
+ def score_all(self, candidates: list[dict], signals: dict) -> list[float]:
78
+ """Return scores for all candidates."""
79
+ return [self.score(c, signals) for c in candidates]
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # SpikeMLP
84
+ # ---------------------------------------------------------------------------
85
+
86
+ class SpikePredictor:
87
+ """Trained SpikeMLP — predicts whether a generation is high-entropy."""
88
+
89
+ def __init__(self, weights: dict):
90
+ self._w = weights
91
+
92
+ def predict(self, entropy_window: list[float]) -> float:
93
+ """Return probability [0, 1] that generation entropy is above median.
94
+
95
+ entropy_window: last 8 per-token entropy values from the running buffer.
96
+ """
97
+ ew = list(entropy_window)
98
+ ew = ew[-8:]
99
+ while len(ew) < 8:
100
+ ew = [0.0] + ew
101
+ x = np.array(ew, dtype=np.float32)
102
+ return _mlp_forward(self._w, x)
103
+
104
+ def should_intervene(self, entropy_window: list[float], threshold: float = 0.5) -> bool:
105
+ return self.predict(entropy_window) > threshold
106
+
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Loaders
110
+ # ---------------------------------------------------------------------------
111
+
112
+ def load_selection_mlp(path: str | Path) -> BNNSelector:
113
+ """Load SelectionMLP checkpoint from .npz file."""
114
+ data = np.load(path)
115
+ weights = {k: data[k] for k in data.files}
116
+ return BNNSelector(weights)
117
+
118
+
119
+ def load_spike_mlp(path: str | Path) -> SpikePredictor:
120
+ """Load SpikeMLP checkpoint from .npz file."""
121
+ data = np.load(path)
122
+ weights = {k: data[k] for k in data.files}
123
+ return SpikePredictor(weights)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy>=1.24
2
+ # Apple Silicon only (for running Falcon H1):
3
+ # mlx>=0.21
4
+ # mlx-lm>=0.31.0
selection_mlp.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e96f525aca7d927f3a65ca4207ee9b461bc3568b52a23057805ec347b7192da
3
+ size 8602
spike_mlp.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62ab5e2409e20ee50cb953873a684c3dbf353ecc2dc9be24d8493c8efe39166f
3
+ size 4250