"""Structured decoding: turn raw model text into validated schema objects. Two paths share this module: - the interrogation hot path (``stream_turn`` / ``generate_turn``) parses the dual-output wire shape produced by ``dual_output.gbnf``; - the generator (``generate_model``) constrains a one-shot call to a pydantic model's JSON schema, validates, and repairs once on failure. """ from __future__ import annotations import json from collections.abc import Iterator from dataclasses import dataclass from functools import lru_cache from pydantic import BaseModel, Field, ValidationError from ..constants import GRAMMARS_DIR, SPOKEN_MAX_TOKENS from ..schemas.enums import EvidenceReaction, Intent from ..schemas.interrogation import InternalState, InterrogationTurn from .backend import GenParams, LLMBackend, LLMError class _StateWire(BaseModel): """The mechanical-state half of the wire shape (drives grammar generation).""" intent: Intent is_lying: bool active_lie_id: str | None evidence_reaction: EvidenceReaction deception_level: int = Field(ge=0, le=100) stress: float = Field(ge=0.0, le=1.0) revealed_fact_ids: list[str] slip: bool class _TurnWire(BaseModel): """The full dual-output wire shape. Its JSON schema is converted to a grammar by llama.cpp (the supported path) so a small model always emits valid structure.""" think: str spoken: str state: _StateWire _REPAIR_SUFFIX = ( "\n\nYour previous reply was not valid. Reply again with ONLY the JSON object, " "matching the required schema exactly." ) @lru_cache(maxsize=8) def load_grammar(name: str) -> str: """Read a GBNF grammar file from the grammars directory (cached).""" path = GRAMMARS_DIR / name if not path.exists(): raise FileNotFoundError(f"grammar not found: {path}") return path.read_text(encoding="utf-8") @dataclass class TurnEvent: """One streaming event: a chunk of spoken text and/or the final parsed turn.""" spoken_delta: str = "" final: InterrogationTurn | None = None class SpokenScanner: """Incrementally extracts the decoded value of the ``"spoken"`` JSON field from a streaming completion, so dialogue can render and be voiced before the full object (including the trailing mechanical state) has arrived.""" # Include the colon so the scanner locks onto the KEY "spoken": and never onto # the word "spoken" appearing inside the (earlier) think value. _KEY = '"spoken":' def __init__(self) -> None: self._phase = "seek" # seek -> precolon -> instr -> done self._tail = "" self._escape = False self._uni: list[str] | None = None self._done = False def feed(self, delta: str) -> str: out: list[str] = [] for ch in delta: piece = self._consume(ch) if piece: out.append(piece) return "".join(out) def _consume(self, ch: str) -> str: if self._done: return "" if self._phase == "seek": self._tail = (self._tail + ch)[-len(self._KEY):] if self._tail == self._KEY: self._phase = "precolon" return "" if self._phase == "precolon": if ch == '"': self._phase = "instr" return "" # inside the spoken string value if self._uni is not None: self._uni.append(ch) if len(self._uni) == 4: code = "".join(self._uni) self._uni = None try: return chr(int(code, 16)) except ValueError: return "" return "" if self._escape: self._escape = False if ch == "u": self._uni = [] return "" return _UNESCAPE.get(ch, ch) if ch == "\\": self._escape = True return "" if ch == '"': self._phase = "done" self._done = True return "" return ch @property def done(self) -> bool: return self._done _UNESCAPE = {"n": "\n", "t": "\t", "r": "\r", "b": "\b", "f": "\f", "/": "/", '"': '"', "\\": "\\"} def _extract_json(text: str) -> str: start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end < start: raise LLMError("no JSON object found in completion") return text[start : end + 1] def _clamp(value: float, low: float, high: float) -> float: return max(low, min(high, value)) def wire_to_turn(wire: dict) -> InterrogationTurn: """Map the ``dual_output.gbnf`` wire shape onto an InterrogationTurn. Advisory numeric fields are clamped, not rejected: a slightly out-of-range value must never discard an otherwise-valid turn (the deterministic director, not these numbers, is the mechanical authority).""" state = dict(wire.get("state") or {}) active = state.get("active_lie_id") # Coerce the advisory enums to safe defaults if the (grammar-free) model emits an # unknown value, so a single bad field never discards an otherwise-valid turn. intent = state.get("intent", "deflect") if intent not in Intent._value2member_map_: intent = "deflect" reaction = state.get("evidence_reaction", "none") if reaction not in EvidenceReaction._value2member_map_: reaction = "none" internal = InternalState( private_reasoning=str(wire.get("think", "")), intent=intent, is_lying=bool(state.get("is_lying", False)), active_lie_id=active if active not in (None, "null", "") else None, evidence_reaction=reaction, deception_level=int(_clamp(float(state.get("deception_level", 0)), 0, 100)), stress=_clamp(float(state.get("stress", 0.0)), 0.0, 1.0), revealed_fact_ids=tuple(state.get("revealed_fact_ids", []) or ()), slip=bool(state.get("slip", False)), ) return InterrogationTurn(spoken=str(wire.get("spoken", "")).strip(), internal=internal) @lru_cache(maxsize=1) def _turn_schema() -> dict: return _TurnWire.model_json_schema() def _turn_params(seed: int | None, temperature: float, *, grammar: bool) -> GenParams: # Grammar-free by default: CPU grammar-sampling is ~8x slower, and the prompt already # spells out the exact wire shape. The grammar (json_schema) path is the reliable # fallback used only when a free parse fails. return GenParams( json_schema=_turn_schema() if grammar else None, max_tokens=SPOKEN_MAX_TOKENS + 220, temperature=temperature, seed=seed, # A 1.5B model otherwise copies its previous answer verbatim across turns; penalise # repeated tokens so each question gets a fresh reply - but gently enough that the # evidence reaction still produces a line (high penalties can starve the output). repeat_penalty=1.25, frequency_penalty=0.4, presence_penalty=0.3, ) def generate_turn( backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7 ) -> InterrogationTurn: """Non-streaming dual-output fallback: grammar-constrained for reliability, with one repair retry and a safe default.""" for attempt, temp in enumerate((temperature, 0.0)): try: raw = backend.generate(prompt, _turn_params(seed, temp, grammar=True)) return wire_to_turn(json.loads(_extract_json(raw))) except (json.JSONDecodeError, ValidationError, LLMError, ValueError): if attempt == 0: prompt = prompt + _REPAIR_SUFFIX continue return InterrogationTurn.safe_default() def stream_turn( backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7 ) -> Iterator[TurnEvent]: """Stream the dual-output turn (grammar-free for speed), emitting spoken text as it arrives, then the final parsed turn. Falls back to the grammar path if parsing fails.""" scanner = SpokenScanner() raw_parts: list[str] = [] try: for delta in backend.stream(prompt, _turn_params(seed, temperature, grammar=False)): raw_parts.append(delta) spoken_delta = scanner.feed(delta) if spoken_delta: yield TurnEvent(spoken_delta=spoken_delta) turn = wire_to_turn(json.loads(_extract_json("".join(raw_parts)))) except (json.JSONDecodeError, ValidationError, LLMError, ValueError): # generate_turn manages its own repair suffix; don't double it. turn = generate_turn(backend, prompt, seed=seed, temperature=0.0) # The repaired turn was not streamed; emit its spoken text now. if not scanner.done: yield TurnEvent(spoken_delta=turn.spoken) yield TurnEvent(final=turn) def generate_model[M: BaseModel]( backend: LLMBackend, prompt: str, model_cls: type[M], *, temperature: float = 0.4, max_tokens: int = 1200, seed: int | None = None, ) -> M: """Generate a JSON object constrained to ``model_cls`` and validate it. Fast path FIRST: generate WITHOUT a grammar (CPU grammar-sampling is ~8x slower) using the JSON template baked into the prompt, then parse + validate. Only if that fails do we fall back to the grammar-constrained path (slower, but guarantees structure) with one repair retry. Raises ``LLMError`` if every attempt fails. """ # Fast path: TWO grammar-free attempts (the prompt carries the exact JSON shape). # The second retries at temperature 0 with a repair nudge - still ~8x faster than # the grammar path, so most recoveries stay fast. last_err: Exception | None = None for attempt in range(2): try: prompt_n = prompt if attempt == 0 else prompt + _REPAIR_SUFFIX temp_n = temperature if attempt == 0 else 0.0 raw = backend.generate( prompt_n, GenParams(temperature=temp_n, max_tokens=max_tokens, seed=seed) ) return model_cls.model_validate_json(_extract_json(raw)) except (json.JSONDecodeError, ValidationError, LLMError) as exc: last_err = exc # Reliable fallback: grammar-constrained (slower, guarantees structure). schema = model_cls.model_json_schema() gparams = GenParams(json_schema=schema, temperature=0.0, max_tokens=max_tokens, seed=seed) try: raw = backend.generate(prompt, gparams) return model_cls.model_validate_json(_extract_json(raw)) except (json.JSONDecodeError, ValidationError, LLMError) as exc: raise LLMError( f"structured generation failed for {model_cls.__name__}: {exc}" ) from last_err