## PRISMA V2: Joint Uncertainty Prediction Mechanism — Implementation Specification **Architecture Overview**: PRISMA V2 replaces Python-side uncertainty state with a **learned, explicit uncertainty latent** predicted jointly with tokens. At each step, the model predicts both the next token *and* an uncertainty code that conditions the following step. This preserves temporal introspection while remaining fully compatible with stateless inference engines. --- ## **Core Design Principle** > **Uncertainty must be data, not memory.** All information required for the next decoding step is carried explicitly through tensors (tokens, uncertainty codes, or cache), never through mutable module state. --- ## **Differences from Prisma V1 (Detailed)** Prisma V2 is not a minor refactor of Prisma V1. It represents a **fundamental shift in how uncertainty is represented, propagated, and learned**. This section documents those differences precisely. --- ### **1. Source of Uncertainty** **Prisma V1** * Uncertainty is **measured post-hoc** from the model’s output distribution * Computed via entropy of logits * Acts as an external diagnostic signal ```text uncertainty_t = H(P(y_t)) ``` **Prisma V2** * Uncertainty is **predicted by the model itself** * Learned as an auxiliary latent variable * Acts as an internal representation ```text (token_{t+1}, uncertainty_{t+1}) = f(token_t, uncertainty_t) ``` **Implication**: V1 answers *“how uncertain was I?”* V2 answers *“how uncertain will I be?”* --- ### **2. State Representation** **Prisma V1** * Uses mutable Python-side state: ```python self.prev_uncertainty_code ``` * State exists **outside** the model’s forward graph * Relies on strict step-by-step execution order **Prisma V2** * No mutable state * Uncertainty is passed explicitly as a tensor: ```python uncertainty_codes: Tensor[B, S] ``` * Fully contained within the model’s inputs and outputs **Implication**: V1 requires engine cooperation. V2 requires only tensors. --- ### **3. Runtime Compatibility** | Runtime | Prisma V1 | Prisma V2 | | ------------------------ | --------- | --------- | | HuggingFace Transformers | ✅ | ✅ | | vLLM | ❌ | ✅ | | llama.cpp | ❌ | ✅ | | MLX | ❌ | ✅ | | Tensor Parallel | ⚠️ | ✅ | **Reason**: * V1 violates the stateless decoding assumptions of modern runtimes * V2 conforms to them by construction --- ### **4. Temporal Feedback Mechanism** **Prisma V1** * Feedback loop implemented via external buffer * Requires padding, truncation, and shifting logic * Not visible to KV cache or sampler **Prisma V2** * Feedback loop is **architectural** * Uncertainty is predicted one step ahead and injected naturally * Temporal alignment is implicit in training and decoding **Implication**: V2’s feedback loop is **native**, not simulated. --- ### **5. Learning Dynamics** **Prisma V1** * Uncertainty signal is fixed (entropy) * Model can only learn *how to react* to uncertainty * Cannot redefine what uncertainty means **Prisma V2** * Uncertainty is supervised initially by entropy, then free to diverge * Model can learn: * epistemic uncertainty * ambiguity * distribution shift * task-specific hesitation signals **Implication**: V1 teaches *response to uncertainty*. V2 teaches *representation of uncertainty*. --- ### **6. Training Complexity** **Prisma V1** * No additional loss * Entropy computed every forward * Sensitive to tensor parallel sharding **Prisma V2** * Adds a lightweight auxiliary loss * Entropy used only as a teacher signal during training * No entropy computation at inference **Implication**: V2 trades a small training cost for large inference robustness. --- ### **7. Inference Behavior** **Prisma V1** * Uncertainty exists only implicitly * Difficult to inspect or intervene at runtime * Breaks under batched or reordered decoding **Prisma V2** * Uncertainty is explicit and inspectable * Sampler can condition on it * Works under any batching or scheduling strategy --- ### **8. Conceptual Framing** **Prisma V1** * Introspection via *measurement* * Confidence is something the model observes after the fact **Prisma V2** * Introspection via *prediction* * Confidence is something the model reasons about and plans with > Prisma V1 makes the model *aware of its uncertainty.* > Prisma V2 makes uncertainty part of the model’s internal world model. --- ### **Summary Table** | Dimension | Prisma V1 | Prisma V2 | | ---------------------- | ------------------ | ------------------ | | Uncertainty source | Entropy (measured) | Learned latent | | State handling | Mutable buffer | Explicit tensor | | Runtime support | Limited | Universal | | KV cache compatibility | ❌ | ✅ | | Tensor parallel | Fragile | Safe | | Introspection depth | Reactive | Predictive | | Deployment readiness | Research-only | Production-capable | --- ### **Why Prisma V2 Exists** Prisma V1 demonstrated that **temporal uncertainty feedback produces introspective behavior**. Prisma V2 makes that insight **architectural, portable, and deployable**. It is not a workaround. It is the correct abstraction boundary. > *Uncertainty must be data, not memory.* --- ## **Core Components to Add** ```python # In your CausalLM class self.n_uncertainty_levels = 256 # V2: compact, sufficient self.uncertainty_embeddings = nn.Embedding( self.n_uncertainty_levels, hidden_dim ) # NEW: Uncertainty prediction head self.uncertainty_head = nn.Linear( hidden_dim, self.n_uncertainty_levels, bias=False ) ``` --- ## **Initialization Details** ### Uncertainty Embeddings * Initialized from `N(0, σ²)` where `σ = config.initializer_range` ### Uncertainty Head (Important) ```python self.uncertainty_head.weight.data.zero_() ``` **Rationale**: * Model initially predicts *neutral uncertainty* * Early training behaves identically to the base model * Avoids destabilizing LM loss with noisy auxiliary signals * Uncertainty pathway is learned gradually --- ## **Forward Pass Modifications (Input Side)** **Location**: *Immediately after token embedding lookup* ```python def forward(self, input_ids, uncertainty_codes=None, ...): inputs_embeds = self.embed_tokens(input_ids) if uncertainty_codes is not None: # uncertainty_codes: [B, S] u = self.uncertainty_embeddings(uncertainty_codes) inputs_embeds = inputs_embeds + u hidden_states = self.model( inputs_embeds=inputs_embeds, ... ).last_hidden_state ``` * `uncertainty_codes[t]` conditions token position `t` * No padding, truncation, or shifting logic required * Temporal alignment is handled by the training and decoding loop --- ## **Forward Pass Modifications (Output Side)** **Location**: *After transformer hidden states* ```python logits = self.lm_head(hidden_states) uncertainty_logits = self.uncertainty_head(hidden_states) ``` Returns: ```python return { "logits": logits, # [B, S, vocab] "uncertainty_logits": uncertainty_logits # [B, S, n_uncertainty_levels] } ``` --- ## **Temporal Semantics** | Position | Input | Predicts | | -------- | --------------------- | ------------------------ | | t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ | This preserves the original PRISMA temporal feedback loop without mutable state. --- ## **Training Objective** ### Language Modeling Loss Standard next-token prediction: ```python loss_lm = cross_entropy( logits[:, :-1], labels[:, 1:] ) ``` --- ### Uncertainty Prediction Loss Uncertainty is predicted **one step ahead**: ```python loss_uncertainty = cross_entropy( uncertainty_logits[:, :-1], uncertainty_labels[:, 1:] ) ``` --- ### Combined Loss ```python loss = loss_lm + λ * loss_uncertainty ``` * Recommended: `λ ≈ 0.1` (to tune) --- ## **Uncertainty Supervision (Teacher Signal)** During training only, entropy is used as a **bootstrap target**, not as the definition of uncertainty. ```python with torch.no_grad(): probs = softmax(logits) entropy = -(probs * log(probs)).sum(dim=-1) entropy_norm = entropy / log(vocab_size) uncertainty_labels = quantize(entropy_norm) ``` **Important**: * Entropy is a *teacher*, not a constraint * The model may learn uncertainty signals that diverge from entropy * This is desirable if they correlate better with error or ambiguity --- ## **Single-Pass Training (Preferred)** A second forward pass is **not required**. ```python outputs = model( input_ids, uncertainty_codes=uncertainty_input ) with torch.no_grad(): uncertainty_labels = compute_uncertainty_labels(outputs.logits) loss = compute_loss( outputs.logits, outputs.uncertainty_logits, labels, uncertainty_labels ) ``` --- ## **Inference Loop (All Runtimes)** ```text (tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁) ``` ### Neutral Start ```python uncertainty_code = n_uncertainty_levels // 2 ``` --- ## **Runtime Integration** | Runtime | Integration | | ---------------- | ---------------------------------------------------- | | **Transformers** | Custom `generate()` tracks `uncertainty_code` tensor | | **vLLM** | Sampler tracks one uncertainty code per request | | **llama.cpp** | Store uncertainty code in `llama_context` | | **MLX** | Works directly (pure tensor graph) | No runtime relies on Python-side mutable state. --- ## **Performance Characteristics** | Component | Parameters | FLOPs | Memory | Latency | | ----------------------- | ------------------ | ---------- | ---------- | ---------------- | | Uncertainty Head | `hidden_dim × 256` | Negligible | Negligible | ~0 | | Uncertainty Embedding | `256 × hidden_dim` | 0 | Tiny | ~0 | | Entropy (training only) | 0 | `O(B×S×V)` | O(1) | Not in inference | **Inference overhead**: effectively zero --- ## **Theoretical Intuition** PRISMA V2 transforms autoregressive generation from: ``` P(y_t | x, y_ *The model no longer measures uncertainty — it learns what uncertainty means.*