ehartford commited on
Commit
57377d2
·
verified ·
1 Parent(s): 54c08c7

Create PRISMAv2.md

Browse files
Files changed (1) hide show
  1. PRISMAv2.md +487 -0
PRISMAv2.md ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## PRISMA V2: Joint Uncertainty Prediction Mechanism — Implementation Specification
2
+
3
+ **Architecture Overview**:
4
+ PRISMA V2 replaces Python-side uncertainty state with a **learned, explicit uncertainty latent** predicted jointly with tokens. At each step, the model predicts both the next token *and* an uncertainty code that conditions the following step. This preserves temporal introspection while remaining fully compatible with stateless inference engines.
5
+
6
+ ---
7
+
8
+ ## **Core Design Principle**
9
+
10
+ > **Uncertainty must be data, not memory.**
11
+
12
+ All information required for the next decoding step is carried explicitly through tensors (tokens, uncertainty codes, or cache), never through mutable module state.
13
+
14
+ ---
15
+
16
+ ## **Differences from Prisma V1 (Detailed)**
17
+
18
+ Prisma V2 is not a minor refactor of Prisma V1. It represents a **fundamental shift in how uncertainty is represented, propagated, and learned**.
19
+
20
+ This section documents those differences precisely.
21
+
22
+ ---
23
+
24
+ ### **1. Source of Uncertainty**
25
+
26
+ **Prisma V1**
27
+
28
+ * Uncertainty is **measured post-hoc** from the model’s output distribution
29
+ * Computed via entropy of logits
30
+ * Acts as an external diagnostic signal
31
+
32
+ ```text
33
+ uncertainty_t = H(P(y_t))
34
+ ```
35
+
36
+ **Prisma V2**
37
+
38
+ * Uncertainty is **predicted by the model itself**
39
+ * Learned as an auxiliary latent variable
40
+ * Acts as an internal representation
41
+
42
+ ```text
43
+ (token_{t+1}, uncertainty_{t+1}) = f(token_t, uncertainty_t)
44
+ ```
45
+
46
+ **Implication**:
47
+ V1 answers *“how uncertain was I?”*
48
+ V2 answers *“how uncertain will I be?”*
49
+
50
+ ---
51
+
52
+ ### **2. State Representation**
53
+
54
+ **Prisma V1**
55
+
56
+ * Uses mutable Python-side state:
57
+
58
+ ```python
59
+ self.prev_uncertainty_code
60
+ ```
61
+
62
+ * State exists **outside** the model’s forward graph
63
+ * Relies on strict step-by-step execution order
64
+
65
+ **Prisma V2**
66
+
67
+ * No mutable state
68
+ * Uncertainty is passed explicitly as a tensor:
69
+
70
+ ```python
71
+ uncertainty_codes: Tensor[B, S]
72
+ ```
73
+
74
+ * Fully contained within the model’s inputs and outputs
75
+
76
+ **Implication**:
77
+ V1 requires engine cooperation.
78
+ V2 requires only tensors.
79
+
80
+ ---
81
+
82
+ ### **3. Runtime Compatibility**
83
+
84
+ | Runtime | Prisma V1 | Prisma V2 |
85
+ | ------------------------ | --------- | --------- |
86
+ | HuggingFace Transformers | ✅ | ✅ |
87
+ | vLLM | ❌ | ✅ |
88
+ | llama.cpp | ❌ | ✅ |
89
+ | MLX | ❌ | ✅ |
90
+ | Tensor Parallel | ⚠️ | ✅ |
91
+
92
+ **Reason**:
93
+
94
+ * V1 violates the stateless decoding assumptions of modern runtimes
95
+ * V2 conforms to them by construction
96
+
97
+ ---
98
+
99
+ ### **4. Temporal Feedback Mechanism**
100
+
101
+ **Prisma V1**
102
+
103
+ * Feedback loop implemented via external buffer
104
+ * Requires padding, truncation, and shifting logic
105
+ * Not visible to KV cache or sampler
106
+
107
+ **Prisma V2**
108
+
109
+ * Feedback loop is **architectural**
110
+ * Uncertainty is predicted one step ahead and injected naturally
111
+ * Temporal alignment is implicit in training and decoding
112
+
113
+ **Implication**:
114
+ V2’s feedback loop is **native**, not simulated.
115
+
116
+ ---
117
+
118
+ ### **5. Learning Dynamics**
119
+
120
+ **Prisma V1**
121
+
122
+ * Uncertainty signal is fixed (entropy)
123
+ * Model can only learn *how to react* to uncertainty
124
+ * Cannot redefine what uncertainty means
125
+
126
+ **Prisma V2**
127
+
128
+ * Uncertainty is supervised initially by entropy, then free to diverge
129
+ * Model can learn:
130
+
131
+ * epistemic uncertainty
132
+ * ambiguity
133
+ * distribution shift
134
+ * task-specific hesitation signals
135
+
136
+ **Implication**:
137
+ V1 teaches *response to uncertainty*.
138
+ V2 teaches *representation of uncertainty*.
139
+
140
+ ---
141
+
142
+ ### **6. Training Complexity**
143
+
144
+ **Prisma V1**
145
+
146
+ * No additional loss
147
+ * Entropy computed every forward
148
+ * Sensitive to tensor parallel sharding
149
+
150
+ **Prisma V2**
151
+
152
+ * Adds a lightweight auxiliary loss
153
+ * Entropy used only as a teacher signal during training
154
+ * No entropy computation at inference
155
+
156
+ **Implication**:
157
+ V2 trades a small training cost for large inference robustness.
158
+
159
+ ---
160
+
161
+ ### **7. Inference Behavior**
162
+
163
+ **Prisma V1**
164
+
165
+ * Uncertainty exists only implicitly
166
+ * Difficult to inspect or intervene at runtime
167
+ * Breaks under batched or reordered decoding
168
+
169
+ **Prisma V2**
170
+
171
+ * Uncertainty is explicit and inspectable
172
+ * Sampler can condition on it
173
+ * Works under any batching or scheduling strategy
174
+
175
+ ---
176
+
177
+ ### **8. Conceptual Framing**
178
+
179
+ **Prisma V1**
180
+
181
+ * Introspection via *measurement*
182
+ * Confidence is something the model observes after the fact
183
+
184
+ **Prisma V2**
185
+
186
+ * Introspection via *prediction*
187
+ * Confidence is something the model reasons about and plans with
188
+
189
+ > Prisma V1 makes the model *aware of its uncertainty.*
190
+ > Prisma V2 makes uncertainty part of the model’s internal world model.
191
+
192
+ ---
193
+
194
+ ### **Summary Table**
195
+
196
+ | Dimension | Prisma V1 | Prisma V2 |
197
+ | ---------------------- | ------------------ | ------------------ |
198
+ | Uncertainty source | Entropy (measured) | Learned latent |
199
+ | State handling | Mutable buffer | Explicit tensor |
200
+ | Runtime support | Limited | Universal |
201
+ | KV cache compatibility | ❌ | ✅ |
202
+ | Tensor parallel | Fragile | Safe |
203
+ | Introspection depth | Reactive | Predictive |
204
+ | Deployment readiness | Research-only | Production-capable |
205
+
206
+ ---
207
+
208
+ ### **Why Prisma V2 Exists**
209
+
210
+ Prisma V1 demonstrated that **temporal uncertainty feedback produces introspective behavior**.
211
+
212
+ Prisma V2 makes that insight **architectural, portable, and deployable**.
213
+
214
+ It is not a workaround.
215
+ It is the correct abstraction boundary.
216
+
217
+ > *Uncertainty must be data, not memory.*
218
+
219
+ ---
220
+
221
+ ## **Core Components to Add**
222
+
223
+ ```python
224
+ # In your CausalLM class
225
+ self.n_uncertainty_levels = 256 # V2: compact, sufficient
226
+ self.uncertainty_embeddings = nn.Embedding(
227
+ self.n_uncertainty_levels,
228
+ hidden_dim
229
+ )
230
+
231
+ # NEW: Uncertainty prediction head
232
+ self.uncertainty_head = nn.Linear(
233
+ hidden_dim,
234
+ self.n_uncertainty_levels,
235
+ bias=False
236
+ )
237
+ ```
238
+
239
+ ---
240
+
241
+ ## **Initialization Details**
242
+
243
+ ### Uncertainty Embeddings
244
+
245
+ * Initialized from `N(0, σ²)` where `σ = config.initializer_range`
246
+
247
+ ### Uncertainty Head (Important)
248
+
249
+ ```python
250
+ self.uncertainty_head.weight.data.zero_()
251
+ ```
252
+
253
+ **Rationale**:
254
+
255
+ * Model initially predicts *neutral uncertainty*
256
+ * Early training behaves identically to the base model
257
+ * Avoids destabilizing LM loss with noisy auxiliary signals
258
+ * Uncertainty pathway is learned gradually
259
+
260
+ ---
261
+
262
+ ## **Forward Pass Modifications (Input Side)**
263
+
264
+ **Location**: *Immediately after token embedding lookup*
265
+
266
+ ```python
267
+ def forward(self, input_ids, uncertainty_codes=None, ...):
268
+ inputs_embeds = self.embed_tokens(input_ids)
269
+
270
+ if uncertainty_codes is not None:
271
+ # uncertainty_codes: [B, S]
272
+ u = self.uncertainty_embeddings(uncertainty_codes)
273
+ inputs_embeds = inputs_embeds + u
274
+
275
+ hidden_states = self.model(
276
+ inputs_embeds=inputs_embeds,
277
+ ...
278
+ ).last_hidden_state
279
+ ```
280
+
281
+ * `uncertainty_codes[t]` conditions token position `t`
282
+ * No padding, truncation, or shifting logic required
283
+ * Temporal alignment is handled by the training and decoding loop
284
+
285
+ ---
286
+
287
+ ## **Forward Pass Modifications (Output Side)**
288
+
289
+ **Location**: *After transformer hidden states*
290
+
291
+ ```python
292
+ logits = self.lm_head(hidden_states)
293
+ uncertainty_logits = self.uncertainty_head(hidden_states)
294
+ ```
295
+
296
+ Returns:
297
+
298
+ ```python
299
+ return {
300
+ "logits": logits, # [B, S, vocab]
301
+ "uncertainty_logits": uncertainty_logits # [B, S, n_uncertainty_levels]
302
+ }
303
+ ```
304
+
305
+ ---
306
+
307
+ ## **Temporal Semantics**
308
+
309
+ | Position | Input | Predicts |
310
+ | -------- | --------------------- | ------------------------ |
311
+ | t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ |
312
+
313
+ This preserves the original PRISMA temporal feedback loop without mutable state.
314
+
315
+ ---
316
+
317
+ ## **Training Objective**
318
+
319
+ ### Language Modeling Loss
320
+
321
+ Standard next-token prediction:
322
+
323
+ ```python
324
+ loss_lm = cross_entropy(
325
+ logits[:, :-1],
326
+ labels[:, 1:]
327
+ )
328
+ ```
329
+
330
+ ---
331
+
332
+ ### Uncertainty Prediction Loss
333
+
334
+ Uncertainty is predicted **one step ahead**:
335
+
336
+ ```python
337
+ loss_uncertainty = cross_entropy(
338
+ uncertainty_logits[:, :-1],
339
+ uncertainty_labels[:, 1:]
340
+ )
341
+ ```
342
+
343
+ ---
344
+
345
+ ### Combined Loss
346
+
347
+ ```python
348
+ loss = loss_lm + λ * loss_uncertainty
349
+ ```
350
+
351
+ * Recommended: `λ ≈ 0.1` (to tune)
352
+
353
+ ---
354
+
355
+ ## **Uncertainty Supervision (Teacher Signal)**
356
+
357
+ During training only, entropy is used as a **bootstrap target**, not as the definition of uncertainty.
358
+
359
+ ```python
360
+ with torch.no_grad():
361
+ probs = softmax(logits)
362
+ entropy = -(probs * log(probs)).sum(dim=-1)
363
+ entropy_norm = entropy / log(vocab_size)
364
+ uncertainty_labels = quantize(entropy_norm)
365
+ ```
366
+
367
+ **Important**:
368
+
369
+ * Entropy is a *teacher*, not a constraint
370
+ * The model may learn uncertainty signals that diverge from entropy
371
+ * This is desirable if they correlate better with error or ambiguity
372
+
373
+ ---
374
+
375
+ ## **Single-Pass Training (Preferred)**
376
+
377
+ A second forward pass is **not required**.
378
+
379
+ ```python
380
+ outputs = model(
381
+ input_ids,
382
+ uncertainty_codes=uncertainty_input
383
+ )
384
+
385
+ with torch.no_grad():
386
+ uncertainty_labels = compute_uncertainty_labels(outputs.logits)
387
+
388
+ loss = compute_loss(
389
+ outputs.logits,
390
+ outputs.uncertainty_logits,
391
+ labels,
392
+ uncertainty_labels
393
+ )
394
+ ```
395
+
396
+ ---
397
+
398
+ ## **Inference Loop (All Runtimes)**
399
+
400
+ ```text
401
+ (tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁)
402
+ ```
403
+
404
+ ### Neutral Start
405
+
406
+ ```python
407
+ uncertainty_code = n_uncertainty_levels // 2
408
+ ```
409
+
410
+ ---
411
+
412
+ ## **Runtime Integration**
413
+
414
+ | Runtime | Integration |
415
+ | ---------------- | ---------------------------------------------------- |
416
+ | **Transformers** | Custom `generate()` tracks `uncertainty_code` tensor |
417
+ | **vLLM** | Sampler tracks one uncertainty code per request |
418
+ | **llama.cpp** | Store uncertainty code in `llama_context` |
419
+ | **MLX** | Works directly (pure tensor graph) |
420
+
421
+ No runtime relies on Python-side mutable state.
422
+
423
+ ---
424
+
425
+ ## **Performance Characteristics**
426
+
427
+ | Component | Parameters | FLOPs | Memory | Latency |
428
+ | ----------------------- | ------------------ | ---------- | ---------- | ---------------- |
429
+ | Uncertainty Head | `hidden_dim × 256` | Negligible | Negligible | ~0 |
430
+ | Uncertainty Embedding | `256 × hidden_dim` | 0 | Tiny | ~0 |
431
+ | Entropy (training only) | 0 | `O(B×S×V)` | O(1) | Not in inference |
432
+
433
+ **Inference overhead**: effectively zero
434
+
435
+ ---
436
+
437
+ ## **Theoretical Intuition**
438
+
439
+ PRISMA V2 transforms autoregressive generation from:
440
+
441
+ ```
442
+ P(y_t | x, y_<t)
443
+ ```
444
+
445
+ to:
446
+
447
+ ```
448
+ P(y_t, c_t | x, y_<t, c_<t)
449
+ ```
450
+
451
+ where `c_t` is a learned uncertainty latent.
452
+
453
+ This allows the model to:
454
+
455
+ * Reduce commitment after uncertain predictions
456
+ * Maintain momentum after confident predictions
457
+ * Learn task-specific uncertainty signals
458
+ * Develop introspection without relying on engine-level state
459
+
460
+ ---
461
+
462
+ ## **Why PRISMA V2 Works Everywhere**
463
+
464
+ | Constraint | V1 | V2 |
465
+ | ------------------ | -- | -- |
466
+ | Stateless decoding | ❌ | ✅ |
467
+ | vLLM batching | ❌ | ✅ |
468
+ | llama.cpp KV cache | ❌ | ✅ |
469
+ | Tensor parallel | ⚠️ | ✅ |
470
+ | MLX tracing | ❌ | ✅ |
471
+
472
+ ---
473
+
474
+ ## **What to Watch For**
475
+
476
+ * **Ablation**: remove uncertainty input, measure perplexity / behavior
477
+ * **Calibration**: does predicted uncertainty correlate with error?
478
+ * **Behavioral shifts**: hedging, correction, abstention
479
+ * **Divergence from entropy**: expected and healthy
480
+
481
+ ---
482
+
483
+ ## **Summary**
484
+
485
+ Prisma V2 preserves the introspective insight of Prisma V1 while replacing fragile mutable state with an explicit, learned uncertainty representation. This makes introspection **portable, scalable, and deployable** across all modern inference engines.
486
+
487
+ > *The model no longer measures uncertainty — it learns what uncertainty means.*