| # Prisma V2 — ToDo.md | |
| ## Goal | |
| Evolve PRISMA from a **Python-stateful uncertainty feedback mechanism** (V1) into a **fully stateless, runtime-portable architecture** compatible with: | |
| * Hugging Face Transformers | |
| * vLLM | |
| * llama.cpp | |
| * MLX | |
| while preserving the core idea: | |
| > **Condition future predictions on the model’s own uncertainty to reduce confident errors and hallucinations.** | |
| --- | |
| ## Background (Why V2 Is Needed) | |
| Prisma V1 relies on mutable per-step Python state: | |
| ```python | |
| self.prev_uncertainty_code | |
| ``` | |
| This works in Hugging Face `generate()`, but fails in modern inference engines: | |
| * **vLLM** — batched, reordered, stateless execution | |
| * **llama.cpp** — compiled C/C++ loop with fixed KV cache semantics | |
| * **MLX** — pure functional graph execution | |
| All target runtimes require decoding to be **purely functional**: | |
| > All information needed for the next step must be carried explicitly in tensors, tokens, or KV cache — not Python object state. | |
| --- | |
| ## Core Design Change (Prisma V2) | |
| ### Replace measured uncertainty with **predicted uncertainty** | |
| Instead of: | |
| * computing entropy post-hoc from logits | |
| * storing uncertainty in a mutable buffer | |
| **The model learns to predict uncertainty directly.** | |
| At each position, the model outputs: | |
| 1. **Next-token logits** (existing) | |
| 2. **Uncertainty logits for the next position** (new) | |
| Uncertainty becomes a **learned latent variable**, not a side effect. | |
| --- | |
| ## Architecture Changes | |
| ### 1. Add an uncertainty prediction head | |
| ```python | |
| self.n_uncertainty_levels = 256 # V2: smaller, sufficient | |
| self.uncertainty_head = nn.Linear(hidden_size, n_uncertainty_levels, bias=False) | |
| ``` | |
| ### 2. Uncertainty head initialization (important) | |
| The uncertainty head is added to a pretrained model. To avoid destabilizing early training, use **zero initialization**: | |
| ```python | |
| self.uncertainty_head.weight.data.zero_() | |
| ``` | |
| **Rationale:** | |
| * Model initially predicts neutral uncertainty everywhere | |
| * Early training behaves identically to the base model | |
| * Uncertainty signal is learned gradually, only when useful | |
| Alternative initializations (small random, copying from `lm_head`) are left for experimentation. | |
| --- | |
| ### 3. Keep uncertainty embeddings (input side) | |
| ```python | |
| self.uncertainty_embeddings = nn.Embedding( | |
| n_uncertainty_levels, | |
| hidden_size | |
| ) | |
| ``` | |
| --- | |
| ### 4. Modify forward signature | |
| ```python | |
| def forward( | |
| input_ids: torch.Tensor, | |
| uncertainty_codes: Optional[torch.Tensor] = None, # [B, S] | |
| **kwargs | |
| ): | |
| ``` | |
| * `uncertainty_codes[t]` conditions token `t` | |
| * No hidden buffers | |
| * No mutable Python state | |
| --- | |
| ## Forward Pass Logic (V2) | |
| 1. Embed tokens | |
| 2. If `uncertainty_codes` provided: | |
| * lookup `uncertainty_embeddings` | |
| * add to `inputs_embeds` | |
| 3. Run transformer | |
| 4. Produce: | |
| * `logits` (next token) | |
| * `uncertainty_logits` (next uncertainty) | |
| ```python | |
| return { | |
| "logits": logits, | |
| "uncertainty_logits": uncertainty_logits, | |
| } | |
| ``` | |
| --- | |
| ## Temporal Semantics | |
| | Position | Input | Predicts | | |
| | -------- | --------------------- | ------------------------ | | |
| | t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ | | |
| This preserves the original PRISMA temporal feedback loop without mutable state. | |
| --- | |
| ## Training Plan | |
| ### Uncertainty supervision (teacher signal) | |
| During training, entropy is used **only as a teacher signal**, not as the definition of uncertainty. | |
| ```python | |
| entropy = -∑ p log p | |
| normalized = entropy / log(vocab_size) | |
| uncertainty_label = quantize(normalized) | |
| ``` | |
| --- | |
| ### Single-pass training (preferred) | |
| A second forward pass is **not required**. | |
| ```python | |
| outputs = model( | |
| input_ids, | |
| uncertainty_codes=uncertainty_input | |
| ) | |
| with torch.no_grad(): | |
| uncertainty_labels = quantize_entropy(outputs.logits) | |
| loss = ( | |
| loss_lm(outputs.logits, labels) | |
| + λ * loss_uncertainty(outputs.uncertainty_logits, uncertainty_labels) | |
| ) | |
| ``` | |
| **Key point:** | |
| * Entropy is a **bootstrap target** | |
| * The model is free to learn uncertainty representations that diverge from entropy over time | |
| * This allows uncertainty to correlate better with *error* than raw entropy does | |
| --- | |
| ### Loss definition | |
| ```python | |
| loss = loss_lm + λ * loss_uncertainty | |
| ``` | |
| * `loss_lm`: standard next-token cross-entropy | |
| * `loss_uncertainty`: cross-entropy over uncertainty codes | |
| * λ ≈ 0.1 (to tune) | |
| --- | |
| ## Inference (All Runtimes) | |
| ### Decode loop (conceptual) | |
| ```text | |
| (tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁) | |
| ``` | |
| ### Runtime responsibilities | |
| * **Transformers**: custom `generate()` tracks uncertainty tensor | |
| * **vLLM**: sampler tracks `uncertainty_code` per request | |
| * **llama.cpp**: store one small uncertainty code in `llama_context` | |
| * **MLX**: works naturally (pure tensor graph) | |
| No runtime needs to preserve Python object state. | |
| --- | |
| ## Compatibility Matrix | |
| | Runtime | Prisma V1 | Prisma V2 | | |
| | --------------- | --------- | --------- | | |
| | Transformers | ✅ | ✅ | | |
| | vLLM | ❌ | ✅ | | |
| | llama.cpp | ❌ | ✅ | | |
| | MLX | ❌ | ✅ | | |
| | Tensor Parallel | ⚠️ | ✅ | | |
| --- | |
| ## Design Decisions | |
| ### Classification vs regression | |
| **Chosen:** classification (quantized uncertainty) | |
| Reasons: | |
| * Stable training | |
| * Matches embedding lookup | |
| * Discrete semantics | |
| * Easier runtime handling | |
| Regression remains an experimental alternative. | |
| --- | |
| ### Uncertainty resolution | |
| * V1: 65,536 levels (overkill) | |
| * V2: **256 levels** (sufficient, efficient, portable) | |
| --- | |
| ## Known Limitations | |
| * Uncertainty reflects **model confidence**, not correctness | |
| * Learned uncertainty may differ from Shannon entropy | |
| * Does not guarantee abstention or correctness | |
| * Behavior depends on training data and loss weighting | |
| --- | |
| ## Open Questions / Future Work | |
| * Tune `n_uncertainty_levels` | |
| * Tune λ (uncertainty loss weight) | |
| * Explore uncertainty-aware decoding strategies | |
| * Compare uncertainty prediction vs uncertainty tokens | |
| * Investigate bootstrapping without entropy supervision | |
| --- | |
| ## Definition of Done (Prisma V2) | |
| * [ ] No mutable per-step Python state | |
| * [ ] Uncertainty passed explicitly as tensor | |
| * [ ] Works in Transformers, vLLM, llama.cpp, MLX | |
| * [ ] Zero-init uncertainty head | |
| * [ ] Single-pass training loop | |
| * [ ] Updated model card + documentation | |
| * [ ] Reference implementation + example | |
| --- | |
| ### Guiding Principle (V2) | |
| > **Uncertainty must be data, not memory.** | |