Prisma-VL-8B / PrismaV2_ToDo.md
ehartford's picture
Create PrismaV2_ToDo.md
6c0624c verified

Prisma V2 — ToDo.md

Goal

Evolve PRISMA from a Python-stateful uncertainty feedback mechanism (V1) into a fully stateless, runtime-portable architecture compatible with:

  • Hugging Face Transformers
  • vLLM
  • llama.cpp
  • MLX

while preserving the core idea:

Condition future predictions on the model’s own uncertainty to reduce confident errors and hallucinations.


Background (Why V2 Is Needed)

Prisma V1 relies on mutable per-step Python state:

self.prev_uncertainty_code

This works in Hugging Face generate(), but fails in modern inference engines:

  • vLLM — batched, reordered, stateless execution
  • llama.cpp — compiled C/C++ loop with fixed KV cache semantics
  • MLX — pure functional graph execution

All target runtimes require decoding to be purely functional:

All information needed for the next step must be carried explicitly in tensors, tokens, or KV cache — not Python object state.


Core Design Change (Prisma V2)

Replace measured uncertainty with predicted uncertainty

Instead of:

  • computing entropy post-hoc from logits
  • storing uncertainty in a mutable buffer

The model learns to predict uncertainty directly.

At each position, the model outputs:

  1. Next-token logits (existing)
  2. Uncertainty logits for the next position (new)

Uncertainty becomes a learned latent variable, not a side effect.


Architecture Changes

1. Add an uncertainty prediction head

self.n_uncertainty_levels = 256  # V2: smaller, sufficient
self.uncertainty_head = nn.Linear(hidden_size, n_uncertainty_levels, bias=False)

2. Uncertainty head initialization (important)

The uncertainty head is added to a pretrained model. To avoid destabilizing early training, use zero initialization:

self.uncertainty_head.weight.data.zero_()

Rationale:

  • Model initially predicts neutral uncertainty everywhere
  • Early training behaves identically to the base model
  • Uncertainty signal is learned gradually, only when useful

Alternative initializations (small random, copying from lm_head) are left for experimentation.


3. Keep uncertainty embeddings (input side)

self.uncertainty_embeddings = nn.Embedding(
    n_uncertainty_levels,
    hidden_size
)

4. Modify forward signature

def forward(
    input_ids: torch.Tensor,
    uncertainty_codes: Optional[torch.Tensor] = None,  # [B, S]
    **kwargs
):
  • uncertainty_codes[t] conditions token t
  • No hidden buffers
  • No mutable Python state

Forward Pass Logic (V2)

  1. Embed tokens

  2. If uncertainty_codes provided:

    • lookup uncertainty_embeddings
    • add to inputs_embeds
  3. Run transformer

  4. Produce:

    • logits (next token)
    • uncertainty_logits (next uncertainty)
return {
    "logits": logits,
    "uncertainty_logits": uncertainty_logits,
}

Temporal Semantics

Position Input Predicts
t tokenₜ + uncertaintyₜ tokenₜ₊₁, uncertaintyₜ₊₁

This preserves the original PRISMA temporal feedback loop without mutable state.


Training Plan

Uncertainty supervision (teacher signal)

During training, entropy is used only as a teacher signal, not as the definition of uncertainty.

entropy = -∑ p log p
normalized = entropy / log(vocab_size)
uncertainty_label = quantize(normalized)

Single-pass training (preferred)

A second forward pass is not required.

outputs = model(
    input_ids,
    uncertainty_codes=uncertainty_input
)

with torch.no_grad():
    uncertainty_labels = quantize_entropy(outputs.logits)

loss = (
    loss_lm(outputs.logits, labels)
    + λ * loss_uncertainty(outputs.uncertainty_logits, uncertainty_labels)
)

Key point:

  • Entropy is a bootstrap target
  • The model is free to learn uncertainty representations that diverge from entropy over time
  • This allows uncertainty to correlate better with error than raw entropy does

Loss definition

loss = loss_lm + λ * loss_uncertainty
  • loss_lm: standard next-token cross-entropy
  • loss_uncertainty: cross-entropy over uncertainty codes
  • λ ≈ 0.1 (to tune)

Inference (All Runtimes)

Decode loop (conceptual)

(tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁)

Runtime responsibilities

  • Transformers: custom generate() tracks uncertainty tensor
  • vLLM: sampler tracks uncertainty_code per request
  • llama.cpp: store one small uncertainty code in llama_context
  • MLX: works naturally (pure tensor graph)

No runtime needs to preserve Python object state.


Compatibility Matrix

Runtime Prisma V1 Prisma V2
Transformers
vLLM
llama.cpp
MLX
Tensor Parallel ⚠️

Design Decisions

Classification vs regression

Chosen: classification (quantized uncertainty)

Reasons:

  • Stable training
  • Matches embedding lookup
  • Discrete semantics
  • Easier runtime handling

Regression remains an experimental alternative.


Uncertainty resolution

  • V1: 65,536 levels (overkill)
  • V2: 256 levels (sufficient, efficient, portable)

Known Limitations

  • Uncertainty reflects model confidence, not correctness
  • Learned uncertainty may differ from Shannon entropy
  • Does not guarantee abstention or correctness
  • Behavior depends on training data and loss weighting

Open Questions / Future Work

  • Tune n_uncertainty_levels
  • Tune λ (uncertainty loss weight)
  • Explore uncertainty-aware decoding strategies
  • Compare uncertainty prediction vs uncertainty tokens
  • Investigate bootstrapping without entropy supervision

Definition of Done (Prisma V2)

  • No mutable per-step Python state
  • Uncertainty passed explicitly as tensor
  • Works in Transformers, vLLM, llama.cpp, MLX
  • Zero-init uncertainty head
  • Single-pass training loop
  • Updated model card + documentation
  • Reference implementation + example

Guiding Principle (V2)

Uncertainty must be data, not memory.