Prisma-VL-8B / PrismaV2_ToDo.md
ehartford's picture
Create PrismaV2_ToDo.md
6c0624c verified
# Prisma V2 — ToDo.md
## Goal
Evolve PRISMA from a **Python-stateful uncertainty feedback mechanism** (V1) into a **fully stateless, runtime-portable architecture** compatible with:
* Hugging Face Transformers
* vLLM
* llama.cpp
* MLX
while preserving the core idea:
> **Condition future predictions on the model’s own uncertainty to reduce confident errors and hallucinations.**
---
## Background (Why V2 Is Needed)
Prisma V1 relies on mutable per-step Python state:
```python
self.prev_uncertainty_code
```
This works in Hugging Face `generate()`, but fails in modern inference engines:
* **vLLM** — batched, reordered, stateless execution
* **llama.cpp** — compiled C/C++ loop with fixed KV cache semantics
* **MLX** — pure functional graph execution
All target runtimes require decoding to be **purely functional**:
> All information needed for the next step must be carried explicitly in tensors, tokens, or KV cache — not Python object state.
---
## Core Design Change (Prisma V2)
### Replace measured uncertainty with **predicted uncertainty**
Instead of:
* computing entropy post-hoc from logits
* storing uncertainty in a mutable buffer
**The model learns to predict uncertainty directly.**
At each position, the model outputs:
1. **Next-token logits** (existing)
2. **Uncertainty logits for the next position** (new)
Uncertainty becomes a **learned latent variable**, not a side effect.
---
## Architecture Changes
### 1. Add an uncertainty prediction head
```python
self.n_uncertainty_levels = 256 # V2: smaller, sufficient
self.uncertainty_head = nn.Linear(hidden_size, n_uncertainty_levels, bias=False)
```
### 2. Uncertainty head initialization (important)
The uncertainty head is added to a pretrained model. To avoid destabilizing early training, use **zero initialization**:
```python
self.uncertainty_head.weight.data.zero_()
```
**Rationale:**
* Model initially predicts neutral uncertainty everywhere
* Early training behaves identically to the base model
* Uncertainty signal is learned gradually, only when useful
Alternative initializations (small random, copying from `lm_head`) are left for experimentation.
---
### 3. Keep uncertainty embeddings (input side)
```python
self.uncertainty_embeddings = nn.Embedding(
n_uncertainty_levels,
hidden_size
)
```
---
### 4. Modify forward signature
```python
def forward(
input_ids: torch.Tensor,
uncertainty_codes: Optional[torch.Tensor] = None, # [B, S]
**kwargs
):
```
* `uncertainty_codes[t]` conditions token `t`
* No hidden buffers
* No mutable Python state
---
## Forward Pass Logic (V2)
1. Embed tokens
2. If `uncertainty_codes` provided:
* lookup `uncertainty_embeddings`
* add to `inputs_embeds`
3. Run transformer
4. Produce:
* `logits` (next token)
* `uncertainty_logits` (next uncertainty)
```python
return {
"logits": logits,
"uncertainty_logits": uncertainty_logits,
}
```
---
## Temporal Semantics
| Position | Input | Predicts |
| -------- | --------------------- | ------------------------ |
| t | tokenₜ + uncertaintyₜ | tokenₜ₊₁, uncertaintyₜ₊₁ |
This preserves the original PRISMA temporal feedback loop without mutable state.
---
## Training Plan
### Uncertainty supervision (teacher signal)
During training, entropy is used **only as a teacher signal**, not as the definition of uncertainty.
```python
entropy = -∑ p log p
normalized = entropy / log(vocab_size)
uncertainty_label = quantize(normalized)
```
---
### Single-pass training (preferred)
A second forward pass is **not required**.
```python
outputs = model(
input_ids,
uncertainty_codes=uncertainty_input
)
with torch.no_grad():
uncertainty_labels = quantize_entropy(outputs.logits)
loss = (
loss_lm(outputs.logits, labels)
+ λ * loss_uncertainty(outputs.uncertainty_logits, uncertainty_labels)
)
```
**Key point:**
* Entropy is a **bootstrap target**
* The model is free to learn uncertainty representations that diverge from entropy over time
* This allows uncertainty to correlate better with *error* than raw entropy does
---
### Loss definition
```python
loss = loss_lm + λ * loss_uncertainty
```
* `loss_lm`: standard next-token cross-entropy
* `loss_uncertainty`: cross-entropy over uncertainty codes
* λ ≈ 0.1 (to tune)
---
## Inference (All Runtimes)
### Decode loop (conceptual)
```text
(tokenₜ, uncertaintyₜ) → model → (tokenₜ₊₁, uncertaintyₜ₊₁)
```
### Runtime responsibilities
* **Transformers**: custom `generate()` tracks uncertainty tensor
* **vLLM**: sampler tracks `uncertainty_code` per request
* **llama.cpp**: store one small uncertainty code in `llama_context`
* **MLX**: works naturally (pure tensor graph)
No runtime needs to preserve Python object state.
---
## Compatibility Matrix
| Runtime | Prisma V1 | Prisma V2 |
| --------------- | --------- | --------- |
| Transformers | ✅ | ✅ |
| vLLM | ❌ | ✅ |
| llama.cpp | ❌ | ✅ |
| MLX | ❌ | ✅ |
| Tensor Parallel | ⚠️ | ✅ |
---
## Design Decisions
### Classification vs regression
**Chosen:** classification (quantized uncertainty)
Reasons:
* Stable training
* Matches embedding lookup
* Discrete semantics
* Easier runtime handling
Regression remains an experimental alternative.
---
### Uncertainty resolution
* V1: 65,536 levels (overkill)
* V2: **256 levels** (sufficient, efficient, portable)
---
## Known Limitations
* Uncertainty reflects **model confidence**, not correctness
* Learned uncertainty may differ from Shannon entropy
* Does not guarantee abstention or correctness
* Behavior depends on training data and loss weighting
---
## Open Questions / Future Work
* Tune `n_uncertainty_levels`
* Tune λ (uncertainty loss weight)
* Explore uncertainty-aware decoding strategies
* Compare uncertainty prediction vs uncertainty tokens
* Investigate bootstrapping without entropy supervision
---
## Definition of Done (Prisma V2)
* [ ] No mutable per-step Python state
* [ ] Uncertainty passed explicitly as tensor
* [ ] Works in Transformers, vLLM, llama.cpp, MLX
* [ ] Zero-init uncertainty head
* [ ] Single-pass training loop
* [ ] Updated model card + documentation
* [ ] Reference implementation + example
---
### Guiding Principle (V2)
> **Uncertainty must be data, not memory.**