File size: 10,802 Bytes
414dc55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""Structured decoding: turn raw model text into validated schema objects.

Two paths share this module:
- the interrogation hot path (``stream_turn`` / ``generate_turn``) parses the
  dual-output wire shape produced by ``dual_output.gbnf``;
- the generator (``generate_model``) constrains a one-shot call to a pydantic
  model's JSON schema, validates, and repairs once on failure.
"""

from __future__ import annotations

import json
from collections.abc import Iterator
from dataclasses import dataclass
from functools import lru_cache

from pydantic import BaseModel, Field, ValidationError

from ..constants import GRAMMARS_DIR, SPOKEN_MAX_TOKENS
from ..schemas.enums import EvidenceReaction, Intent
from ..schemas.interrogation import InternalState, InterrogationTurn
from .backend import GenParams, LLMBackend, LLMError


class _StateWire(BaseModel):
    """The mechanical-state half of the wire shape (drives grammar generation)."""

    intent: Intent
    is_lying: bool
    active_lie_id: str | None
    evidence_reaction: EvidenceReaction
    deception_level: int = Field(ge=0, le=100)
    stress: float = Field(ge=0.0, le=1.0)
    revealed_fact_ids: list[str]
    slip: bool


class _TurnWire(BaseModel):
    """The full dual-output wire shape. Its JSON schema is converted to a grammar by
    llama.cpp (the supported path) so a small model always emits valid structure."""

    think: str
    spoken: str
    state: _StateWire

_REPAIR_SUFFIX = (
    "\n\nYour previous reply was not valid. Reply again with ONLY the JSON object, "
    "matching the required schema exactly."
)


@lru_cache(maxsize=8)
def load_grammar(name: str) -> str:
    """Read a GBNF grammar file from the grammars directory (cached)."""
    path = GRAMMARS_DIR / name
    if not path.exists():
        raise FileNotFoundError(f"grammar not found: {path}")
    return path.read_text(encoding="utf-8")


@dataclass
class TurnEvent:
    """One streaming event: a chunk of spoken text and/or the final parsed turn."""

    spoken_delta: str = ""
    final: InterrogationTurn | None = None


class SpokenScanner:
    """Incrementally extracts the decoded value of the ``"spoken"`` JSON field from
    a streaming completion, so dialogue can render and be voiced before the full
    object (including the trailing mechanical state) has arrived."""

    # Include the colon so the scanner locks onto the KEY "spoken": and never onto
    # the word "spoken" appearing inside the (earlier) think value.
    _KEY = '"spoken":'

    def __init__(self) -> None:
        self._phase = "seek"  # seek -> precolon -> instr -> done
        self._tail = ""
        self._escape = False
        self._uni: list[str] | None = None
        self._done = False

    def feed(self, delta: str) -> str:
        out: list[str] = []
        for ch in delta:
            piece = self._consume(ch)
            if piece:
                out.append(piece)
        return "".join(out)

    def _consume(self, ch: str) -> str:
        if self._done:
            return ""
        if self._phase == "seek":
            self._tail = (self._tail + ch)[-len(self._KEY):]
            if self._tail == self._KEY:
                self._phase = "precolon"
            return ""
        if self._phase == "precolon":
            if ch == '"':
                self._phase = "instr"
            return ""
        # inside the spoken string value
        if self._uni is not None:
            self._uni.append(ch)
            if len(self._uni) == 4:
                code = "".join(self._uni)
                self._uni = None
                try:
                    return chr(int(code, 16))
                except ValueError:
                    return ""
            return ""
        if self._escape:
            self._escape = False
            if ch == "u":
                self._uni = []
                return ""
            return _UNESCAPE.get(ch, ch)
        if ch == "\\":
            self._escape = True
            return ""
        if ch == '"':
            self._phase = "done"
            self._done = True
            return ""
        return ch

    @property
    def done(self) -> bool:
        return self._done


_UNESCAPE = {"n": "\n", "t": "\t", "r": "\r", "b": "\b", "f": "\f", "/": "/", '"': '"', "\\": "\\"}


def _extract_json(text: str) -> str:
    start = text.find("{")
    end = text.rfind("}")
    if start == -1 or end == -1 or end < start:
        raise LLMError("no JSON object found in completion")
    return text[start : end + 1]


def _clamp(value: float, low: float, high: float) -> float:
    return max(low, min(high, value))


def wire_to_turn(wire: dict) -> InterrogationTurn:
    """Map the ``dual_output.gbnf`` wire shape onto an InterrogationTurn.

    Advisory numeric fields are clamped, not rejected: a slightly out-of-range value
    must never discard an otherwise-valid turn (the deterministic director, not these
    numbers, is the mechanical authority)."""
    state = dict(wire.get("state") or {})
    active = state.get("active_lie_id")
    # Coerce the advisory enums to safe defaults if the (grammar-free) model emits an
    # unknown value, so a single bad field never discards an otherwise-valid turn.
    intent = state.get("intent", "deflect")
    if intent not in Intent._value2member_map_:
        intent = "deflect"
    reaction = state.get("evidence_reaction", "none")
    if reaction not in EvidenceReaction._value2member_map_:
        reaction = "none"
    internal = InternalState(
        private_reasoning=str(wire.get("think", "")),
        intent=intent,
        is_lying=bool(state.get("is_lying", False)),
        active_lie_id=active if active not in (None, "null", "") else None,
        evidence_reaction=reaction,
        deception_level=int(_clamp(float(state.get("deception_level", 0)), 0, 100)),
        stress=_clamp(float(state.get("stress", 0.0)), 0.0, 1.0),
        revealed_fact_ids=tuple(state.get("revealed_fact_ids", []) or ()),
        slip=bool(state.get("slip", False)),
    )
    return InterrogationTurn(spoken=str(wire.get("spoken", "")).strip(), internal=internal)


@lru_cache(maxsize=1)
def _turn_schema() -> dict:
    return _TurnWire.model_json_schema()


def _turn_params(seed: int | None, temperature: float, *, grammar: bool) -> GenParams:
    # Grammar-free by default: CPU grammar-sampling is ~8x slower, and the prompt already
    # spells out the exact wire shape. The grammar (json_schema) path is the reliable
    # fallback used only when a free parse fails.
    return GenParams(
        json_schema=_turn_schema() if grammar else None,
        max_tokens=SPOKEN_MAX_TOKENS + 220,
        temperature=temperature,
        seed=seed,
        # A 1.5B model otherwise copies its previous answer verbatim across turns; penalise
        # repeated tokens so each question gets a fresh reply - but gently enough that the
        # evidence reaction still produces a line (high penalties can starve the output).
        repeat_penalty=1.25,
        frequency_penalty=0.4,
        presence_penalty=0.3,
    )


def generate_turn(
    backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7
) -> InterrogationTurn:
    """Non-streaming dual-output fallback: grammar-constrained for reliability, with one
    repair retry and a safe default."""
    for attempt, temp in enumerate((temperature, 0.0)):
        try:
            raw = backend.generate(prompt, _turn_params(seed, temp, grammar=True))
            return wire_to_turn(json.loads(_extract_json(raw)))
        except (json.JSONDecodeError, ValidationError, LLMError, ValueError):
            if attempt == 0:
                prompt = prompt + _REPAIR_SUFFIX
                continue
    return InterrogationTurn.safe_default()


def stream_turn(
    backend: LLMBackend, prompt: str, *, seed: int | None = None, temperature: float = 0.7
) -> Iterator[TurnEvent]:
    """Stream the dual-output turn (grammar-free for speed), emitting spoken text as it
    arrives, then the final parsed turn. Falls back to the grammar path if parsing fails."""
    scanner = SpokenScanner()
    raw_parts: list[str] = []
    try:
        for delta in backend.stream(prompt, _turn_params(seed, temperature, grammar=False)):
            raw_parts.append(delta)
            spoken_delta = scanner.feed(delta)
            if spoken_delta:
                yield TurnEvent(spoken_delta=spoken_delta)
        turn = wire_to_turn(json.loads(_extract_json("".join(raw_parts))))
    except (json.JSONDecodeError, ValidationError, LLMError, ValueError):
        # generate_turn manages its own repair suffix; don't double it.
        turn = generate_turn(backend, prompt, seed=seed, temperature=0.0)
        # The repaired turn was not streamed; emit its spoken text now.
        if not scanner.done:
            yield TurnEvent(spoken_delta=turn.spoken)
    yield TurnEvent(final=turn)


def generate_model[M: BaseModel](
    backend: LLMBackend,
    prompt: str,
    model_cls: type[M],
    *,
    temperature: float = 0.4,
    max_tokens: int = 1200,
    seed: int | None = None,
) -> M:
    """Generate a JSON object constrained to ``model_cls`` and validate it.

    Fast path FIRST: generate WITHOUT a grammar (CPU grammar-sampling is ~8x slower)
    using the JSON template baked into the prompt, then parse + validate. Only if that
    fails do we fall back to the grammar-constrained path (slower, but guarantees
    structure) with one repair retry. Raises ``LLMError`` if every attempt fails.
    """
    # Fast path: TWO grammar-free attempts (the prompt carries the exact JSON shape).
    # The second retries at temperature 0 with a repair nudge - still ~8x faster than
    # the grammar path, so most recoveries stay fast.
    last_err: Exception | None = None
    for attempt in range(2):
        try:
            prompt_n = prompt if attempt == 0 else prompt + _REPAIR_SUFFIX
            temp_n = temperature if attempt == 0 else 0.0
            raw = backend.generate(
                prompt_n, GenParams(temperature=temp_n, max_tokens=max_tokens, seed=seed)
            )
            return model_cls.model_validate_json(_extract_json(raw))
        except (json.JSONDecodeError, ValidationError, LLMError) as exc:
            last_err = exc

    # Reliable fallback: grammar-constrained (slower, guarantees structure).
    schema = model_cls.model_json_schema()
    gparams = GenParams(json_schema=schema, temperature=0.0, max_tokens=max_tokens, seed=seed)
    try:
        raw = backend.generate(prompt, gparams)
        return model_cls.model_validate_json(_extract_json(raw))
    except (json.JSONDecodeError, ValidationError, LLMError) as exc:
        raise LLMError(
            f"structured generation failed for {model_cls.__name__}: {exc}"
        ) from last_err