Spaces:
Running
Running
Case Zero - initial public release (fully local: Qwen2.5-1.5B via llama.cpp + Supertonic, custom pixel-noir SPA via gradio.Server)
414dc55 | """Structured decoding: turn raw model text into validated schema objects. | |
| Two paths share this module: | |
| - the interrogation hot path (``stream_turn`` / ``generate_turn``) parses the | |
| dual-output wire shape produced by ``dual_output.gbnf``; | |
| - the generator (``generate_model``) constrains a one-shot call to a pydantic | |
| model's JSON schema, validates, and repairs once on failure. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from collections.abc import Iterator | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pydantic import BaseModel, Field, ValidationError | |
| from ..constants import GRAMMARS_DIR, SPOKEN_MAX_TOKENS | |
| from ..schemas.enums import EvidenceReaction, Intent | |
| from ..schemas.interrogation import InternalState, InterrogationTurn | |
| from .backend import GenParams, LLMBackend, LLMError | |
| class _StateWire(BaseModel): | |
| """The mechanical-state half of the wire shape (drives grammar generation).""" | |
| intent: Intent | |
| is_lying: bool | |
| active_lie_id: str | None | |
| evidence_reaction: EvidenceReaction | |
| deception_level: int = Field(ge=0, le=100) | |
| stress: float = Field(ge=0.0, le=1.0) | |
| revealed_fact_ids: list[str] | |
| slip: bool | |
| class _TurnWire(BaseModel): | |
| """The full dual-output wire shape. Its JSON schema is converted to a grammar by | |
| llama.cpp (the supported path) so a small model always emits valid structure.""" | |
| think: str | |
| spoken: str | |
| state: _StateWire | |
| _REPAIR_SUFFIX = ( | |
| "\n\nYour previous reply was not valid. Reply again with ONLY the JSON object, " | |
| "matching the required schema exactly." | |
| ) | |
| def load_grammar(name: str) -> str: | |
| """Read a GBNF grammar file from the grammars directory (cached).""" | |
| path = GRAMMARS_DIR / name | |
| if not path.exists(): | |
| raise FileNotFoundError(f"grammar not found: {path}") | |
| return path.read_text(encoding="utf-8") | |
| class TurnEvent: | |
| """One streaming event: a chunk of spoken text and/or the final parsed turn.""" | |
| spoken_delta: str = "" | |
| final: InterrogationTurn | None = None | |
| class SpokenScanner: | |
| """Incrementally extracts the decoded value of the ``"spoken"`` JSON field from | |
| a streaming completion, so dialogue can render and be voiced before the full | |
| object (including the trailing mechanical state) has arrived.""" | |
| # Include the colon so the scanner locks onto the KEY "spoken": and never onto | |
| # the word "spoken" appearing inside the (earlier) think value. | |
| _KEY = '"spoken":' | |
| def __init__(self) -> None: | |
| self._phase = "seek" # seek -> precolon -> instr -> done | |
| self._tail = "" | |
| self._escape = False | |
| self._uni: list[str] | None = None | |
| self._done = False | |
| def feed(self, delta: str) -> str: | |
| out: list[str] = [] | |
| for ch in delta: | |
| piece = self._consume(ch) | |
| if piece: | |
| out.append(piece) | |
| return "".join(out) | |
| def _consume(self, ch: str) -> str: | |
| if self._done: | |
| return "" | |
| if self._phase == "seek": | |
| self._tail = (self._tail + ch)[-len(self._KEY):] | |
| if self._tail == self._KEY: | |
| self._phase = "precolon" | |
| return "" | |
| if self._phase == "precolon": | |
| if ch == '"': | |
| self._phase = "instr" | |
| return "" | |
| # inside the spoken string value | |
| if self._uni is not None: | |
| self._uni.append(ch) | |
| if len(self._uni) == 4: | |
| code = "".join(self._uni) | |
| self._uni = None | |
| try: | |
| return chr(int(code, 16)) | |
| except ValueError: | |
| return "" | |
| return "" | |
| if self._escape: | |
| self._escape = False | |
| if ch == "u": | |
| self._uni = [] | |
| return "" | |
| return _UNESCAPE.get(ch, ch) | |
| if ch == "\\": | |
| self._escape = True | |
| return "" | |
| if ch == '"': | |
| self._phase = "done" | |
| self._done = True | |
| return "" | |
| return ch | |
| def done(self) -> bool: | |
| return self._done | |
| _UNESCAPE = {"n": "\n", "t": "\t", "r": "\r", "b": "\b", "f": "\f", "/": "/", '"': '"', "\\": "\\"} | |
| def _extract_json(text: str) -> str: | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start == -1 or end == -1 or end < start: | |
| raise LLMError("no JSON object found in completion") | |
| return text[start : end + 1] | |
| def _clamp(value: float, low: float, high: float) -> float: | |
| return max(low, min(high, value)) | |
| def wire_to_turn(wire: dict) -> InterrogationTurn: | |
| """Map the ``dual_output.gbnf`` wire shape onto an InterrogationTurn. | |
| Advisory numeric fields are clamped, not rejected: a slightly out-of-range value | |
| must never discard an otherwise-valid turn (the deterministic director, not these | |
| numbers, is the mechanical authority).""" | |
| state = dict(wire.get("state") or {}) | |
| active = state.get("active_lie_id") | |
| # Coerce the advisory enums to safe defaults if the (grammar-free) model emits an | |
| # unknown value, so a single bad field never discards an otherwise-valid turn. | |
| intent = state.get("intent", "deflect") | |
| if intent not in Intent._value2member_map_: | |
| intent = "deflect" | |
| reaction = state.get("evidence_reaction", "none") | |
| if reaction not in EvidenceReaction._value2member_map_: | |
| reaction = "none" | |
| internal = InternalState( | |
| private_reasoning=str(wire.get("think", "")), | |
| intent=intent, | |
| is_lying=bool(state.get("is_lying", False)), | |
| active_lie_id=active if active not in (None, "null", "") else None, | |
| evidence_reaction=reaction, | |
| deception_level=int(_clamp(float(state.get("deception_level", 0)), 0, 100)), | |
| stress=_clamp(float(state.get("stress", 0.0)), 0.0, 1.0), | |
| revealed_fact_ids=tuple(state.get("revealed_fact_ids", []) or ()), | |
| slip=bool(state.get("slip", False)), | |
| ) | |
| return InterrogationTurn(spoken=str(wire.get("spoken", "")).strip(), internal=internal) | |
| def _turn_schema() -> dict: | |
| return _TurnWire.model_json_schema() | |
| def _turn_params(seed: int | None, temperature: float, *, grammar: bool) -> GenParams: | |
| # Grammar-free by default: CPU grammar-sampling is ~8x slower, and the prompt already | |
| # spells out the exact wire shape. The grammar (json_schema) path is the reliable | |
| # fallback used only when a free parse fails. | |
| return GenParams( | |
| json_schema=_turn_schema() if grammar else None, | |
| max_tokens=SPOKEN_MAX_TOKENS + 220, | |
| temperature=temperature, | |
| seed=seed, | |
| # A 1.5B model otherwise copies its previous answer verbatim across turns; penalise | |
| # repeated tokens so each question gets a fresh reply - but gently enough that the | |
| # evidence reaction still produces a line (high penalties can starve the output). | |
| repeat_penalty=1.25, | |
| frequency_penalty=0.4, | |
| presence_penalty=0.3, | |
| ) | |
| def generate_turn( | |
| backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7 | |
| ) -> InterrogationTurn: | |
| """Non-streaming dual-output fallback: grammar-constrained for reliability, with one | |
| repair retry and a safe default.""" | |
| for attempt, temp in enumerate((temperature, 0.0)): | |
| try: | |
| raw = backend.generate(prompt, _turn_params(seed, temp, grammar=True)) | |
| return wire_to_turn(json.loads(_extract_json(raw))) | |
| except (json.JSONDecodeError, ValidationError, LLMError, ValueError): | |
| if attempt == 0: | |
| prompt = prompt + _REPAIR_SUFFIX | |
| continue | |
| return InterrogationTurn.safe_default() | |
| def stream_turn( | |
| backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7 | |
| ) -> Iterator[TurnEvent]: | |
| """Stream the dual-output turn (grammar-free for speed), emitting spoken text as it | |
| arrives, then the final parsed turn. Falls back to the grammar path if parsing fails.""" | |
| scanner = SpokenScanner() | |
| raw_parts: list[str] = [] | |
| try: | |
| for delta in backend.stream(prompt, _turn_params(seed, temperature, grammar=False)): | |
| raw_parts.append(delta) | |
| spoken_delta = scanner.feed(delta) | |
| if spoken_delta: | |
| yield TurnEvent(spoken_delta=spoken_delta) | |
| turn = wire_to_turn(json.loads(_extract_json("".join(raw_parts)))) | |
| except (json.JSONDecodeError, ValidationError, LLMError, ValueError): | |
| # generate_turn manages its own repair suffix; don't double it. | |
| turn = generate_turn(backend, prompt, seed=seed, temperature=0.0) | |
| # The repaired turn was not streamed; emit its spoken text now. | |
| if not scanner.done: | |
| yield TurnEvent(spoken_delta=turn.spoken) | |
| yield TurnEvent(final=turn) | |
| def generate_model[M: BaseModel]( | |
| backend: LLMBackend, | |
| prompt: str, | |
| model_cls: type[M], | |
| *, | |
| temperature: float = 0.4, | |
| max_tokens: int = 1200, | |
| seed: int | None = None, | |
| ) -> M: | |
| """Generate a JSON object constrained to ``model_cls`` and validate it. | |
| Fast path FIRST: generate WITHOUT a grammar (CPU grammar-sampling is ~8x slower) | |
| using the JSON template baked into the prompt, then parse + validate. Only if that | |
| fails do we fall back to the grammar-constrained path (slower, but guarantees | |
| structure) with one repair retry. Raises ``LLMError`` if every attempt fails. | |
| """ | |
| # Fast path: TWO grammar-free attempts (the prompt carries the exact JSON shape). | |
| # The second retries at temperature 0 with a repair nudge - still ~8x faster than | |
| # the grammar path, so most recoveries stay fast. | |
| last_err: Exception | None = None | |
| for attempt in range(2): | |
| try: | |
| prompt_n = prompt if attempt == 0 else prompt + _REPAIR_SUFFIX | |
| temp_n = temperature if attempt == 0 else 0.0 | |
| raw = backend.generate( | |
| prompt_n, GenParams(temperature=temp_n, max_tokens=max_tokens, seed=seed) | |
| ) | |
| return model_cls.model_validate_json(_extract_json(raw)) | |
| except (json.JSONDecodeError, ValidationError, LLMError) as exc: | |
| last_err = exc | |
| # Reliable fallback: grammar-constrained (slower, guarantees structure). | |
| schema = model_cls.model_json_schema() | |
| gparams = GenParams(json_schema=schema, temperature=0.0, max_tokens=max_tokens, seed=seed) | |
| try: | |
| raw = backend.generate(prompt, gparams) | |
| return model_cls.model_validate_json(_extract_json(raw)) | |
| except (json.JSONDecodeError, ValidationError, LLMError) as exc: | |
| raise LLMError( | |
| f"structured generation failed for {model_cls.__name__}: {exc}" | |
| ) from last_err | |