Prisma-VL-8B / PRISMAv2.md
ehartford's picture
Create PRISMAv2.md
57377d2 verified

PRISMA V2: Joint Uncertainty Prediction Mechanism — Implementation Specification

Architecture Overview: PRISMA V2 replaces Python-side uncertainty state with a learned, explicit uncertainty latent predicted jointly with tokens. At each step, the model predicts both the next token and an uncertainty code that conditions the following step. This preserves temporal introspection while remaining fully compatible with stateless inference engines.


Core Design Principle

Uncertainty must be data, not memory.

All information required for the next decoding step is carried explicitly through tensors (tokens, uncertainty codes, or cache), never through mutable module state.


Differences from Prisma V1 (Detailed)

Prisma V2 is not a minor refactor of Prisma V1. It represents a fundamental shift in how uncertainty is represented, propagated, and learned.

This section documents those differences precisely.


1. Source of Uncertainty

Prisma V1

  • Uncertainty is measured post-hoc from the model’s output distribution
  • Computed via entropy of logits
  • Acts as an external diagnostic signal
uncertainty_t = H(P(y_t))

Prisma V2

  • Uncertainty is predicted by the model itself
  • Learned as an auxiliary latent variable
  • Acts as an internal representation
(token_{t+1}, uncertainty_{t+1}) = f(token_t, uncertainty_t)

Implication: V1 answers “how uncertain was I?” V2 answers “how uncertain will I be?”


2. State Representation

Prisma V1

  • Uses mutable Python-side state:
self.prev_uncertainty_code
  • State exists outside the model’s forward graph
  • Relies on strict step-by-step execution order

Prisma V2

  • No mutable state
  • Uncertainty is passed explicitly as a tensor:
uncertainty_codes: Tensor[B, S]
  • Fully contained within the model’s inputs and outputs

Implication: V1 requires engine cooperation. V2 requires only tensors.


3. Runtime Compatibility

Runtime Prisma V1 Prisma V2
HuggingFace Transformers
vLLM
llama.cpp
MLX
Tensor Parallel ⚠️

Reason:

  • V1 violates the stateless decoding assumptions of modern runtimes
  • V2 conforms to them by construction

4. Temporal Feedback Mechanism

Prisma V1

  • Feedback loop implemented via external buffer
  • Requires padding, truncation, and shifting logic
  • Not visible to KV cache or sampler

Prisma V2

  • Feedback loop is architectural
  • Uncertainty is predicted one step ahead and injected naturally
  • Temporal alignment is implicit in training and decoding

Implication: V2’s feedback loop is native, not simulated.


5. Learning Dynamics

Prisma V1

  • Uncertainty signal is fixed (entropy)
  • Model can only learn how to react to uncertainty
  • Cannot redefine what uncertainty means

Prisma V2

  • Uncertainty is supervised initially by entropy, then free to diverge

  • Model can learn:

    • epistemic uncertainty
    • ambiguity
    • distribution shift
    • task-specific hesitation signals

Implication: V1 teaches response to uncertainty. V2 teaches representation of uncertainty.


6. Training Complexity

Prisma V1

  • No additional loss
  • Entropy computed every forward
  • Sensitive to tensor parallel sharding

Prisma V2

  • Adds a lightweight auxiliary loss
  • Entropy used only as a teacher signal during training
  • No entropy computation at inference

Implication: V2 trades a small training cost for large inference robustness.


7. Inference Behavior

Prisma V1

  • Uncertainty exists only implicitly
  • Difficult to inspect or intervene at runtime
  • Breaks under batched or reordered decoding

Prisma V2

  • Uncertainty is explicit and inspectable
  • Sampler can condition on it
  • Works under any batching or scheduling strategy

8. Conceptual Framing

Prisma V1

  • Introspection via measurement
  • Confidence is something the model observes after the fact

Prisma V2

  • Introspection via prediction
  • Confidence is something the model reasons about and plans with

Prisma V1 makes the model aware of its uncertainty. Prisma V2 makes uncertainty part of the model’s internal world model.


Summary Table

Dimension Prisma V1 Prisma V2
Uncertainty source Entropy (measured) Learned latent
State handling Mutable buffer Explicit tensor
Runtime support Limited Universal
KV cache compatibility
Tensor parallel Fragile Safe
Introspection depth Reactive Predictive
Deployment readiness Research-only Production-capable

Why Prisma V2 Exists

Prisma V1 demonstrated that temporal uncertainty feedback produces introspective behavior.

Prisma V2 makes that insight architectural, portable, and deployable.

It is not a workaround. It is the correct abstraction boundary.

Uncertainty must be data, not memory.


Core Components to Add

# In your CausalLM class
self.n_uncertainty_levels = 256  # V2: compact, sufficient
self.uncertainty_embeddings = nn.Embedding(
    self.n_uncertainty_levels,
    hidden_dim
)

# NEW: Uncertainty prediction head
self.uncertainty_head = nn.Linear(
    hidden_dim,
    self.n_uncertainty_levels,
    bias=False
)

