File size: 6,419 Bytes
3727dac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | ---
tags:
- energy-based-model
- guided-decoding
- constraint-satisfaction
- jax
- carnot
license: apache-2.0
---
> **Research Artifact β Not Production-Ready**
>
> Real-model validation is pending (Exp-111). Exp-110 results use a mock LLM
> with deterministic error injection. The constraint checker works correctly
> (0.006 ms/check on CPU); the guidance logic is unvalidated on live models.
# guided-decoding-adapter
Energy-guided decoding adapter for any HuggingFace causal LM.
Attaches Carnot's constraint energy pipeline to the token generation loop.
Each token step runs a constraint violation check on the text generated so far;
violating tokens are penalised by subtracting `alpha Γ violation_count` from
all logits before sampling.
## How It Works
```
prompt β encode β [forward pass β check constraints β penalise logits β sample] Γ N β text
```
The constraint checker (`AutoExtractor`) detects violations across four domains:
| Domain | Constraint types |
|--------|-----------------|
| Arithmetic | addition, multiplication, bounds |
| Code | type checks, return types, initialisation |
| Logic | implication, exclusion, disjunction, negation, universal |
| Natural language | NL consistency |
Energy is a plain violation count (not a calibrated probability). The penalty
is applied uniformly across the vocabulary β token ranking is preserved while
overall entropy increases, discouraging the model from continuing down a
constraint-violating path.
## Latency Profile
From Exp-102 (CPU, JAX_PLATFORMS=cpu, 1000-iteration benchmark):
| Measurement | Value |
|---|---|
| Constraint check p50 | 0.006 ms |
| Constraint check p99 | 0.034 ms |
| Extraction p50 | 0.276 ms |
| Per-token budget fraction | 0.04% of 20 ms/token |
| Verdict | **Fits in real-time generation budget** |
## Usage
```python
from carnot.inference.guided_decoding import GuidedDecoder
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model (any HF causal LM)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-0.8B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-0.8B")
model.eval()
# Load adapter from local directory or HuggingFace Hub
decoder = GuidedDecoder.from_pretrained("Carnot-EBM/guided-decoding-adapter")
# Generate with constraint guidance
result = decoder.generate(model, tokenizer, "What is 47 + 28?")
print(result.text)
print(f"Energy checks: {result.energy_checks}, final energy: {result.final_energy}")
```
### Override defaults
```python
decoder = GuidedDecoder.from_pretrained(
"Carnot-EBM/guided-decoding-adapter",
alpha=1.0, # stronger guidance
check_every_k=5, # check every 5 tokens (faster, less precise)
energy_threshold=0.5 # only penalise when violations > 0.5
)
```
### Load from a local export directory
```python
decoder = GuidedDecoder.from_pretrained("./exports/guided-decoding-adapter")
```
## Return Value
`generate()` returns a `GuidedDecodingResult`:
| Field | Type | Description |
|---|---|---|
| `text` | `str` | Generated text (prompt excluded) |
| `tokens_generated` | `int` | Number of tokens produced |
| `energy_checks` | `int` | Times constraint check ran |
| `mean_penalty` | `float` | Average logit penalty applied |
| `latency_seconds` | `float` | Wall-clock time |
| `final_energy` | `float` | Violation count after last check |
## Constraint Weights
Default weights are stored in `constraint_weights.safetensors`. Load and inspect:
```python
from safetensors.numpy import load_file
weights = load_file("constraint_weights.safetensors")
print(weights["all_weights"]) # shape (12,) float32
print(weights["default_alpha"]) # [0.5]
```
## Compatible Models
Tested target models (Exp-110):
- `Qwen/Qwen3.5-0.8B`
- `google/gemma-4-E4B-it`
Any HuggingFace `AutoModelForCausalLM` with `.logits` output should work.
The adapter does not modify model weights.
## Benchmark Results (Exp-138 & Exp-140)
> **Note β Simulated Inference**: All benchmark numbers below were produced
> with a *simulated* (mock) LLM, not a real transformer model. The constraint
> checker and logit-penalty logic are real; the generation loop uses a
> deterministic stand-in. Live-model E2E validation is pending (Exp-111).
### Accuracy (Exp-138, n=200/50/100, simulated inference)
| Dataset | Baseline | Guided | Guided+Verify-Repair | Delta (guided) |
|---------|----------|--------|----------------------|----------------|
| GSM8K (math) | 55.5% | 62.5% | 65.0% | **+7.0%** |
| HumanEval (code) | 100.0% | 100.0% | β | **+0.0%** |
| TruthfulQA | 55.0% | 56.0% | 61.0% | **+1.0%** |
### Latency (Exp-138, n=485 samples, CPU)
| Metric | Value |
|--------|-------|
| Constraint-check p50 | 0.0719 ms |
| Constraint-check p99 | 0.1275 ms |
### Latency β KAN Projection Mode (Exp-140, batch=1, CPU)
| Operation | p50 | p99 |
|-----------|-----|-----|
| Logit projection (energy gradient) | 0.077 ms | 0.271 ms |
| Total per-token (grad + projection) | 0.405 ms | 0.924 ms |
Exp-140 pass criterion: total p50 < 5 ms β **PASSED**
(actual 0.4054 ms vs 5.0 ms threshold).
## Installation
```bash
pip install carnot
```
Requires Python 3.11+. See [pypi.org/project/carnot](https://pypi.org/project/carnot)
for the full package including the verify-repair pipeline.
## Limitations
1. **Simulated inference benchmark**: Exp-138 and Exp-140 used a mock LLM.
Numbers show constraint-checker and logit-penalty overhead, not end-to-end
accuracy on real models. Treat accuracy deltas as directional, not final.
2. **No KV-cache**: Full forward pass every token. Keep `max_tokens < 256`.
3. **Uniform penalty**: Adjusts entropy across the whole vocabulary; does not
steer towards specific correct tokens.
4. **Energy is a violation count**: Not a calibrated probability. High `alpha`
+ many violations β very flat distribution (model may repeat or stall).
5. **Min-text guard**: `AutoExtractor` skips texts < 5 chars (early tokens).
6. **Live-model E2E pending**: Exp-111 validation against Qwen/Gemma not done yet.
## Spec
- REQ-VERIFY-001: Constraint energy computed from partial text at each step.
- SCENARIO-VERIFY-004: Energy penalises logits before sampling.
## Citation
```bibtex
@misc{carnot2026guided,
title = {Carnot Guided Decoding Adapter},
author = {Carnot-EBM},
year = {2026},
url = {https://github.com/Carnot-EBM/carnot-ebm}
}
```
|