Spaces:
Running
Running
File size: 10,802 Bytes
414dc55 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 | """Structured decoding: turn raw model text into validated schema objects.
Two paths share this module:
- the interrogation hot path (``stream_turn`` / ``generate_turn``) parses the
dual-output wire shape produced by ``dual_output.gbnf``;
- the generator (``generate_model``) constrains a one-shot call to a pydantic
model's JSON schema, validates, and repairs once on failure.
"""
from __future__ import annotations
import json
from collections.abc import Iterator
from dataclasses import dataclass
from functools import lru_cache
from pydantic import BaseModel, Field, ValidationError
from ..constants import GRAMMARS_DIR, SPOKEN_MAX_TOKENS
from ..schemas.enums import EvidenceReaction, Intent
from ..schemas.interrogation import InternalState, InterrogationTurn
from .backend import GenParams, LLMBackend, LLMError
class _StateWire(BaseModel):
"""The mechanical-state half of the wire shape (drives grammar generation)."""
intent: Intent
is_lying: bool
active_lie_id: str | None
evidence_reaction: EvidenceReaction
deception_level: int = Field(ge=0, le=100)
stress: float = Field(ge=0.0, le=1.0)
revealed_fact_ids: list[str]
slip: bool
class _TurnWire(BaseModel):
"""The full dual-output wire shape. Its JSON schema is converted to a grammar by
llama.cpp (the supported path) so a small model always emits valid structure."""
think: str
spoken: str
state: _StateWire
_REPAIR_SUFFIX = (
"\n\nYour previous reply was not valid. Reply again with ONLY the JSON object, "
"matching the required schema exactly."
)
@lru_cache(maxsize=8)
def load_grammar(name: str) -> str:
"""Read a GBNF grammar file from the grammars directory (cached)."""
path = GRAMMARS_DIR / name
if not path.exists():
raise FileNotFoundError(f"grammar not found: {path}")
return path.read_text(encoding="utf-8")
@dataclass
class TurnEvent:
"""One streaming event: a chunk of spoken text and/or the final parsed turn."""
spoken_delta: str = ""
final: InterrogationTurn | None = None
class SpokenScanner:
"""Incrementally extracts the decoded value of the ``"spoken"`` JSON field from
a streaming completion, so dialogue can render and be voiced before the full
object (including the trailing mechanical state) has arrived."""
# Include the colon so the scanner locks onto the KEY "spoken": and never onto
# the word "spoken" appearing inside the (earlier) think value.
_KEY = '"spoken":'
def __init__(self) -> None:
self._phase = "seek" # seek -> precolon -> instr -> done
self._tail = ""
self._escape = False
self._uni: list[str] | None = None
self._done = False
def feed(self, delta: str) -> str:
out: list[str] = []
for ch in delta:
piece = self._consume(ch)
if piece:
out.append(piece)
return "".join(out)
def _consume(self, ch: str) -> str:
if self._done:
return ""
if self._phase == "seek":
self._tail = (self._tail + ch)[-len(self._KEY):]
if self._tail == self._KEY:
self._phase = "precolon"
return ""
if self._phase == "precolon":
if ch == '"':
self._phase = "instr"
return ""
# inside the spoken string value
if self._uni is not None:
self._uni.append(ch)
if len(self._uni) == 4:
code = "".join(self._uni)
self._uni = None
try:
return chr(int(code, 16))
except ValueError:
return ""
return ""
if self._escape:
self._escape = False
if ch == "u":
self._uni = []
return ""
return _UNESCAPE.get(ch, ch)
if ch == "\\":
self._escape = True
return ""
if ch == '"':
self._phase = "done"
self._done = True
return ""
return ch
@property
def done(self) -> bool:
return self._done
_UNESCAPE = {"n": "\n", "t": "\t", "r": "\r", "b": "\b", "f": "\f", "/": "/", '"': '"', "\\": "\\"}
def _extract_json(text: str) -> str:
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end < start:
raise LLMError("no JSON object found in completion")
return text[start : end + 1]
def _clamp(value: float, low: float, high: float) -> float:
return max(low, min(high, value))
def wire_to_turn(wire: dict) -> InterrogationTurn:
"""Map the ``dual_output.gbnf`` wire shape onto an InterrogationTurn.
Advisory numeric fields are clamped, not rejected: a slightly out-of-range value
must never discard an otherwise-valid turn (the deterministic director, not these
numbers, is the mechanical authority)."""
state = dict(wire.get("state") or {})
active = state.get("active_lie_id")
# Coerce the advisory enums to safe defaults if the (grammar-free) model emits an
# unknown value, so a single bad field never discards an otherwise-valid turn.
intent = state.get("intent", "deflect")
if intent not in Intent._value2member_map_:
intent = "deflect"
reaction = state.get("evidence_reaction", "none")
if reaction not in EvidenceReaction._value2member_map_:
reaction = "none"
internal = InternalState(
private_reasoning=str(wire.get("think", "")),
intent=intent,
is_lying=bool(state.get("is_lying", False)),
active_lie_id=active if active not in (None, "null", "") else None,
evidence_reaction=reaction,
deception_level=int(_clamp(float(state.get("deception_level", 0)), 0, 100)),
stress=_clamp(float(state.get("stress", 0.0)), 0.0, 1.0),
revealed_fact_ids=tuple(state.get("revealed_fact_ids", []) or ()),
slip=bool(state.get("slip", False)),
)
return InterrogationTurn(spoken=str(wire.get("spoken", "")).strip(), internal=internal)
@lru_cache(maxsize=1)
def _turn_schema() -> dict:
return _TurnWire.model_json_schema()
def _turn_params(seed: int | None, temperature: float, *, grammar: bool) -> GenParams:
# Grammar-free by default: CPU grammar-sampling is ~8x slower, and the prompt already
# spells out the exact wire shape. The grammar (json_schema) path is the reliable
# fallback used only when a free parse fails.
return GenParams(
json_schema=_turn_schema() if grammar else None,
max_tokens=SPOKEN_MAX_TOKENS + 220,
temperature=temperature,
seed=seed,
# A 1.5B model otherwise copies its previous answer verbatim across turns; penalise
# repeated tokens so each question gets a fresh reply - but gently enough that the
# evidence reaction still produces a line (high penalties can starve the output).
repeat_penalty=1.25,
frequency_penalty=0.4,
presence_penalty=0.3,
)
def generate_turn(
backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7
) -> InterrogationTurn:
"""Non-streaming dual-output fallback: grammar-constrained for reliability, with one
repair retry and a safe default."""
for attempt, temp in enumerate((temperature, 0.0)):
try:
raw = backend.generate(prompt, _turn_params(seed, temp, grammar=True))
return wire_to_turn(json.loads(_extract_json(raw)))
except (json.JSONDecodeError, ValidationError, LLMError, ValueError):
if attempt == 0:
prompt = prompt + _REPAIR_SUFFIX
continue
return InterrogationTurn.safe_default()
def stream_turn(
backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7
) -> Iterator[TurnEvent]:
"""Stream the dual-output turn (grammar-free for speed), emitting spoken text as it
arrives, then the final parsed turn. Falls back to the grammar path if parsing fails."""
scanner = SpokenScanner()
raw_parts: list[str] = []
try:
for delta in backend.stream(prompt, _turn_params(seed, temperature, grammar=False)):
raw_parts.append(delta)
spoken_delta = scanner.feed(delta)
if spoken_delta:
yield TurnEvent(spoken_delta=spoken_delta)
turn = wire_to_turn(json.loads(_extract_json("".join(raw_parts))))
except (json.JSONDecodeError, ValidationError, LLMError, ValueError):
# generate_turn manages its own repair suffix; don't double it.
turn = generate_turn(backend, prompt, seed=seed, temperature=0.0)
# The repaired turn was not streamed; emit its spoken text now.
if not scanner.done:
yield TurnEvent(spoken_delta=turn.spoken)
yield TurnEvent(final=turn)
def generate_model[M: BaseModel](
backend: LLMBackend,
prompt: str,
model_cls: type[M],
*,
temperature: float = 0.4,
max_tokens: int = 1200,
seed: int | None = None,
) -> M:
"""Generate a JSON object constrained to ``model_cls`` and validate it.
Fast path FIRST: generate WITHOUT a grammar (CPU grammar-sampling is ~8x slower)
using the JSON template baked into the prompt, then parse + validate. Only if that
fails do we fall back to the grammar-constrained path (slower, but guarantees
structure) with one repair retry. Raises ``LLMError`` if every attempt fails.
"""
# Fast path: TWO grammar-free attempts (the prompt carries the exact JSON shape).
# The second retries at temperature 0 with a repair nudge - still ~8x faster than
# the grammar path, so most recoveries stay fast.
last_err: Exception | None = None
for attempt in range(2):
try:
prompt_n = prompt if attempt == 0 else prompt + _REPAIR_SUFFIX
temp_n = temperature if attempt == 0 else 0.0
raw = backend.generate(
prompt_n, GenParams(temperature=temp_n, max_tokens=max_tokens, seed=seed)
)
return model_cls.model_validate_json(_extract_json(raw))
except (json.JSONDecodeError, ValidationError, LLMError) as exc:
last_err = exc
# Reliable fallback: grammar-constrained (slower, guarantees structure).
schema = model_cls.model_json_schema()
gparams = GenParams(json_schema=schema, temperature=0.0, max_tokens=max_tokens, seed=seed)
try:
raw = backend.generate(prompt, gparams)
return model_cls.model_validate_json(_extract_json(raw))
except (json.JSONDecodeError, ValidationError, LLMError) as exc:
raise LLMError(
f"structured generation failed for {model_cls.__name__}: {exc}"
) from last_err
|