Initialization Details

Uncertainty Embeddings

  • Initialized from N(0, σ²) where σ = config.initializer_range

Uncertainty Head (Important)

self.uncertainty_head.weight.data.zero_()

Rationale:

  • Model initially predicts neutral uncertainty
  • Early training behaves identically to the base model
  • Avoids destabilizing LM loss with noisy auxiliary signals
  • Uncertainty pathway is learned gradually

Forward Pass Modifications (Input Side)

Location: Immediately after token embedding lookup

def forward(self, input_ids, uncertainty_codes=None, ...):
    inputs_embeds = self.embed_tokens(input_ids)

    if uncertainty_codes is not None:
        # uncertainty_codes: [B, S]
        u = self.uncertainty_embeddings(uncertainty_codes)
        inputs_embeds = inputs_embeds + u

    hidden_states = self.model(
        inputs_embeds=inputs_embeds,
        ...
    ).last_hidden_state
  • uncertainty_codes[t] conditions token position t
  • No padding, truncation, or shifting logic required
  • Temporal alignment is handled by the training and decoding loop

Forward Pass Modifications (Output Side)

Location: After transformer hidden states

logits = self.lm_head(hidden_states)
uncertainty_logits = self.uncertainty_head(hidden_states)

Returns:

return {
    "logits": logits,                      # [B, S, vocab]
    "uncertainty_logits": uncertainty_logits  # [B, S, n_uncertainty_levels]
}

Temporal Semantics

Position Input Predicts
t tokenₜ + uncertaintyₜ tokenₜ₊₁, uncertaintyₜ₊₁

This preserves the original PRISMA temporal feedback loop without mutable state.


Training Objective

Language Modeling Loss

Standard next-token prediction:

loss_lm = cross_entropy(
    logits[:, :-1],
    labels[:, 1:]
)

Uncertainty Prediction Loss

Uncertainty is predicted one step ahead:

loss_uncertainty = cross_entropy(
    uncertainty_logits[:, :-1],
    uncertainty_labels[:, 1:]
)

Combined Loss

loss = loss_lm + λ * loss_uncertainty
  • Recommended: λ ≈ 0.1 (to tune)

Uncertainty Supervision (Teacher Signal)

During training only, entropy is used as a bootstrap target, not as the definition of uncertainty.

with torch.no_grad():
    probs = softmax(logits)
    entropy = -(probs * log(probs)).sum(dim=-1)
    entropy_norm = entropy / log(vocab_size)
    uncertainty_labels = quantize(entropy_norm)

Important:

  • Entropy is a teacher, not a constraint
  • The model may learn uncertainty signals that diverge from entropy
  • This is desirable if they correlate better with error or ambiguity

Single-Pass Training (Preferred)

A second forward pass is not required.

outputs = model(
    input_ids,
    uncertainty_codes=uncertainty_input
)

with torch.no_grad():
    uncertainty_labels = compute_uncertainty_labels(outputs.logits)

loss = compute_loss(
    outputs.logits,
    outputs.uncertainty_logits,
    labels,
    uncertainty_labels
)

Inference Loop (All Runtimes)

(tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁)

Neutral Start

uncertainty_code = n_uncertainty_levels // 2

Runtime Integration

Runtime Integration
Transformers Custom generate() tracks uncertainty_code tensor
vLLM Sampler tracks one uncertainty code per request
llama.cpp Store uncertainty code in llama_context
MLX Works directly (pure tensor graph)

No runtime relies on Python-side mutable state.


Performance Characteristics

Component Parameters FLOPs Memory Latency
Uncertainty Head hidden_dim × 256 Negligible Negligible ~0
Uncertainty Embedding 256 × hidden_dim 0 Tiny ~0
Entropy (training only) 0 O(B×S×V) O(1) Not in inference

Inference overhead: effectively zero


Theoretical Intuition

PRISMA V2 transforms autoregressive generation from:

P(y_t | x, y_<t)

to:

P(y_t, c_t | x, y_<t, c_<t)

where c_t is a learned uncertainty latent.

This allows the model to:

  • Reduce commitment after uncertain predictions
  • Maintain momentum after confident predictions
  • Learn task-specific uncertainty signals
  • Develop introspection without relying on engine-level state

Why PRISMA V2 Works Everywhere

Constraint V1 V2
Stateless decoding
vLLM batching
llama.cpp KV cache
Tensor parallel ⚠️
MLX tracing

What to Watch For

  • Ablation: remove uncertainty input, measure perplexity / behavior
  • Calibration: does predicted uncertainty correlate with error?
  • Behavioral shifts: hedging, correction, abstention
  • Divergence from entropy: expected and healthy

Summary

Prisma V2 preserves the introspective insight of Prisma V1 while replacing fragile mutable state with an explicit, learned uncertainty representation. This makes introspection portable, scalable, and deployable across all modern inference engines.

The model no longer measures uncertainty — it learns what uncertainty means.