Prisma-VL-8B / PRISMAv2.md
ehartford's picture
Create PRISMAv2.md
57377d2 verified
## PRISMA V2: Joint Uncertainty Prediction Mechanism — Implementation Specification
**Architecture Overview**:
PRISMA V2 replaces Python-side uncertainty state with a **learned, explicit uncertainty latent** predicted jointly with tokens. At each step, the model predicts both the next token *and* an uncertainty code that conditions the following step. This preserves temporal introspection while remaining fully compatible with stateless inference engines.
---
## **Core Design Principle**
> **Uncertainty must be data, not memory.**
All information required for the next decoding step is carried explicitly through tensors (tokens, uncertainty codes, or cache), never through mutable module state.
---
## **Differences from Prisma V1 (Detailed)**
Prisma V2 is not a minor refactor of Prisma V1. It represents a **fundamental shift in how uncertainty is represented, propagated, and learned**.
This section documents those differences precisely.
---
### **1. Source of Uncertainty**
**Prisma V1**
* Uncertainty is **measured post-hoc** from the model’s output distribution
* Computed via entropy of logits
* Acts as an external diagnostic signal
```text
uncertainty_t = H(P(y_t))
```
**Prisma V2**
* Uncertainty is **predicted by the model itself**
* Learned as an auxiliary latent variable
* Acts as an internal representation
```text
(token_{t+1}, uncertainty_{t+1}) = f(token_t, uncertainty_t)
```
**Implication**:
V1 answers *“how uncertain was I?”*
V2 answers *“how uncertain will I be?”*
---
### **2. State Representation**
**Prisma V1**
* Uses mutable Python-side state:
```python
self.prev_uncertainty_code
```
* State exists **outside** the model’s forward graph
* Relies on strict step-by-step execution order
**Prisma V2**
* No mutable state
* Uncertainty is passed explicitly as a tensor:
```python
uncertainty_codes: Tensor[B, S]
```
* Fully contained within the model’s inputs and outputs
**Implication**:
V1 requires engine cooperation.
V2 requires only tensors.
---
### **3. Runtime Compatibility**
| Runtime | Prisma V1 | Prisma V2 |
| ------------------------ | --------- | --------- |
| HuggingFace Transformers | ✅ | ✅ |
| vLLM | ❌ | ✅ |
| llama.cpp | ❌ | ✅ |
| MLX | ❌ | ✅ |
| Tensor Parallel | ⚠️ | ✅ |
**Reason**:
* V1 violates the stateless decoding assumptions of modern runtimes
* V2 conforms to them by construction
---
### **4. Temporal Feedback Mechanism**
**Prisma V1**
* Feedback loop implemented via external buffer
* Requires padding, truncation, and shifting logic
* Not visible to KV cache or sampler
**Prisma V2**
* Feedback loop is **architectural**
* Uncertainty is predicted one step ahead and injected naturally
* Temporal alignment is implicit in training and decoding
**Implication**:
V2’s feedback loop is **native**, not simulated.
---
### **5. Learning Dynamics**
**Prisma V1**
* Uncertainty signal is fixed (entropy)
* Model can only learn *how to react* to uncertainty
* Cannot redefine what uncertainty means
**Prisma V2**
* Uncertainty is supervised initially by entropy, then free to diverge
* Model can learn:
* epistemic uncertainty
* ambiguity
* distribution shift
* task-specific hesitation signals
**Implication**:
V1 teaches *response to uncertainty*.
V2 teaches *representation of uncertainty*.
---
### **6. Training Complexity**
**Prisma V1**
* No additional loss
* Entropy computed every forward
* Sensitive to tensor parallel sharding
**Prisma V2**
* Adds a lightweight auxiliary loss
* Entropy used only as a teacher signal during training
* No entropy computation at inference
**Implication**:
V2 trades a small training cost for large inference robustness.
---
### **7. Inference Behavior**
**Prisma V1**
* Uncertainty exists only implicitly
* Difficult to inspect or intervene at runtime
* Breaks under batched or reordered decoding
**Prisma V2**
* Uncertainty is explicit and inspectable
* Sampler can condition on it
* Works under any batching or scheduling strategy
---
### **8. Conceptual Framing**
**Prisma V1**
* Introspection via *measurement*
* Confidence is something the model observes after the fact
**Prisma V2**
* Introspection via *prediction*
* Confidence is something the model reasons about and plans with
> Prisma V1 makes the model *aware of its uncertainty.*
> Prisma V2 makes uncertainty part of the model’s internal world model.
---
### **Summary Table**
| Dimension | Prisma V1 | Prisma V2 |
| ---------------------- | ------------------ | ------------------ |
| Uncertainty source | Entropy (measured) | Learned latent |
| State handling | Mutable buffer | Explicit tensor |
| Runtime support | Limited | Universal |
| KV cache compatibility | ❌ | ✅ |
| Tensor parallel | Fragile | Safe |
| Introspection depth | Reactive | Predictive |
| Deployment readiness | Research-only | Production-capable |
---
### **Why Prisma V2 Exists**
Prisma V1 demonstrated that **temporal uncertainty feedback produces introspective behavior**.
Prisma V2 makes that insight **architectural, portable, and deployable**.
It is not a workaround.
It is the correct abstraction boundary.
> *Uncertainty must be data, not memory.*
---
## **Core Components to Add**
```python
# In your CausalLM class
self.n_uncertainty_levels = 256 # V2: compact, sufficient
self.uncertainty_embeddings = nn.Embedding(
self.n_uncertainty_levels,
hidden_dim
)
# NEW: Uncertainty prediction head
self.uncertainty_head = nn.Linear(
hidden_dim,
self.n_uncertainty_levels,
bias=False
)
```
---
## **Initialization Details**
### Uncertainty Embeddings
* Initialized from `N(0, σ²)` where `σ = config.initializer_range`
### Uncertainty Head (Important)
```python
self.uncertainty_head.weight.data.zero_()
```
**Rationale**:
* Model initially predicts *neutral uncertainty*
* Early training behaves identically to the base model
* Avoids destabilizing LM loss with noisy auxiliary signals
* Uncertainty pathway is learned gradually
---
## **Forward Pass Modifications (Input Side)**
**Location**: *Immediately after token embedding lookup*
```python
def forward(self, input_ids, uncertainty_codes=None, ...):
inputs_embeds = self.embed_tokens(input_ids)
if uncertainty_codes is not None:
# uncertainty_codes: [B, S]
u = self.uncertainty_embeddings(uncertainty_codes)
inputs_embeds = inputs_embeds + u
hidden_states = self.model(
inputs_embeds=inputs_embeds,
...
).last_hidden_state
```
* `uncertainty_codes[t]` conditions token position `t`
* No padding, truncation, or shifting logic required
* Temporal alignment is handled by the training and decoding loop
---
## **Forward Pass Modifications (Output Side)**
**Location**: *After transformer hidden states*
```python
logits = self.lm_head(hidden_states)
uncertainty_logits = self.uncertainty_head(hidden_states)
```
Returns:
```python
return {
"logits": logits, # [B, S, vocab]
"uncertainty_logits": uncertainty_logits # [B, S, n_uncertainty_levels]
}
```
---
## **Temporal Semantics**
| Position | Input | Predicts |
| -------- | --------------------- | ------------------------ |
| t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ |
This preserves the original PRISMA temporal feedback loop without mutable state.
---
## **Training Objective**
### Language Modeling Loss
Standard next-token prediction:
```python
loss_lm = cross_entropy(
logits[:, :-1],
labels[:, 1:]
)
```
---
### Uncertainty Prediction Loss
Uncertainty is predicted **one step ahead**:
```python
loss_uncertainty = cross_entropy(
uncertainty_logits[:, :-1],
uncertainty_labels[:, 1:]
)
```
---
### Combined Loss
```python
loss = loss_lm + λ * loss_uncertainty
```
* Recommended: `λ ≈ 0.1` (to tune)
---
## **Uncertainty Supervision (Teacher Signal)**
During training only, entropy is used as a **bootstrap target**, not as the definition of uncertainty.
```python
with torch.no_grad():
probs = softmax(logits)
entropy = -(probs * log(probs)).sum(dim=-1)
entropy_norm = entropy / log(vocab_size)
uncertainty_labels = quantize(entropy_norm)
```
**Important**:
* Entropy is a *teacher*, not a constraint
* The model may learn uncertainty signals that diverge from entropy
* This is desirable if they correlate better with error or ambiguity
---
## **Single-Pass Training (Preferred)**
A second forward pass is **not required**.
```python
outputs = model(
input_ids,
uncertainty_codes=uncertainty_input
)
with torch.no_grad():
uncertainty_labels = compute_uncertainty_labels(outputs.logits)
loss = compute_loss(
outputs.logits,
outputs.uncertainty_logits,
labels,
uncertainty_labels
)
```
---
## **Inference Loop (All Runtimes)**
```text
(tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁)
```
### Neutral Start
```python
uncertainty_code = n_uncertainty_levels // 2
```
---
## **Runtime Integration**
| Runtime | Integration |
| ---------------- | ---------------------------------------------------- |
| **Transformers** | Custom `generate()` tracks `uncertainty_code` tensor |
| **vLLM** | Sampler tracks one uncertainty code per request |
| **llama.cpp** | Store uncertainty code in `llama_context` |
| **MLX** | Works directly (pure tensor graph) |
No runtime relies on Python-side mutable state.
---
## **Performance Characteristics**
| Component | Parameters | FLOPs | Memory | Latency |
| ----------------------- | ------------------ | ---------- | ---------- | ---------------- |
| Uncertainty Head | `hidden_dim × 256` | Negligible | Negligible | ~0 |
| Uncertainty Embedding | `256 × hidden_dim` | 0 | Tiny | ~0 |
| Entropy (training only) | 0 | `O(B×S×V)` | O(1) | Not in inference |
**Inference overhead**: effectively zero
---
## **Theoretical Intuition**
PRISMA V2 transforms autoregressive generation from:
```
P(y_t | x, y_<t)
```
to:
```
P(y_t, c_t | x, y_<t, c_<t)
```
where `c_t` is a learned uncertainty latent.
This allows the model to:
* Reduce commitment after uncertain predictions
* Maintain momentum after confident predictions
* Learn task-specific uncertainty signals
* Develop introspection without relying on engine-level state
---
## **Why PRISMA V2 Works Everywhere**
| Constraint | V1 | V2 |
| ------------------ | -- | -- |
| Stateless decoding | ❌ | ✅ |
| vLLM batching | ❌ | ✅ |
| llama.cpp KV cache | ❌ | ✅ |
| Tensor parallel | ⚠️ | ✅ |
| MLX tracing | ❌ | ✅ |
---
## **What to Watch For**
* **Ablation**: remove uncertainty input, measure perplexity / behavior
* **Calibration**: does predicted uncertainty correlate with error?
* **Behavioral shifts**: hedging, correction, abstention
* **Divergence from entropy**: expected and healthy
---
## **Summary**
Prisma V2 preserves the introspective insight of Prisma V1 while replacing fragile mutable state with an explicit, learned uncertainty representation. This makes introspection **portable, scalable, and deployable** across all modern inference engines.
> *The model no longer measures uncertainty — it learns what uncertainty means.*