File size: 6,513 Bytes
6c0624c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# Prisma V2 — ToDo.md

## Goal

Evolve PRISMA from a **Python-stateful uncertainty feedback mechanism** (V1) into a **fully stateless, runtime-portable architecture** compatible with:

* Hugging Face Transformers
* vLLM
* llama.cpp
* MLX

while preserving the core idea:

> **Condition future predictions on the model’s own uncertainty to reduce confident errors and hallucinations.**

---

## Background (Why V2 Is Needed)

Prisma V1 relies on mutable per-step Python state:

```python
self.prev_uncertainty_code
```

This works in Hugging Face `generate()`, but fails in modern inference engines:

* **vLLM** — batched, reordered, stateless execution
* **llama.cpp** — compiled C/C++ loop with fixed KV cache semantics
* **MLX** — pure functional graph execution

All target runtimes require decoding to be **purely functional**:

> All information needed for the next step must be carried explicitly in tensors, tokens, or KV cache — not Python object state.

---

## Core Design Change (Prisma V2)

### Replace measured uncertainty with **predicted uncertainty**

Instead of:

* computing entropy post-hoc from logits
* storing uncertainty in a mutable buffer

**The model learns to predict uncertainty directly.**

At each position, the model outputs:

1. **Next-token logits** (existing)
2. **Uncertainty logits for the next position** (new)

Uncertainty becomes a **learned latent variable**, not a side effect.

---

## Architecture Changes

### 1. Add an uncertainty prediction head

```python
self.n_uncertainty_levels = 256  # V2: smaller, sufficient
self.uncertainty_head = nn.Linear(hidden_size, n_uncertainty_levels, bias=False)
```

### 2. Uncertainty head initialization (important)

The uncertainty head is added to a pretrained model. To avoid destabilizing early training, use **zero initialization**:

```python
self.uncertainty_head.weight.data.zero_()
```

**Rationale:**

* Model initially predicts neutral uncertainty everywhere
* Early training behaves identically to the base model
* Uncertainty signal is learned gradually, only when useful

Alternative initializations (small random, copying from `lm_head`) are left for experimentation.

---

### 3. Keep uncertainty embeddings (input side)

```python
self.uncertainty_embeddings = nn.Embedding(
    n_uncertainty_levels,
    hidden_size
)
```

---

### 4. Modify forward signature

```python
def forward(
    input_ids: torch.Tensor,
    uncertainty_codes: Optional[torch.Tensor] = None,  # [B, S]
    **kwargs
):
```

* `uncertainty_codes[t]` conditions token `t`
* No hidden buffers
* No mutable Python state

---

## Forward Pass Logic (V2)

1. Embed tokens
2. If `uncertainty_codes` provided:

   * lookup `uncertainty_embeddings`
   * add to `inputs_embeds`
3. Run transformer
4. Produce:

   * `logits` (next token)
   * `uncertainty_logits` (next uncertainty)

```python
return {
    "logits": logits,
    "uncertainty_logits": uncertainty_logits,
}
```

---

## Temporal Semantics

| Position | Input                 | Predicts                 |
| -------- | --------------------- | ------------------------ |
| t        | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ |

This preserves the original PRISMA temporal feedback loop without mutable state.

---

## Training Plan

### Uncertainty supervision (teacher signal)

During training, entropy is used **only as a teacher signal**, not as the definition of uncertainty.

```python
entropy = -∑ p log p
normalized = entropy / log(vocab_size)
uncertainty_label = quantize(normalized)
```

---

### Single-pass training (preferred)

A second forward pass is **not required**.

```python
outputs = model(
    input_ids,
    uncertainty_codes=uncertainty_input
)

with torch.no_grad():
    uncertainty_labels = quantize_entropy(outputs.logits)

loss = (
    loss_lm(outputs.logits, labels)
    + λ * loss_uncertainty(outputs.uncertainty_logits, uncertainty_labels)
)
```

**Key point:**

* Entropy is a **bootstrap target**
* The model is free to learn uncertainty representations that diverge from entropy over time
* This allows uncertainty to correlate better with *error* than raw entropy does

---

### Loss definition

```python
loss = loss_lm + λ * loss_uncertainty
```

* `loss_lm`: standard next-token cross-entropy
* `loss_uncertainty`: cross-entropy over uncertainty codes
* λ ≈ 0.1 (to tune)

---

## Inference (All Runtimes)

### Decode loop (conceptual)

```text
(tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁)
```

### Runtime responsibilities

* **Transformers**: custom `generate()` tracks uncertainty tensor
* **vLLM**: sampler tracks `uncertainty_code` per request
* **llama.cpp**: store one small uncertainty code in `llama_context`
* **MLX**: works naturally (pure tensor graph)

No runtime needs to preserve Python object state.

---

## Compatibility Matrix

| Runtime         | Prisma V1 | Prisma V2 |
| --------------- | --------- | --------- |
| Transformers    | ✅         | ✅         |
| vLLM            | ❌         | ✅         |
| llama.cpp       | ❌         | ✅         |
| MLX             | ❌         | ✅         |
| Tensor Parallel | ⚠️        | ✅         |

---

## Design Decisions

### Classification vs regression

**Chosen:** classification (quantized uncertainty)

Reasons:

* Stable training
* Matches embedding lookup
* Discrete semantics
* Easier runtime handling

Regression remains an experimental alternative.

---

### Uncertainty resolution

* V1: 65,536 levels (overkill)
* V2: **256 levels** (sufficient, efficient, portable)

---

## Known Limitations

* Uncertainty reflects **model confidence**, not correctness
* Learned uncertainty may differ from Shannon entropy
* Does not guarantee abstention or correctness
* Behavior depends on training data and loss weighting

---

## Open Questions / Future Work

* Tune `n_uncertainty_levels`
* Tune λ (uncertainty loss weight)
* Explore uncertainty-aware decoding strategies
* Compare uncertainty prediction vs uncertainty tokens
* Investigate bootstrapping without entropy supervision

---

## Definition of Done (Prisma V2)

* [ ] No mutable per-step Python state
* [ ] Uncertainty passed explicitly as tensor
* [ ] Works in Transformers, vLLM, llama.cpp, MLX
* [ ] Zero-init uncertainty head
* [ ] Single-pass training loop
* [ ] Updated model card + documentation
* [ ] Reference implementation + example

---

### Guiding Principle (V2)

> **Uncertainty must be data, not memory.**