| ## PRISMA V2: Joint Uncertainty Prediction Mechanism — Implementation Specification | |
| **Architecture Overview**: | |
| PRISMA V2 replaces Python-side uncertainty state with a **learned, explicit uncertainty latent** predicted jointly with tokens. At each step, the model predicts both the next token *and* an uncertainty code that conditions the following step. This preserves temporal introspection while remaining fully compatible with stateless inference engines. | |
| --- | |
| ## **Core Design Principle** | |
| > **Uncertainty must be data, not memory.** | |
| All information required for the next decoding step is carried explicitly through tensors (tokens, uncertainty codes, or cache), never through mutable module state. | |
| --- | |
| ## **Differences from Prisma V1 (Detailed)** | |
| Prisma V2 is not a minor refactor of Prisma V1. It represents a **fundamental shift in how uncertainty is represented, propagated, and learned**. | |
| This section documents those differences precisely. | |
| --- | |
| ### **1. Source of Uncertainty** | |
| **Prisma V1** | |
| * Uncertainty is **measured post-hoc** from the model’s output distribution | |
| * Computed via entropy of logits | |
| * Acts as an external diagnostic signal | |
| ```text | |
| uncertainty_t = H(P(y_t)) | |
| ``` | |
| **Prisma V2** | |
| * Uncertainty is **predicted by the model itself** | |
| * Learned as an auxiliary latent variable | |
| * Acts as an internal representation | |
| ```text | |
| (token_{t+1}, uncertainty_{t+1}) = f(token_t, uncertainty_t) | |
| ``` | |
| **Implication**: | |
| V1 answers *“how uncertain was I?”* | |
| V2 answers *“how uncertain will I be?”* | |
| --- | |
| ### **2. State Representation** | |
| **Prisma V1** | |
| * Uses mutable Python-side state: | |
| ```python | |
| self.prev_uncertainty_code | |
| ``` | |
| * State exists **outside** the model’s forward graph | |
| * Relies on strict step-by-step execution order | |
| **Prisma V2** | |
| * No mutable state | |
| * Uncertainty is passed explicitly as a tensor: | |
| ```python | |
| uncertainty_codes: Tensor[B, S] | |
| ``` | |
| * Fully contained within the model’s inputs and outputs | |
| **Implication**: | |
| V1 requires engine cooperation. | |
| V2 requires only tensors. | |
| --- | |
| ### **3. Runtime Compatibility** | |
| | Runtime | Prisma V1 | Prisma V2 | | |
| | ------------------------ | --------- | --------- | | |
| | HuggingFace Transformers | ✅ | ✅ | | |
| | vLLM | ❌ | ✅ | | |
| | llama.cpp | ❌ | ✅ | | |
| | MLX | ❌ | ✅ | | |
| | Tensor Parallel | ⚠️ | ✅ | | |
| **Reason**: | |
| * V1 violates the stateless decoding assumptions of modern runtimes | |
| * V2 conforms to them by construction | |
| --- | |
| ### **4. Temporal Feedback Mechanism** | |
| **Prisma V1** | |
| * Feedback loop implemented via external buffer | |
| * Requires padding, truncation, and shifting logic | |
| * Not visible to KV cache or sampler | |
| **Prisma V2** | |
| * Feedback loop is **architectural** | |
| * Uncertainty is predicted one step ahead and injected naturally | |
| * Temporal alignment is implicit in training and decoding | |
| **Implication**: | |
| V2’s feedback loop is **native**, not simulated. | |
| --- | |
| ### **5. Learning Dynamics** | |
| **Prisma V1** | |
| * Uncertainty signal is fixed (entropy) | |
| * Model can only learn *how to react* to uncertainty | |
| * Cannot redefine what uncertainty means | |
| **Prisma V2** | |
| * Uncertainty is supervised initially by entropy, then free to diverge | |
| * Model can learn: | |
| * epistemic uncertainty | |
| * ambiguity | |
| * distribution shift | |
| * task-specific hesitation signals | |
| **Implication**: | |
| V1 teaches *response to uncertainty*. | |
| V2 teaches *representation of uncertainty*. | |
| --- | |
| ### **6. Training Complexity** | |
| **Prisma V1** | |
| * No additional loss | |
| * Entropy computed every forward | |
| * Sensitive to tensor parallel sharding | |
| **Prisma V2** | |
| * Adds a lightweight auxiliary loss | |
| * Entropy used only as a teacher signal during training | |
| * No entropy computation at inference | |
| **Implication**: | |
| V2 trades a small training cost for large inference robustness. | |
| --- | |
| ### **7. Inference Behavior** | |
| **Prisma V1** | |
| * Uncertainty exists only implicitly | |
| * Difficult to inspect or intervene at runtime | |
| * Breaks under batched or reordered decoding | |
| **Prisma V2** | |
| * Uncertainty is explicit and inspectable | |
| * Sampler can condition on it | |
| * Works under any batching or scheduling strategy | |
| --- | |
| ### **8. Conceptual Framing** | |
| **Prisma V1** | |
| * Introspection via *measurement* | |
| * Confidence is something the model observes after the fact | |
| **Prisma V2** | |
| * Introspection via *prediction* | |
| * Confidence is something the model reasons about and plans with | |
| > Prisma V1 makes the model *aware of its uncertainty.* | |
| > Prisma V2 makes uncertainty part of the model’s internal world model. | |
| --- | |
| ### **Summary Table** | |
| | Dimension | Prisma V1 | Prisma V2 | | |
| | ---------------------- | ------------------ | ------------------ | | |
| | Uncertainty source | Entropy (measured) | Learned latent | | |
| | State handling | Mutable buffer | Explicit tensor | | |
| | Runtime support | Limited | Universal | | |
| | KV cache compatibility | ❌ | ✅ | | |
| | Tensor parallel | Fragile | Safe | | |
| | Introspection depth | Reactive | Predictive | | |
| | Deployment readiness | Research-only | Production-capable | | |
| --- | |
| ### **Why Prisma V2 Exists** | |
| Prisma V1 demonstrated that **temporal uncertainty feedback produces introspective behavior**. | |
| Prisma V2 makes that insight **architectural, portable, and deployable**. | |
| It is not a workaround. | |
| It is the correct abstraction boundary. | |
| > *Uncertainty must be data, not memory.* | |
| --- | |
| ## **Core Components to Add** | |
| ```python | |
| # In your CausalLM class | |
| self.n_uncertainty_levels = 256 # V2: compact, sufficient | |
| self.uncertainty_embeddings = nn.Embedding( | |
| self.n_uncertainty_levels, | |
| hidden_dim | |
| ) | |
| # NEW: Uncertainty prediction head | |
| self.uncertainty_head = nn.Linear( | |
| hidden_dim, | |
| self.n_uncertainty_levels, | |
| bias=False | |
| ) | |
| ``` | |
| --- | |
| ## **Initialization Details** | |
| ### Uncertainty Embeddings | |
| * Initialized from `N(0, σ²)` where `σ = config.initializer_range` | |
| ### Uncertainty Head (Important) | |
| ```python | |
| self.uncertainty_head.weight.data.zero_() | |
| ``` | |
| **Rationale**: | |
| * Model initially predicts *neutral uncertainty* | |
| * Early training behaves identically to the base model | |
| * Avoids destabilizing LM loss with noisy auxiliary signals | |
| * Uncertainty pathway is learned gradually | |
| --- | |
| ## **Forward Pass Modifications (Input Side)** | |
| **Location**: *Immediately after token embedding lookup* | |
| ```python | |
| def forward(self, input_ids, uncertainty_codes=None, ...): | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| if uncertainty_codes is not None: | |
| # uncertainty_codes: [B, S] | |
| u = self.uncertainty_embeddings(uncertainty_codes) | |
| inputs_embeds = inputs_embeds + u | |
| hidden_states = self.model( | |
| inputs_embeds=inputs_embeds, | |
| ... | |
| ).last_hidden_state | |
| ``` | |
| * `uncertainty_codes[t]` conditions token position `t` | |
| * No padding, truncation, or shifting logic required | |
| * Temporal alignment is handled by the training and decoding loop | |
| --- | |
| ## **Forward Pass Modifications (Output Side)** | |
| **Location**: *After transformer hidden states* | |
| ```python | |
| logits = self.lm_head(hidden_states) | |
| uncertainty_logits = self.uncertainty_head(hidden_states) | |
| ``` | |
| Returns: | |
| ```python | |
| return { | |
| "logits": logits, # [B, S, vocab] | |
| "uncertainty_logits": uncertainty_logits # [B, S, n_uncertainty_levels] | |
| } | |
| ``` | |
| --- | |
| ## **Temporal Semantics** | |
| | Position | Input | Predicts | | |
| | -------- | --------------------- | ------------------------ | | |
| | t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ | | |
| This preserves the original PRISMA temporal feedback loop without mutable state. | |
| --- | |
| ## **Training Objective** | |
| ### Language Modeling Loss | |
| Standard next-token prediction: | |
| ```python | |
| loss_lm = cross_entropy( | |
| logits[:, :-1], | |
| labels[:, 1:] | |
| ) | |
| ``` | |
| --- | |
| ### Uncertainty Prediction Loss | |
| Uncertainty is predicted **one step ahead**: | |
| ```python | |
| loss_uncertainty = cross_entropy( | |
| uncertainty_logits[:, :-1], | |
| uncertainty_labels[:, 1:] | |
| ) | |
| ``` | |
| --- | |
| ### Combined Loss | |
| ```python | |
| loss = loss_lm + λ * loss_uncertainty | |
| ``` | |
| * Recommended: `λ ≈ 0.1` (to tune) | |
| --- | |
| ## **Uncertainty Supervision (Teacher Signal)** | |
| During training only, entropy is used as a **bootstrap target**, not as the definition of uncertainty. | |
| ```python | |
| with torch.no_grad(): | |
| probs = softmax(logits) | |
| entropy = -(probs * log(probs)).sum(dim=-1) | |
| entropy_norm = entropy / log(vocab_size) | |
| uncertainty_labels = quantize(entropy_norm) | |
| ``` | |
| **Important**: | |
| * Entropy is a *teacher*, not a constraint | |
| * The model may learn uncertainty signals that diverge from entropy | |
| * This is desirable if they correlate better with error or ambiguity | |
| --- | |
| ## **Single-Pass Training (Preferred)** | |
| A second forward pass is **not required**. | |
| ```python | |
| outputs = model( | |
| input_ids, | |
| uncertainty_codes=uncertainty_input | |
| ) | |
| with torch.no_grad(): | |
| uncertainty_labels = compute_uncertainty_labels(outputs.logits) | |
| loss = compute_loss( | |
| outputs.logits, | |
| outputs.uncertainty_logits, | |
| labels, | |
| uncertainty_labels | |
| ) | |
| ``` | |
| --- | |
| ## **Inference Loop (All Runtimes)** | |
| ```text | |
| (tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁) | |
| ``` | |
| ### Neutral Start | |
| ```python | |
| uncertainty_code = n_uncertainty_levels // 2 | |
| ``` | |
| --- | |
| ## **Runtime Integration** | |
| | Runtime | Integration | | |
| | ---------------- | ---------------------------------------------------- | | |
| | **Transformers** | Custom `generate()` tracks `uncertainty_code` tensor | | |
| | **vLLM** | Sampler tracks one uncertainty code per request | | |
| | **llama.cpp** | Store uncertainty code in `llama_context` | | |
| | **MLX** | Works directly (pure tensor graph) | | |
| No runtime relies on Python-side mutable state. | |
| --- | |
| ## **Performance Characteristics** | |
| | Component | Parameters | FLOPs | Memory | Latency | | |
| | ----------------------- | ------------------ | ---------- | ---------- | ---------------- | | |
| | Uncertainty Head | `hidden_dim × 256` | Negligible | Negligible | ~0 | | |
| | Uncertainty Embedding | `256 × hidden_dim` | 0 | Tiny | ~0 | | |
| | Entropy (training only) | 0 | `O(B×S×V)` | O(1) | Not in inference | | |
| **Inference overhead**: effectively zero | |
| --- | |
| ## **Theoretical Intuition** | |
| PRISMA V2 transforms autoregressive generation from: | |
| ``` | |
| P(y_t | x, y_<t) | |
| ``` | |
| to: | |
| ``` | |
| P(y_t, c_t | x, y_<t, c_<t) | |
| ``` | |
| where `c_t` is a learned uncertainty latent. | |
| This allows the model to: | |
| * Reduce commitment after uncertain predictions | |
| * Maintain momentum after confident predictions | |
| * Learn task-specific uncertainty signals | |
| * Develop introspection without relying on engine-level state | |
| --- | |
| ## **Why PRISMA V2 Works Everywhere** | |
| | Constraint | V1 | V2 | | |
| | ------------------ | -- | -- | | |
| | Stateless decoding | ❌ | ✅ | | |
| | vLLM batching | ❌ | ✅ | | |
| | llama.cpp KV cache | ❌ | ✅ | | |
| | Tensor parallel | ⚠️ | ✅ | | |
| | MLX tracing | ❌ | ✅ | | |
| --- | |
| ## **What to Watch For** | |
| * **Ablation**: remove uncertainty input, measure perplexity / behavior | |
| * **Calibration**: does predicted uncertainty correlate with error? | |
| * **Behavioral shifts**: hedging, correction, abstention | |
| * **Divergence from entropy**: expected and healthy | |
| --- | |
| ## **Summary** | |
| Prisma V2 preserves the introspective insight of Prisma V1 while replacing fragile mutable state with an explicit, learned uncertainty representation. This makes introspection **portable, scalable, and deployable** across all modern inference engines. | |
| > *The model no longer measures uncertainty — it learns what uncertainty means.* |