ehartford commited on
Commit
54c08c7
·
verified ·
1 Parent(s): 6c0624c

Create PRISMA.md

Browse files
Files changed (1) hide show
  1. PRISMA.md +274 -0
PRISMA.md ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## PRISMA: 16-Bit Temporal Introspection Mechanism - Implementation Specification
2
+
3
+ **Architecture Overview**: PRISMA adds a cross-step feedback loop where model uncertainty from the *previous* forward pass modulates input embeddings for the *current* step. This enables introspective behavior without modifying internal transformer layers.
4
+
5
+ ---
6
+
7
+ ### **Core Components to Add**
8
+
9
+ ```python
10
+ # 1. In your model class (e.g., LlamaModel, MistralModel)
11
+ self.uncertainty_embeddings = nn.Embedding(65536, hidden_dim) # 16-bit codes
12
+ self.register_buffer('prev_uncertainty_code', None) # [batch, prev_seq_len]
13
+ ```
14
+
15
+ ---
16
+
17
+ ### **Initialization Details**
18
+
19
+ - **Embedding Table**: Initialize weights from N(0, σ²) where σ = `config.initializer_range` (typically 0.02)
20
+ - **Buffer**: `prev_uncertainty_code` starts as `None`; will be lazily initialized on first forward pass
21
+ - **Device/Dtype**: Buffer automatically inherits model's device; ensure `uncertainty_embeddings` runs in same dtype as model (typically bfloat16)
22
+
23
+ ---
24
+
25
+ ### **Forward Pass Modifications (Input Side)**
26
+
27
+ **Location**: *Immediately after input embedding lookup, before transformer layers*
28
+
29
+ ```python
30
+ # Pseudocode for model forward()
31
+ def forward(self, input_ids, inputs_embeds=None, ...):
32
+ if inputs_embeds is None:
33
+ inputs_embeds = self.embed_tokens(input_ids)
34
+
35
+ # === PRISMA INJECTION POINT ===
36
+ batch_size, seq_len = inputs_embeds.shape[:2]
37
+
38
+ # Handle uncertainty state initialization
39
+ if self.prev_uncertainty_code is None or self.prev_uncertainty_code.shape[0] != batch_size:
40
+ # First pass or batch size changed: use neutral uncertainty
41
+ uncertainty_code = torch.full(
42
+ (batch_size, seq_len), 32768, # N/2 = neutral
43
+ dtype=torch.long, device=inputs_embeds.device
44
+ )
45
+ else:
46
+ # Pad or truncate to match current sequence length
47
+ prev_len = self.prev_uncertainty_code.shape[1]
48
+ if prev_len < seq_len:
49
+ padding = torch.full(
50
+ (batch_size, seq_len - prev_len), 32768,
51
+ dtype=torch.long, device=inputs_embeds.device
52
+ )
53
+ uncertainty_code = torch.cat([self.prev_uncertainty_code, padding], dim=1)
54
+ else:
55
+ uncertainty_code = self.prev_uncertainty_code[:, :seq_len]
56
+
57
+ # Lookup and shift embeddings (position t gets uncertainty from t-1)
58
+ uncertainty_embeds = self.uncertainty_embeddings(uncertainty_code) # [B, S, D]
59
+ uncertainty_shifted = F.pad(
60
+ uncertainty_embeds[:, :-1, :], (0, 0, 1, 0), value=0.0
61
+ ) # First position gets zero
62
+
63
+ # Inject into main embeddings
64
+ inputs_embeds = inputs_embeds + uncertainty_shifted
65
+ # === END PRISMA INJECTION ===
66
+
67
+ # Proceed to transformer layers as normal
68
+ hidden_states = self.layers(inputs_embeds, ...)
69
+ return hidden_states
70
+ ```
71
+
72
+ ---
73
+
74
+ ### **Forward Pass Modifications (Output Side)**
75
+
76
+ **Location**: *In your CausalLM class (e.g., LlamaForCausalLM) after computing logits*
77
+
78
+ ```python
79
+ # Pseudocode for CausalLM forward()
80
+ def forward(self, ..., labels=None, return_dict=True):
81
+ outputs = self.model(...)
82
+ hidden_states = outputs.last_hidden_state
83
+ logits = self.lm_head(hidden_states)
84
+
85
+ # === PRISMA UNCERTAINTY COMPUTATION ===
86
+ if self.training or logits is not None: # Compute during both train and inference
87
+ with torch.no_grad():
88
+ # Detach to avoid gradient flow into uncertainty mechanism
89
+ probs = logits.detach().softmax(dim=-1) # [B, S, V]
90
+
91
+ # Compute normalized entropy
92
+ log_probs = torch.log(probs.clamp(min=1e-9))
93
+ entropy = -(probs * log_probs).sum(dim=-1) # [B, S]
94
+
95
+ # Normalize by uniform distribution entropy
96
+ max_entropy = math.log(probs.size(-1))
97
+ entropy_norm = (entropy / max_entropy).clamp(0.0, 1.0)
98
+
99
+ # Quantize to 16-bit integer codes [0, 65535]
100
+ self.model.prev_uncertainty_code = (
101
+ entropy_norm * 65535
102
+ ).long().clamp(0, 65535)
103
+ # === END PRISMA COMPUTATION ===
104
+
105
+ # Compute loss, return outputs as normal
106
+ loss = None
107
+ if labels is not None:
108
+ loss = self.loss_function(logits, labels)
109
+
110
+ return CausalLMOutputWithPast(
111
+ loss=loss,
112
+ logits=logits,
113
+ past_key_values=outputs.past_key_values,
114
+ )
115
+ ```
116
+
117
+ ---
118
+
119
+ ### **Generation Loop Integration**
120
+
121
+ **Required**: Reset uncertainty state between generation runs
122
+
123
+ ```python
124
+ # Add this method to your CausalLM class
125
+ def reset_uncertainty(self):
126
+ """Call this before each new generation to clear uncertainty state"""
127
+ self.model.prev_uncertainty_code = None
128
+
129
+ # In your generation code:
130
+ model.reset_uncertainty() # Essential!
131
+ outputs = model.generate(**inputs)
132
+ ```
133
+
134
+ ---
135
+
136
+ ### **Key Implementation Notes for Arbitrary Models**
137
+
138
+ | Model Type | Integration Points |
139
+ |------------|-------------------|
140
+ | **Standard Decoder (Llama, Mistral)** | Inject in `forward()` after `self.embed_tokens()`; compute uncertainty in `ForCausalLM.forward()` |
141
+ | **Encoder-Decoder (T5)** | Inject in decoder embedding; compute uncertainty from decoder output logits |
142
+ | **Vision-Language (LLaVA, DeepSeek-VL)** | Inject *after* multimodal projections; ensure `prev_uncertainty_code` tracks *text token positions only* |
143
+ | **MoE Models (Mixtral)** | Inject before expert routing; uncertainty overhead is negligible compared to MoE computation |
144
+
145
+ ---
146
+
147
+ ### **Edge Cases & State Management**
148
+
149
+ 1. **Dynamic Sequence Lengths**: The padding/truncation logic ensures `prev_uncertainty_code` always matches current `seq_len`
150
+ 2. **Batch Size Changes**: When batch size changes mid-generation, reinitialize with neutral codes
151
+ 3. **KV Cache**: `prev_uncertainty_code` *does not* participate in KV cache; it's purely a side-channel
152
+ 4. **Gradient Checkpointing**: The mechanism is checkpointing-safe; embeddings are recomputed during backward
153
+ 5. **Multi-GPU**: `uncertainty_embeddings` are part of model parameters and get sharded automatically; `prev_uncertainty_code` stays on same device as model
154
+
155
+ ---
156
+
157
+ ### **Performance Characteristics**
158
+
159
+ | Component | Parameters | FLOPs | Memory | Latency |
160
+ |-----------|------------|-------|--------|---------|
161
+ | Uncertainty Embeddings | `65,536 × hidden_dim` | 0 | ~134MB (if d=2048) | Negligible |
162
+ | Entropy Computation | 0 | `O(B×S×V)` | O(1) | <0.1ms |
163
+ | Embedding Addition | 0 | `O(B×S×D)` | O(1) | <0.01ms |
164
+
165
+ **Total Overhead**: <1% additional compute, ~0.1% additional memory
166
+
167
+ ---
168
+
169
+ ### **Theoretical Intuition**
170
+
171
+ PRISMA transforms autoregressive generation from a **memoryless process** P(y_t | x, y_<t) into a **stateful process** P(y_t | x, y_<t, H(P(y_<t))), where H is the uncertainty quantizer. The model learns to use this "confidence memory" to:
172
+ - Be more **cautious** after uncertain predictions (high c_t → modified embeddings)
173
+ - Maintain **momentum** after confident predictions (low c_t → near-zero injection)
174
+ - Develop **meta-cognitive strategies** without architectural depth changes
175
+
176
+ The 16-bit quantization provides sufficient resolution (65,536 levels) to capture subtle confidence gradations while maintaining a computationally efficient lookup table.
177
+
178
+ # Recipe: Add 16-bit Uncertainty Feedback to Any Model
179
+
180
+ ## Goal
181
+
182
+ Feed model confidence from the previous step back into the next token's embedding using entropy-based uncertainty. The signal rides along through the network, gaining strength until the model can act on it explicitly.
183
+
184
+ ---
185
+
186
+ ## 1. Add persistent uncertainty state
187
+
188
+ **Where:** model `__init__`
189
+
190
+ ```python
191
+ self.n_bits = 16
192
+ self.n_levels = 2 ** self.n_bits
193
+ self.uncertainty_embed = nn.Embedding(self.n_levels, hidden_size)
194
+ self.register_buffer("prev_uncertainty_code", None)
195
+ ```
196
+
197
+ ---
198
+
199
+ ## 2. Inject uncertainty into input embeddings
200
+
201
+ **Where:** model `forward`, immediately after `inputs_embeds` is created
202
+
203
+ ```python
204
+ B, T, D = inputs_embeds.shape
205
+
206
+ if self.prev_uncertainty_code is None or self.prev_uncertainty_code.shape[0] != B:
207
+ code = torch.full((B, T), self.n_levels // 2, dtype=torch.long, device=inputs_embeds.device)
208
+ else:
209
+ prev = self.prev_uncertainty_code
210
+ if prev.shape[1] >= T:
211
+ code = prev[:, :T]
212
+ else:
213
+ code = F.pad(prev, (0, T - prev.shape[1]), value=self.n_levels // 2)
214
+
215
+ u = self.uncertainty_embed(code)
216
+ u = F.pad(u[:, :-1], (0, 0, 1, 0)) # shift right: position i gets uncertainty from i-1
217
+ inputs_embeds = inputs_embeds + u
218
+ ```
219
+
220
+ ---
221
+
222
+ ## 3. Compute uncertainty from logits
223
+
224
+ **Where:** LM head `forward`, after logits are computed
225
+
226
+ Note: If your buffer lives on an inner model (e.g., `self.model`), update `self.model.prev_uncertainty_code` instead.
227
+
228
+ ```python
229
+ with torch.no_grad():
230
+ probs = logits.softmax(dim=-1)
231
+ entropy = -(probs * torch.log(probs.clamp_min(1e-9))).sum(dim=-1)
232
+ entropy = entropy / math.log(probs.size(-1)) # normalize to [0, 1]
233
+ self.prev_uncertainty_code = (
234
+ entropy * (self.n_levels - 1)
235
+ ).long().clamp(0, self.n_levels - 1)
236
+ ```
237
+
238
+ ---
239
+
240
+ ## 4. Reset hook
241
+
242
+ ```python
243
+ def reset_uncertainty(self):
244
+ self.prev_uncertainty_code = None
245
+ ```
246
+
247
+ Call before each new generation or when switching between unrelated sequences.
248
+
249
+ ---
250
+
251
+ ## 5. Generation rules
252
+
253
+ - **Do NOT** clear `prev_uncertainty_code` between decoding steps within a sequence
254
+ - **DO** clear it between unrelated sequences or batches
255
+
256
+ ---
257
+
258
+ ## Why this works
259
+
260
+ The uncertainty signal rides along in the residual stream from the first layer. Early on, it competes with stronger signals—token semantics, position, attention patterns. But because it correlates with prediction difficulty, the model learns to preserve it.
261
+
262
+ By approximately two-thirds through the network, the signal has accumulated enough relative strength to influence decisions explicitly. The model doesn't just *feel* uncertain—it can *act* on uncertainty: hedge, qualify, or change course.
263
+
264
+ You don't inject at a specific layer because you don't know where introspection should live. The model discovers that for itself. You just ensure the information is present from the start.
265
+
266
+ ---
267
+
268
+ ## What to watch for
269
+
270
+ - **Ablation test:** Zero out the uncertainty injection and measure perplexity change. If it hurts, the signal is being used.
271
+ - **Attention probe:** Check whether high-uncertainty positions receive more attention in later layers.
272
+ - **Behavioral test:** Does the model hedge more after high-entropy predictions? Does it recover better from mistakes?
273
+
274
+ If uncertainty is truly integrated, the model's *behavior* will reflect its confidence—not because you trained it to say "I'm not sure," but because knowing its own uncertainty became useful for prediction.