Prisma V2 — ToDo.md
Goal
Evolve PRISMA from a Python-stateful uncertainty feedback mechanism (V1) into a fully stateless, runtime-portable architecture compatible with:
- Hugging Face Transformers
- vLLM
- llama.cpp
- MLX
while preserving the core idea:
Condition future predictions on the model’s own uncertainty to reduce confident errors and hallucinations.
Background (Why V2 Is Needed)
Prisma V1 relies on mutable per-step Python state:
self.prev_uncertainty_code
This works in Hugging Face generate(), but fails in modern inference engines:
- vLLM — batched, reordered, stateless execution
- llama.cpp — compiled C/C++ loop with fixed KV cache semantics
- MLX — pure functional graph execution
All target runtimes require decoding to be purely functional:
All information needed for the next step must be carried explicitly in tensors, tokens, or KV cache — not Python object state.
Core Design Change (Prisma V2)
Replace measured uncertainty with predicted uncertainty
Instead of:
- computing entropy post-hoc from logits
- storing uncertainty in a mutable buffer
The model learns to predict uncertainty directly.
At each position, the model outputs:
- Next-token logits (existing)
- Uncertainty logits for the next position (new)
Uncertainty becomes a learned latent variable, not a side effect.
Architecture Changes
1. Add an uncertainty prediction head
self.n_uncertainty_levels = 256 # V2: smaller, sufficient
self.uncertainty_head = nn.Linear(hidden_size, n_uncertainty_levels, bias=False)
2. Uncertainty head initialization (important)
The uncertainty head is added to a pretrained model. To avoid destabilizing early training, use zero initialization:
self.uncertainty_head.weight.data.zero_()
Rationale:
- Model initially predicts neutral uncertainty everywhere
- Early training behaves identically to the base model
- Uncertainty signal is learned gradually, only when useful
Alternative initializations (small random, copying from lm_head) are left for experimentation.
3. Keep uncertainty embeddings (input side)
self.uncertainty_embeddings = nn.Embedding(
n_uncertainty_levels,
hidden_size
)
4. Modify forward signature
def forward(
input_ids: torch.Tensor,
uncertainty_codes: Optional[torch.Tensor] = None, # [B, S]
**kwargs
):
uncertainty_codes[t]conditions tokent- No hidden buffers
- No mutable Python state
Forward Pass Logic (V2)
Embed tokens
If
uncertainty_codesprovided:- lookup
uncertainty_embeddings - add to
inputs_embeds
- lookup
Run transformer
Produce:
logits(next token)uncertainty_logits(next uncertainty)
return {
"logits": logits,
"uncertainty_logits": uncertainty_logits,
}
Temporal Semantics
| Position | Input | Predicts |
|---|---|---|
| t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ |
This preserves the original PRISMA temporal feedback loop without mutable state.
Training Plan
Uncertainty supervision (teacher signal)
During training, entropy is used only as a teacher signal, not as the definition of uncertainty.
entropy = -∑ p log p
normalized = entropy / log(vocab_size)
uncertainty_label = quantize(normalized)
Single-pass training (preferred)
A second forward pass is not required.
outputs = model(
input_ids,
uncertainty_codes=uncertainty_input
)
with torch.no_grad():
uncertainty_labels = quantize_entropy(outputs.logits)
loss = (
loss_lm(outputs.logits, labels)
+ λ * loss_uncertainty(outputs.uncertainty_logits, uncertainty_labels)
)
Key point:
- Entropy is a bootstrap target
- The model is free to learn uncertainty representations that diverge from entropy over time
- This allows uncertainty to correlate better with error than raw entropy does
Loss definition
loss = loss_lm + λ * loss_uncertainty
loss_lm: standard next-token cross-entropyloss_uncertainty: cross-entropy over uncertainty codes- λ ≈ 0.1 (to tune)
Inference (All Runtimes)
Decode loop (conceptual)
(tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁)
Runtime responsibilities
- Transformers: custom
generate()tracks uncertainty tensor - vLLM: sampler tracks
uncertainty_codeper request - llama.cpp: store one small uncertainty code in
llama_context - MLX: works naturally (pure tensor graph)
No runtime needs to preserve Python object state.
Compatibility Matrix
| Runtime | Prisma V1 | Prisma V2 |
|---|---|---|
| Transformers | ✅ | ✅ |
| vLLM | ❌ | ✅ |
| llama.cpp | ❌ | ✅ |
| MLX | ❌ | ✅ |
| Tensor Parallel | ⚠️ | ✅ |
Design Decisions
Classification vs regression
Chosen: classification (quantized uncertainty)
Reasons:
- Stable training
- Matches embedding lookup
- Discrete semantics
- Easier runtime handling
Regression remains an experimental alternative.
Uncertainty resolution
- V1: 65,536 levels (overkill)
- V2: 256 levels (sufficient, efficient, portable)
Known Limitations
- Uncertainty reflects model confidence, not correctness
- Learned uncertainty may differ from Shannon entropy
- Does not guarantee abstention or correctness
- Behavior depends on training data and loss weighting
Open Questions / Future Work
- Tune
n_uncertainty_levels - Tune λ (uncertainty loss weight)
- Explore uncertainty-aware decoding strategies
- Compare uncertainty prediction vs uncertainty tokens
- Investigate bootstrapping without entropy supervision
Definition of Done (Prisma V2)
- No mutable per-step Python state
- Uncertainty passed explicitly as tensor
- Works in Transformers, vLLM, llama.cpp, MLX
- Zero-init uncertainty head
- Single-pass training loop
- Updated model card + documentation
- Reference implementation + example
Guiding Principle (V2)
Uncertainty must be data, not memory.