ianblenke commited on
Commit
3727dac
·
verified ·
1 Parent(s): 7a49432

Exp 164: upload guided-decoding-adapter (Exp 137 artifact)

Browse files
Files changed (4) hide show
  1. README.md +197 -0
  2. config.json +66 -0
  3. constraint_weights.safetensors +3 -0
  4. example.py +44 -0
README.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - energy-based-model
4
+ - guided-decoding
5
+ - constraint-satisfaction
6
+ - jax
7
+ - carnot
8
+ license: apache-2.0
9
+ ---
10
+
11
+ > **Research Artifact — Not Production-Ready**
12
+ >
13
+ > Real-model validation is pending (Exp-111). Exp-110 results use a mock LLM
14
+ > with deterministic error injection. The constraint checker works correctly
15
+ > (0.006 ms/check on CPU); the guidance logic is unvalidated on live models.
16
+
17
+ # guided-decoding-adapter
18
+
19
+ Energy-guided decoding adapter for any HuggingFace causal LM.
20
+
21
+ Attaches Carnot's constraint energy pipeline to the token generation loop.
22
+ Each token step runs a constraint violation check on the text generated so far;
23
+ violating tokens are penalised by subtracting `alpha × violation_count` from
24
+ all logits before sampling.
25
+
26
+ ## How It Works
27
+
28
+ ```
29
+ prompt → encode → [forward pass → check constraints → penalise logits → sample] × N → text
30
+ ```
31
+
32
+ The constraint checker (`AutoExtractor`) detects violations across four domains:
33
+
34
+ | Domain | Constraint types |
35
+ |--------|-----------------|
36
+ | Arithmetic | addition, multiplication, bounds |
37
+ | Code | type checks, return types, initialisation |
38
+ | Logic | implication, exclusion, disjunction, negation, universal |
39
+ | Natural language | NL consistency |
40
+
41
+ Energy is a plain violation count (not a calibrated probability). The penalty
42
+ is applied uniformly across the vocabulary — token ranking is preserved while
43
+ overall entropy increases, discouraging the model from continuing down a
44
+ constraint-violating path.
45
+
46
+ ## Latency Profile
47
+
48
+ From Exp-102 (CPU, JAX_PLATFORMS=cpu, 1000-iteration benchmark):
49
+
50
+ | Measurement | Value |
51
+ |---|---|
52
+ | Constraint check p50 | 0.006 ms |
53
+ | Constraint check p99 | 0.034 ms |
54
+ | Extraction p50 | 0.276 ms |
55
+ | Per-token budget fraction | 0.04% of 20 ms/token |
56
+ | Verdict | **Fits in real-time generation budget** |
57
+
58
+ ## Usage
59
+
60
+ ```python
61
+ from carnot.inference.guided_decoding import GuidedDecoder
62
+ from transformers import AutoModelForCausalLM, AutoTokenizer
63
+
64
+ # Load model (any HF causal LM)
65
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-0.8B")
66
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-0.8B")
67
+ model.eval()
68
+
69
+ # Load adapter from local directory or HuggingFace Hub
70
+ decoder = GuidedDecoder.from_pretrained("Carnot-EBM/guided-decoding-adapter")
71
+
72
+ # Generate with constraint guidance
73
+ result = decoder.generate(model, tokenizer, "What is 47 + 28?")
74
+ print(result.text)
75
+ print(f"Energy checks: {result.energy_checks}, final energy: {result.final_energy}")
76
+ ```
77
+
78
+ ### Override defaults
79
+
80
+ ```python
81
+ decoder = GuidedDecoder.from_pretrained(
82
+ "Carnot-EBM/guided-decoding-adapter",
83
+ alpha=1.0, # stronger guidance
84
+ check_every_k=5, # check every 5 tokens (faster, less precise)
85
+ energy_threshold=0.5 # only penalise when violations > 0.5
86
+ )
87
+ ```
88
+
89
+ ### Load from a local export directory
90
+
91
+ ```python
92
+ decoder = GuidedDecoder.from_pretrained("./exports/guided-decoding-adapter")
93
+ ```
94
+
95
+ ## Return Value
96
+
97
+ `generate()` returns a `GuidedDecodingResult`:
98
+
99
+ | Field | Type | Description |
100
+ |---|---|---|
101
+ | `text` | `str` | Generated text (prompt excluded) |
102
+ | `tokens_generated` | `int` | Number of tokens produced |
103
+ | `energy_checks` | `int` | Times constraint check ran |
104
+ | `mean_penalty` | `float` | Average logit penalty applied |
105
+ | `latency_seconds` | `float` | Wall-clock time |
106
+ | `final_energy` | `float` | Violation count after last check |
107
+
108
+ ## Constraint Weights
109
+
110
+ Default weights are stored in `constraint_weights.safetensors`. Load and inspect:
111
+
112
+ ```python
113
+ from safetensors.numpy import load_file
114
+ weights = load_file("constraint_weights.safetensors")
115
+ print(weights["all_weights"]) # shape (12,) float32
116
+ print(weights["default_alpha"]) # [0.5]
117
+ ```
118
+
119
+ ## Compatible Models
120
+
121
+ Tested target models (Exp-110):
122
+ - `Qwen/Qwen3.5-0.8B`
123
+ - `google/gemma-4-E4B-it`
124
+
125
+ Any HuggingFace `AutoModelForCausalLM` with `.logits` output should work.
126
+ The adapter does not modify model weights.
127
+
128
+
129
+ ## Benchmark Results (Exp-138 & Exp-140)
130
+
131
+ > **Note — Simulated Inference**: All benchmark numbers below were produced
132
+ > with a *simulated* (mock) LLM, not a real transformer model. The constraint
133
+ > checker and logit-penalty logic are real; the generation loop uses a
134
+ > deterministic stand-in. Live-model E2E validation is pending (Exp-111).
135
+
136
+ ### Accuracy (Exp-138, n=200/50/100, simulated inference)
137
+
138
+ | Dataset | Baseline | Guided | Guided+Verify-Repair | Delta (guided) |
139
+ |---------|----------|--------|----------------------|----------------|
140
+ | GSM8K (math) | 55.5% | 62.5% | 65.0% | **+7.0%** |
141
+ | HumanEval (code) | 100.0% | 100.0% | — | **+0.0%** |
142
+ | TruthfulQA | 55.0% | 56.0% | 61.0% | **+1.0%** |
143
+
144
+ ### Latency (Exp-138, n=485 samples, CPU)
145
+
146
+ | Metric | Value |
147
+ |--------|-------|
148
+ | Constraint-check p50 | 0.0719 ms |
149
+ | Constraint-check p99 | 0.1275 ms |
150
+
151
+ ### Latency — KAN Projection Mode (Exp-140, batch=1, CPU)
152
+
153
+ | Operation | p50 | p99 |
154
+ |-----------|-----|-----|
155
+ | Logit projection (energy gradient) | 0.077 ms | 0.271 ms |
156
+ | Total per-token (grad + projection) | 0.405 ms | 0.924 ms |
157
+
158
+ Exp-140 pass criterion: total p50 < 5 ms — **PASSED**
159
+ (actual 0.4054 ms vs 5.0 ms threshold).
160
+
161
+ ## Installation
162
+
163
+ ```bash
164
+ pip install carnot
165
+ ```
166
+
167
+ Requires Python 3.11+. See [pypi.org/project/carnot](https://pypi.org/project/carnot)
168
+ for the full package including the verify-repair pipeline.
169
+
170
+ ## Limitations
171
+
172
+ 1. **Simulated inference benchmark**: Exp-138 and Exp-140 used a mock LLM.
173
+ Numbers show constraint-checker and logit-penalty overhead, not end-to-end
174
+ accuracy on real models. Treat accuracy deltas as directional, not final.
175
+ 2. **No KV-cache**: Full forward pass every token. Keep `max_tokens < 256`.
176
+ 3. **Uniform penalty**: Adjusts entropy across the whole vocabulary; does not
177
+ steer towards specific correct tokens.
178
+ 4. **Energy is a violation count**: Not a calibrated probability. High `alpha`
179
+ + many violations → very flat distribution (model may repeat or stall).
180
+ 5. **Min-text guard**: `AutoExtractor` skips texts < 5 chars (early tokens).
181
+ 6. **Live-model E2E pending**: Exp-111 validation against Qwen/Gemma not done yet.
182
+
183
+ ## Spec
184
+
185
+ - REQ-VERIFY-001: Constraint energy computed from partial text at each step.
186
+ - SCENARIO-VERIFY-004: Energy penalises logits before sampling.
187
+
188
+ ## Citation
189
+
190
+ ```bibtex
191
+ @misc{carnot2026guided,
192
+ title = {Carnot Guided Decoding Adapter},
193
+ author = {Carnot-EBM},
194
+ year = {2026},
195
+ url = {https://github.com/Carnot-EBM/carnot-ebm}
196
+ }
197
+ ```
config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "adapter_type": "guided_decoding",
3
+ "carnot_version": "0.1.0",
4
+ "spec": ["REQ-VERIFY-001", "SCENARIO-VERIFY-004"],
5
+
6
+ "constraint_types": [
7
+ "arithmetic",
8
+ "type_check",
9
+ "return_type",
10
+ "return_value_type",
11
+ "bound",
12
+ "initialization",
13
+ "implication",
14
+ "exclusion",
15
+ "disjunction",
16
+ "negation",
17
+ "universal",
18
+ "nl_consistency"
19
+ ],
20
+
21
+ "default_weights": {
22
+ "arithmetic": 1.0,
23
+ "type_check": 1.0,
24
+ "return_type": 1.0,
25
+ "return_value_type":1.0,
26
+ "bound": 1.0,
27
+ "initialization": 0.8,
28
+ "implication": 0.9,
29
+ "exclusion": 0.9,
30
+ "disjunction": 0.7,
31
+ "negation": 0.8,
32
+ "universal": 0.7,
33
+ "nl_consistency": 0.5
34
+ },
35
+
36
+ "default_alpha": 0.5,
37
+ "default_check_every_k": 1,
38
+ "default_energy_threshold": 0.0,
39
+ "default_max_tokens": 256,
40
+ "default_temperature": 1.0,
41
+
42
+ "compatible_model_families": [
43
+ "Qwen/Qwen3-0.6B",
44
+ "Qwen/Qwen3.5-0.8B",
45
+ "Qwen/Qwen2.5-*",
46
+ "google/gemma-4-*",
47
+ "meta-llama/Llama-*",
48
+ "mistralai/Mistral-*"
49
+ ],
50
+
51
+ "latency_profile": {
52
+ "constraint_check_p50_ms": 0.006,
53
+ "constraint_check_p99_ms": 0.034,
54
+ "extraction_p50_ms": 0.275,
55
+ "per_token_overhead_budget_fraction": 0.0004,
56
+ "source_experiments": ["Exp-102", "Exp-110"]
57
+ },
58
+
59
+ "known_limitations": [
60
+ "AutoExtractor min-text guard (< 5 chars) skips early tokens",
61
+ "No KV-cache: full forward pass every token — use max_tokens < 256 for speed",
62
+ "Uniform logit penalty preserves token ranking but does not steer vocabulary selection",
63
+ "Energy is a violation count, not a calibrated probability",
64
+ "Real-model validation pending (Exp-111); Exp-110 used MockArithmeticLLM"
65
+ ]
66
+ }
constraint_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:070451f8e88383b12f8957c8f703b5eec2f99673f8e302795d14f0fb56d4222e
3
+ size 1192
example.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal usage example for the guided-decoding-adapter.
2
+
3
+ Run from the carnot repo root:
4
+ JAX_PLATFORMS=cpu python exports/guided-decoding-adapter/example.py
5
+ """
6
+
7
+ import os
8
+ os.environ.setdefault("JAX_PLATFORMS", "cpu")
9
+
10
+ from unittest.mock import MagicMock
11
+ import torch
12
+
13
+ from carnot.inference.guided_decoding import GuidedDecoder
14
+
15
+ # Load adapter from this directory (local usage)
16
+ # To load from HuggingFace Hub swap the path for the repo ID:
17
+ # decoder = GuidedDecoder.from_pretrained("Carnot-EBM/guided-decoding-adapter")
18
+ decoder = GuidedDecoder.from_pretrained("exports/guided-decoding-adapter")
19
+
20
+ # --- Minimal mock model and tokenizer (no GPU / model download needed) ---
21
+ # Replace these two blocks with real HF model/tokenizer for production use:
22
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-0.8B")
23
+ # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-0.8B")
24
+ step = [0]
25
+
26
+ def _forward(input_ids):
27
+ logits = torch.zeros(1, input_ids.shape[1], 10)
28
+ logits[0, -1, 1 if step[0] >= 3 else 0] = 10.0 # EOS after 3 tokens
29
+ step[0] += 1
30
+ out = MagicMock(); out.logits = logits; return out
31
+
32
+ model = MagicMock()
33
+ model.side_effect = _forward
34
+ model.parameters = MagicMock(return_value=iter([torch.zeros(1)]))
35
+
36
+ tokenizer = MagicMock()
37
+ tokenizer.eos_token_id = 1
38
+ tokenizer.encode = MagicMock(return_value=torch.tensor([[2, 3, 4]]))
39
+ tokenizer.decode = MagicMock(side_effect=lambda ids, **kw: "" if ids.item() == 1 else "A")
40
+
41
+ result = decoder.generate(model, tokenizer, "What is 47 + 28?")
42
+ print("Generated:", result.text)
43
+ print(f"Tokens: {result.tokens_generated} Checks: {result.energy_checks} "
44
+ f"Mean penalty: {result.mean_penalty:.3f} Latency: {result.latency_seconds*1000:.1f}ms")