PRISMA V2: Joint Uncertainty Prediction Mechanism — Implementation Specification
Architecture Overview: PRISMA V2 replaces Python-side uncertainty state with a learned, explicit uncertainty latent predicted jointly with tokens. At each step, the model predicts both the next token and an uncertainty code that conditions the following step. This preserves temporal introspection while remaining fully compatible with stateless inference engines.
Core Design Principle
Uncertainty must be data, not memory.
All information required for the next decoding step is carried explicitly through tensors (tokens, uncertainty codes, or cache), never through mutable module state.
Differences from Prisma V1 (Detailed)
Prisma V2 is not a minor refactor of Prisma V1. It represents a fundamental shift in how uncertainty is represented, propagated, and learned.
This section documents those differences precisely.
1. Source of Uncertainty
Prisma V1
- Uncertainty is measured post-hoc from the model’s output distribution
- Computed via entropy of logits
- Acts as an external diagnostic signal
uncertainty_t = H(P(y_t))
Prisma V2
- Uncertainty is predicted by the model itself
- Learned as an auxiliary latent variable
- Acts as an internal representation
(token_{t+1}, uncertainty_{t+1}) = f(token_t, uncertainty_t)
Implication: V1 answers “how uncertain was I?” V2 answers “how uncertain will I be?”
2. State Representation
Prisma V1
- Uses mutable Python-side state:
self.prev_uncertainty_code
- State exists outside the model’s forward graph
- Relies on strict step-by-step execution order
Prisma V2
- No mutable state
- Uncertainty is passed explicitly as a tensor:
uncertainty_codes: Tensor[B, S]
- Fully contained within the model’s inputs and outputs
Implication: V1 requires engine cooperation. V2 requires only tensors.
3. Runtime Compatibility
| Runtime | Prisma V1 | Prisma V2 |
|---|---|---|
| HuggingFace Transformers | ✅ | ✅ |
| vLLM | ❌ | ✅ |
| llama.cpp | ❌ | ✅ |
| MLX | ❌ | ✅ |
| Tensor Parallel | ⚠️ | ✅ |
Reason:
- V1 violates the stateless decoding assumptions of modern runtimes
- V2 conforms to them by construction
4. Temporal Feedback Mechanism
Prisma V1
- Feedback loop implemented via external buffer
- Requires padding, truncation, and shifting logic
- Not visible to KV cache or sampler
Prisma V2
- Feedback loop is architectural
- Uncertainty is predicted one step ahead and injected naturally
- Temporal alignment is implicit in training and decoding
Implication: V2’s feedback loop is native, not simulated.
5. Learning Dynamics
Prisma V1
- Uncertainty signal is fixed (entropy)
- Model can only learn how to react to uncertainty
- Cannot redefine what uncertainty means
Prisma V2
Uncertainty is supervised initially by entropy, then free to diverge
Model can learn:
- epistemic uncertainty
- ambiguity
- distribution shift
- task-specific hesitation signals
Implication: V1 teaches response to uncertainty. V2 teaches representation of uncertainty.
6. Training Complexity
Prisma V1
- No additional loss
- Entropy computed every forward
- Sensitive to tensor parallel sharding
Prisma V2
- Adds a lightweight auxiliary loss
- Entropy used only as a teacher signal during training
- No entropy computation at inference
Implication: V2 trades a small training cost for large inference robustness.
7. Inference Behavior
Prisma V1
- Uncertainty exists only implicitly
- Difficult to inspect or intervene at runtime
- Breaks under batched or reordered decoding
Prisma V2
- Uncertainty is explicit and inspectable
- Sampler can condition on it
- Works under any batching or scheduling strategy
8. Conceptual Framing
Prisma V1
- Introspection via measurement
- Confidence is something the model observes after the fact
Prisma V2
- Introspection via prediction
- Confidence is something the model reasons about and plans with
Prisma V1 makes the model aware of its uncertainty. Prisma V2 makes uncertainty part of the model’s internal world model.
Summary Table
| Dimension | Prisma V1 | Prisma V2 |
|---|---|---|
| Uncertainty source | Entropy (measured) | Learned latent |
| State handling | Mutable buffer | Explicit tensor |
| Runtime support | Limited | Universal |
| KV cache compatibility | ❌ | ✅ |
| Tensor parallel | Fragile | Safe |
| Introspection depth | Reactive | Predictive |
| Deployment readiness | Research-only | Production-capable |
Why Prisma V2 Exists
Prisma V1 demonstrated that temporal uncertainty feedback produces introspective behavior.
Prisma V2 makes that insight architectural, portable, and deployable.
It is not a workaround. It is the correct abstraction boundary.
Uncertainty must be data, not memory.
Core Components to Add
# In your CausalLM class
self.n_uncertainty_levels = 256 # V2: compact, sufficient
self.uncertainty_embeddings = nn.Embedding(
self.n_uncertainty_levels,
hidden_dim
)
# NEW: Uncertainty prediction head
self.uncertainty_head = nn.Linear(
hidden_dim,
self.n_uncertainty_levels,
bias=False
)
Initialization Details
Uncertainty Embeddings
- Initialized from
N(0, σ²)whereσ = config.initializer_range
Uncertainty Head (Important)
self.uncertainty_head.weight.data.zero_()
Rationale:
- Model initially predicts neutral uncertainty
- Early training behaves identically to the base model
- Avoids destabilizing LM loss with noisy auxiliary signals
- Uncertainty pathway is learned gradually
Forward Pass Modifications (Input Side)
Location: Immediately after token embedding lookup
def forward(self, input_ids, uncertainty_codes=None, ...):
inputs_embeds = self.embed_tokens(input_ids)
if uncertainty_codes is not None:
# uncertainty_codes: [B, S]
u = self.uncertainty_embeddings(uncertainty_codes)
inputs_embeds = inputs_embeds + u
hidden_states = self.model(
inputs_embeds=inputs_embeds,
...
).last_hidden_state
uncertainty_codes[t]conditions token positiont- No padding, truncation, or shifting logic required
- Temporal alignment is handled by the training and decoding loop
Forward Pass Modifications (Output Side)
Location: After transformer hidden states
logits = self.lm_head(hidden_states)
uncertainty_logits = self.uncertainty_head(hidden_states)
Returns:
return {
"logits": logits, # [B, S, vocab]
"uncertainty_logits": uncertainty_logits # [B, S, n_uncertainty_levels]
}
Temporal Semantics
| Position | Input | Predicts |
|---|---|---|
| t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ |
This preserves the original PRISMA temporal feedback loop without mutable state.
Training Objective
Language Modeling Loss
Standard next-token prediction:
loss_lm = cross_entropy(
logits[:, :-1],
labels[:, 1:]
)
Uncertainty Prediction Loss
Uncertainty is predicted one step ahead:
loss_uncertainty = cross_entropy(
uncertainty_logits[:, :-1],
uncertainty_labels[:, 1:]
)
Combined Loss
loss = loss_lm + λ * loss_uncertainty
- Recommended:
λ ≈ 0.1(to tune)
Uncertainty Supervision (Teacher Signal)
During training only, entropy is used as a bootstrap target, not as the definition of uncertainty.
with torch.no_grad():
probs = softmax(logits)
entropy = -(probs * log(probs)).sum(dim=-1)
entropy_norm = entropy / log(vocab_size)
uncertainty_labels = quantize(entropy_norm)
Important:
- Entropy is a teacher, not a constraint
- The model may learn uncertainty signals that diverge from entropy
- This is desirable if they correlate better with error or ambiguity
Single-Pass Training (Preferred)
A second forward pass is not required.
outputs = model(
input_ids,
uncertainty_codes=uncertainty_input
)
with torch.no_grad():
uncertainty_labels = compute_uncertainty_labels(outputs.logits)
loss = compute_loss(
outputs.logits,
outputs.uncertainty_logits,
labels,
uncertainty_labels
)
Inference Loop (All Runtimes)
(tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁)
Neutral Start
uncertainty_code = n_uncertainty_levels // 2
Runtime Integration
| Runtime | Integration |
|---|---|
| Transformers | Custom generate() tracks uncertainty_code tensor |
| vLLM | Sampler tracks one uncertainty code per request |
| llama.cpp | Store uncertainty code in llama_context |
| MLX | Works directly (pure tensor graph) |
No runtime relies on Python-side mutable state.
Performance Characteristics
| Component | Parameters | FLOPs | Memory | Latency |
|---|---|---|---|---|
| Uncertainty Head | hidden_dim × 256 |
Negligible | Negligible | ~0 |
| Uncertainty Embedding | 256 × hidden_dim |
0 | Tiny | ~0 |
| Entropy (training only) | 0 | O(B×S×V) |
O(1) | Not in inference |
Inference overhead: effectively zero
Theoretical Intuition
PRISMA V2 transforms autoregressive generation from:
P(y_t | x, y_<t)
to:
P(y_t, c_t | x, y_<t, c_<t)
where c_t is a learned uncertainty latent.
This allows the model to:
- Reduce commitment after uncertain predictions
- Maintain momentum after confident predictions
- Learn task-specific uncertainty signals
- Develop introspection without relying on engine-level state
Why PRISMA V2 Works Everywhere
| Constraint | V1 | V2 |
|---|---|---|
| Stateless decoding | ❌ | ✅ |
| vLLM batching | ❌ | ✅ |
| llama.cpp KV cache | ❌ | ✅ |
| Tensor parallel | ⚠️ | ✅ |
| MLX tracing | ❌ | ✅ |
What to Watch For
- Ablation: remove uncertainty input, measure perplexity / behavior
- Calibration: does predicted uncertainty correlate with error?
- Behavioral shifts: hedging, correction, abstention
- Divergence from entropy: expected and healthy
Summary
Prisma V2 preserves the introspective insight of Prisma V1 while replacing fragile mutable state with an explicit, learned uncertainty representation. This makes introspection portable, scalable, and deployable across all modern inference engines.
The model no longer measures uncertainty — it learns what uncertainty means.