File size: 1,729 Bytes
3727dac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | """Minimal usage example for the guided-decoding-adapter.
Run from the carnot repo root:
JAX_PLATFORMS=cpu python exports/guided-decoding-adapter/example.py
"""
import os
os.environ.setdefault("JAX_PLATFORMS", "cpu")
from unittest.mock import MagicMock
import torch
from carnot.inference.guided_decoding import GuidedDecoder
# Load adapter from this directory (local usage)
# To load from HuggingFace Hub swap the path for the repo ID:
# decoder = GuidedDecoder.from_pretrained("Carnot-EBM/guided-decoding-adapter")
decoder = GuidedDecoder.from_pretrained("exports/guided-decoding-adapter")
# --- Minimal mock model and tokenizer (no GPU / model download needed) ---
# Replace these two blocks with real HF model/tokenizer for production use:
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-0.8B")
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-0.8B")
step = [0]
def _forward(input_ids):
logits = torch.zeros(1, input_ids.shape[1], 10)
logits[0, -1, 1 if step[0] >= 3 else 0] = 10.0 # EOS after 3 tokens
step[0] += 1
out = MagicMock(); out.logits = logits; return out
model = MagicMock()
model.side_effect = _forward
model.parameters = MagicMock(return_value=iter([torch.zeros(1)]))
tokenizer = MagicMock()
tokenizer.eos_token_id = 1
tokenizer.encode = MagicMock(return_value=torch.tensor([[2, 3, 4]]))
tokenizer.decode = MagicMock(side_effect=lambda ids, **kw: "" if ids.item() == 1 else "A")
result = decoder.generate(model, tokenizer, "What is 47 + 28?")
print("Generated:", result.text)
print(f"Tokens: {result.tokens_generated} Checks: {result.energy_checks} "
f"Mean penalty: {result.mean_penalty:.3f} Latency: {result.latency_seconds*1000:.1f}ms")
|