case0 / src /case_zero /llm /decoding.py
HusseinEid's picture
Case Zero - initial public release (fully local: Qwen2.5-1.5B via llama.cpp + Supertonic, custom pixel-noir SPA via gradio.Server)
414dc55
raw
history blame
10.8 kB
"""Structured decoding: turn raw model text into validated schema objects.
Two paths share this module:
- the interrogation hot path (``stream_turn`` / ``generate_turn``) parses the
dual-output wire shape produced by ``dual_output.gbnf``;
- the generator (``generate_model``) constrains a one-shot call to a pydantic
model's JSON schema, validates, and repairs once on failure.
"""
from __future__ import annotations
import json
from collections.abc import Iterator
from dataclasses import dataclass
from functools import lru_cache
from pydantic import BaseModel, Field, ValidationError
from ..constants import GRAMMARS_DIR, SPOKEN_MAX_TOKENS
from ..schemas.enums import EvidenceReaction, Intent
from ..schemas.interrogation import InternalState, InterrogationTurn
from .backend import GenParams, LLMBackend, LLMError
class _StateWire(BaseModel):
"""The mechanical-state half of the wire shape (drives grammar generation)."""
intent: Intent
is_lying: bool
active_lie_id: str | None
evidence_reaction: EvidenceReaction
deception_level: int = Field(ge=0, le=100)
stress: float = Field(ge=0.0, le=1.0)
revealed_fact_ids: list[str]
slip: bool
class _TurnWire(BaseModel):
"""The full dual-output wire shape. Its JSON schema is converted to a grammar by
llama.cpp (the supported path) so a small model always emits valid structure."""
think: str
spoken: str
state: _StateWire
_REPAIR_SUFFIX = (
"\n\nYour previous reply was not valid. Reply again with ONLY the JSON object, "
"matching the required schema exactly."
)
@lru_cache(maxsize=8)
def load_grammar(name: str) -> str:
"""Read a GBNF grammar file from the grammars directory (cached)."""
path = GRAMMARS_DIR / name
if not path.exists():
raise FileNotFoundError(f"grammar not found: {path}")
return path.read_text(encoding="utf-8")
@dataclass
class TurnEvent:
"""One streaming event: a chunk of spoken text and/or the final parsed turn."""
spoken_delta: str = ""
final: InterrogationTurn | None = None
class SpokenScanner:
"""Incrementally extracts the decoded value of the ``"spoken"`` JSON field from
a streaming completion, so dialogue can render and be voiced before the full
object (including the trailing mechanical state) has arrived."""
# Include the colon so the scanner locks onto the KEY "spoken": and never onto
# the word "spoken" appearing inside the (earlier) think value.
_KEY = '"spoken":'
def __init__(self) -> None:
self._phase = "seek" # seek -> precolon -> instr -> done
self._tail = ""
self._escape = False
self._uni: list[str] | None = None
self._done = False
def feed(self, delta: str) -> str:
out: list[str] = []
for ch in delta:
piece = self._consume(ch)
if piece:
out.append(piece)
return "".join(out)
def _consume(self, ch: str) -> str:
if self._done:
return ""
if self._phase == "seek":
self._tail = (self._tail + ch)[-len(self._KEY):]
if self._tail == self._KEY:
self._phase = "precolon"
return ""
if self._phase == "precolon":
if ch == '"':
self._phase = "instr"
return ""
# inside the spoken string value
if self._uni is not None:
self._uni.append(ch)
if len(self._uni) == 4:
code = "".join(self._uni)
self._uni = None
try:
return chr(int(code, 16))
except ValueError:
return ""
return ""
if self._escape:
self._escape = False
if ch == "u":
self._uni = []
return ""
return _UNESCAPE.get(ch, ch)
if ch == "\\":
self._escape = True
return ""
if ch == '"':
self._phase = "done"
self._done = True
return ""
return ch
@property
def done(self) -> bool:
return self._done
_UNESCAPE = {"n": "\n", "t": "\t", "r": "\r", "b": "\b", "f": "\f", "/": "/", '"': '"', "\\": "\\"}
def _extract_json(text: str) -> str:
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end < start:
raise LLMError("no JSON object found in completion")
return text[start : end + 1]
def _clamp(value: float, low: float, high: float) -> float:
return max(low, min(high, value))
def wire_to_turn(wire: dict) -> InterrogationTurn:
"""Map the ``dual_output.gbnf`` wire shape onto an InterrogationTurn.
Advisory numeric fields are clamped, not rejected: a slightly out-of-range value
must never discard an otherwise-valid turn (the deterministic director, not these
numbers, is the mechanical authority)."""
state = dict(wire.get("state") or {})
active = state.get("active_lie_id")
# Coerce the advisory enums to safe defaults if the (grammar-free) model emits an
# unknown value, so a single bad field never discards an otherwise-valid turn.
intent = state.get("intent", "deflect")
if intent not in Intent._value2member_map_:
intent = "deflect"
reaction = state.get("evidence_reaction", "none")
if reaction not in EvidenceReaction._value2member_map_:
reaction = "none"
internal = InternalState(
private_reasoning=str(wire.get("think", "")),
intent=intent,
is_lying=bool(state.get("is_lying", False)),
active_lie_id=active if active not in (None, "null", "") else None,
evidence_reaction=reaction,
deception_level=int(_clamp(float(state.get("deception_level", 0)), 0, 100)),
stress=_clamp(float(state.get("stress", 0.0)), 0.0, 1.0),
revealed_fact_ids=tuple(state.get("revealed_fact_ids", []) or ()),
slip=bool(state.get("slip", False)),
)
return InterrogationTurn(spoken=str(wire.get("spoken", "")).strip(), internal=internal)
@lru_cache(maxsize=1)
def _turn_schema() -> dict:
return _TurnWire.model_json_schema()
def _turn_params(seed: int | None, temperature: float, *, grammar: bool) -> GenParams:
# Grammar-free by default: CPU grammar-sampling is ~8x slower, and the prompt already
# spells out the exact wire shape. The grammar (json_schema) path is the reliable
# fallback used only when a free parse fails.
return GenParams(
json_schema=_turn_schema() if grammar else None,
max_tokens=SPOKEN_MAX_TOKENS + 220,
temperature=temperature,
seed=seed,
# A 1.5B model otherwise copies its previous answer verbatim across turns; penalise
# repeated tokens so each question gets a fresh reply - but gently enough that the
# evidence reaction still produces a line (high penalties can starve the output).
repeat_penalty=1.25,
frequency_penalty=0.4,
presence_penalty=0.3,
)
def generate_turn(
backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7
) -> InterrogationTurn:
"""Non-streaming dual-output fallback: grammar-constrained for reliability, with one
repair retry and a safe default."""
for attempt, temp in enumerate((temperature, 0.0)):
try:
raw = backend.generate(prompt, _turn_params(seed, temp, grammar=True))
return wire_to_turn(json.loads(_extract_json(raw)))
except (json.JSONDecodeError, ValidationError, LLMError, ValueError):
if attempt == 0:
prompt = prompt + _REPAIR_SUFFIX
continue
return InterrogationTurn.safe_default()
def stream_turn(
backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7
) -> Iterator[TurnEvent]:
"""Stream the dual-output turn (grammar-free for speed), emitting spoken text as it
arrives, then the final parsed turn. Falls back to the grammar path if parsing fails."""
scanner = SpokenScanner()
raw_parts: list[str] = []
try:
for delta in backend.stream(prompt, _turn_params(seed, temperature, grammar=False)):
raw_parts.append(delta)
spoken_delta = scanner.feed(delta)
if spoken_delta:
yield TurnEvent(spoken_delta=spoken_delta)
turn = wire_to_turn(json.loads(_extract_json("".join(raw_parts))))
except (json.JSONDecodeError, ValidationError, LLMError, ValueError):
# generate_turn manages its own repair suffix; don't double it.
turn = generate_turn(backend, prompt, seed=seed, temperature=0.0)
# The repaired turn was not streamed; emit its spoken text now.
if not scanner.done:
yield TurnEvent(spoken_delta=turn.spoken)
yield TurnEvent(final=turn)
def generate_model[M: BaseModel](
backend: LLMBackend,
prompt: str,
model_cls: type[M],
*,
temperature: float = 0.4,
max_tokens: int = 1200,
seed: int | None = None,
) -> M:
"""Generate a JSON object constrained to ``model_cls`` and validate it.
Fast path FIRST: generate WITHOUT a grammar (CPU grammar-sampling is ~8x slower)
using the JSON template baked into the prompt, then parse + validate. Only if that
fails do we fall back to the grammar-constrained path (slower, but guarantees
structure) with one repair retry. Raises ``LLMError`` if every attempt fails.
"""
# Fast path: TWO grammar-free attempts (the prompt carries the exact JSON shape).
# The second retries at temperature 0 with a repair nudge - still ~8x faster than
# the grammar path, so most recoveries stay fast.
last_err: Exception | None = None
for attempt in range(2):
try:
prompt_n = prompt if attempt == 0 else prompt + _REPAIR_SUFFIX
temp_n = temperature if attempt == 0 else 0.0
raw = backend.generate(
prompt_n, GenParams(temperature=temp_n, max_tokens=max_tokens, seed=seed)
)
return model_cls.model_validate_json(_extract_json(raw))
except (json.JSONDecodeError, ValidationError, LLMError) as exc:
last_err = exc
# Reliable fallback: grammar-constrained (slower, guarantees structure).
schema = model_cls.model_json_schema()
gparams = GenParams(json_schema=schema, temperature=0.0, max_tokens=max_tokens, seed=seed)
try:
raw = backend.generate(prompt, gparams)
return model_cls.model_validate_json(_extract_json(raw))
except (json.JSONDecodeError, ValidationError, LLMError) as exc:
raise LLMError(
f"structured generation failed for {model_cls.__name__}: {exc}"
) from last_err