# Prisma V2 — ToDo.md ## Goal Evolve PRISMA from a **Python-stateful uncertainty feedback mechanism** (V1) into a **fully stateless, runtime-portable architecture** compatible with: * Hugging Face Transformers * vLLM * llama.cpp * MLX while preserving the core idea: > **Condition future predictions on the model’s own uncertainty to reduce confident errors and hallucinations.** --- ## Background (Why V2 Is Needed) Prisma V1 relies on mutable per-step Python state: ```python self.prev_uncertainty_code ``` This works in Hugging Face `generate()`, but fails in modern inference engines: * **vLLM** — batched, reordered, stateless execution * **llama.cpp** — compiled C/C++ loop with fixed KV cache semantics * **MLX** — pure functional graph execution All target runtimes require decoding to be **purely functional**: > All information needed for the next step must be carried explicitly in tensors, tokens, or KV cache — not Python object state. --- ## Core Design Change (Prisma V2) ### Replace measured uncertainty with **predicted uncertainty** Instead of: * computing entropy post-hoc from logits * storing uncertainty in a mutable buffer **The model learns to predict uncertainty directly.** At each position, the model outputs: 1. **Next-token logits** (existing) 2. **Uncertainty logits for the next position** (new) Uncertainty becomes a **learned latent variable**, not a side effect. --- ## Architecture Changes ### 1. Add an uncertainty prediction head ```python self.n_uncertainty_levels = 256 # V2: smaller, sufficient self.uncertainty_head = nn.Linear(hidden_size, n_uncertainty_levels, bias=False) ``` ### 2. Uncertainty head initialization (important) The uncertainty head is added to a pretrained model. To avoid destabilizing early training, use **zero initialization**: ```python self.uncertainty_head.weight.data.zero_() ``` **Rationale:** * Model initially predicts neutral uncertainty everywhere * Early training behaves identically to the base model * Uncertainty signal is learned gradually, only when useful Alternative initializations (small random, copying from `lm_head`) are left for experimentation. --- ### 3. Keep uncertainty embeddings (input side) ```python self.uncertainty_embeddings = nn.Embedding( n_uncertainty_levels, hidden_size ) ``` --- ### 4. Modify forward signature ```python def forward( input_ids: torch.Tensor, uncertainty_codes: Optional[torch.Tensor] = None, # [B, S] **kwargs ): ``` * `uncertainty_codes[t]` conditions token `t` * No hidden buffers * No mutable Python state --- ## Forward Pass Logic (V2) 1. Embed tokens 2. If `uncertainty_codes` provided: * lookup `uncertainty_embeddings` * add to `inputs_embeds` 3. Run transformer 4. Produce: * `logits` (next token) * `uncertainty_logits` (next uncertainty) ```python return { "logits": logits, "uncertainty_logits": uncertainty_logits, } ``` --- ## Temporal Semantics | Position | Input | Predicts | | -------- | --------------------- | ------------------------ | | t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ | This preserves the original PRISMA temporal feedback loop without mutable state. --- ## Training Plan ### Uncertainty supervision (teacher signal) During training, entropy is used **only as a teacher signal**, not as the definition of uncertainty. ```python entropy = -∑ p log p normalized = entropy / log(vocab_size) uncertainty_label = quantize(normalized) ``` --- ### Single-pass training (preferred) A second forward pass is **not required**. ```python outputs = model( input_ids, uncertainty_codes=uncertainty_input ) with torch.no_grad(): uncertainty_labels = quantize_entropy(outputs.logits) loss = ( loss_lm(outputs.logits, labels) + λ * loss_uncertainty(outputs.uncertainty_logits, uncertainty_labels) ) ``` **Key point:** * Entropy is a **bootstrap target** * The model is free to learn uncertainty representations that diverge from entropy over time * This allows uncertainty to correlate better with *error* than raw entropy does --- ### Loss definition ```python loss = loss_lm + λ * loss_uncertainty ``` * `loss_lm`: standard next-token cross-entropy * `loss_uncertainty`: cross-entropy over uncertainty codes * λ ≈ 0.1 (to tune) --- ## Inference (All Runtimes) ### Decode loop (conceptual) ```text (tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁) ``` ### Runtime responsibilities * **Transformers**: custom `generate()` tracks uncertainty tensor * **vLLM**: sampler tracks `uncertainty_code` per request * **llama.cpp**: store one small uncertainty code in `llama_context` * **MLX**: works naturally (pure tensor graph) No runtime needs to preserve Python object state. --- ## Compatibility Matrix | Runtime | Prisma V1 | Prisma V2 | | --------------- | --------- | --------- | | Transformers | ✅ | ✅ | | vLLM | ❌ | ✅ | | llama.cpp | ❌ | ✅ | | MLX | ❌ | ✅ | | Tensor Parallel | ⚠️ | ✅ | --- ## Design Decisions ### Classification vs regression **Chosen:** classification (quantized uncertainty) Reasons: * Stable training * Matches embedding lookup * Discrete semantics * Easier runtime handling Regression remains an experimental alternative. --- ### Uncertainty resolution * V1: 65,536 levels (overkill) * V2: **256 levels** (sufficient, efficient, portable) --- ## Known Limitations * Uncertainty reflects **model confidence**, not correctness * Learned uncertainty may differ from Shannon entropy * Does not guarantee abstention or correctness * Behavior depends on training data and loss weighting --- ## Open Questions / Future Work * Tune `n_uncertainty_levels` * Tune λ (uncertainty loss weight) * Explore uncertainty-aware decoding strategies * Compare uncertainty prediction vs uncertainty tokens * Investigate bootstrapping without entropy supervision --- ## Definition of Done (Prisma V2) * [ ] No mutable per-step Python state * [ ] Uncertainty passed explicitly as tensor * [ ] Works in Transformers, vLLM, llama.cpp, MLX * [ ] Zero-init uncertainty head * [ ] Single-pass training loop * [ ] Updated model card + documentation * [ ] Reference implementation + example --- ### Guiding Principle (V2) > **Uncertainty must be data, not memory.**