File size: 5,365 Bytes
414dc55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""One interrogation turn: prompt -> streamed model call -> deterministic director.

Exactly one model call per turn. Spoken text streams out as it arrives; mechanics are
decided afterwards by the director against ground truth.
"""

from __future__ import annotations

import re
from collections.abc import Iterator
from dataclasses import dataclass

from ..llm.backend import LLMBackend
from ..llm.decoding import stream_turn
from ..projections.suspect_brief import SuspectBrief
from ..schemas.case import CaseFile
from ..schemas.enums import Relevance
from ..schemas.interrogation import InterrogationTurn
from ..suspects.memory import buffer_text, ledger_text
from ..suspects.persona import build_prompt
from ..suspects.scrub import scrub_spoken
from .director import Adjudication, adjudicate
from .game_state import GameState
from .relevance import assess_relevance
from .state_update import apply_turn

_SUSPECT_TEMPERATURE = 0.8

# Deterministic in-character deflections, used as a backstop when the small model returns an
# empty "spoken" field OR parrots a line it already used this session. Same pattern as the
# confession backstop in scrub.py: the model authors the dialogue; canned lines only stand in
# when its output is unusable, so the player never sees a blank or a duplicate reply.
_PRESSED = (
    "I... I don't know what you want me to say.",
    "Don't look at me like that - I had nothing to do with this.",
    "You're putting words in my mouth.",
    "That doesn't prove a thing, and you know it.",
    "You can stare all you like - it wasn't me.",
    "I'm done answering that. I want a lawyer.",
    "You're twisting this. That is not what happened.",
)
_CALM = (
    "I've already told you everything I know.",
    "There's nothing more to it than that.",
    "I don't see what that has to do with me.",
    "You're reaching, Detective.",
    "Ask me something that actually matters.",
    "I wasn't anywhere near it, if that's what you're getting at.",
    "What else do you want me to say?",
    "That's all there is to tell.",
)


def _norm(s: str) -> str:
    return re.sub(r"[^a-z0-9 ]", " ", s.lower()).strip()


def _too_similar(a: str, b: str) -> bool:
    """True if two replies are effectively the same line (a small model parroting itself)."""
    na, nb = _norm(a), _norm(b)
    if not na or not nb:
        return False
    if na == nb:
        return True
    wa, wb = set(na.split()), set(nb.split())
    if not wa or not wb:
        return False
    return len(wa & wb) / min(len(wa), len(wb)) >= 0.7


def _distinct_deflection(relevance: Relevance, recent: list[str], key: str) -> str:
    """Pick an in-character deflection that is NOT a near-repeat of any recent answer."""
    pool = _PRESSED if relevance in (Relevance.BREAKING, Relevance.DIRECT) else _CALM
    start = sum(map(ord, key)) % len(pool)
    for i in range(len(pool)):
        cand = pool[(start + i) % len(pool)]
        if not any(_too_similar(cand, prev) for prev in recent):
            return cand
    return pool[start]


@dataclass
class FinalTurn:
    turn: InterrogationTurn
    adjudication: Adjudication
    state: GameState


@dataclass
class InterrogationEvent:
    spoken_delta: str = ""
    final: FinalTurn | None = None


def interrogate(
    backend: LLMBackend,
    case: CaseFile,
    brief: SuspectBrief,
    state: GameState,
    sus_id: str,
    question: str,
    presented_clue_id: str | None = None,
    seed: int | None = None,
) -> Iterator[InterrogationEvent]:
    suspect = case.suspect(sus_id)
    sstate = state.state_for(sus_id)
    rel = assess_relevance(case, suspect, presented_clue_id)
    clue = case.clue(presented_clue_id) if presented_clue_id else None

    prompt = build_prompt(
        case=case,
        brief=brief,
        ledger=ledger_text(case, suspect, sstate),
        buffer=buffer_text(sstate),
        question=question,
        clue=clue,
        relevance=rel.relevance,
    )

    turn: InterrogationTurn | None = None
    for event in stream_turn(backend, prompt, seed=seed, temperature=_SUSPECT_TEMPERATURE):
        if event.spoken_delta:
            yield InterrogationEvent(spoken_delta=event.spoken_delta)
        if event.final is not None:
            turn = event.final
    if turn is None:
        turn = InterrogationTurn.safe_default()

    # Backstop 1: no suspect line is ever allowed to confess - the win lives in the
    # director, not in the suspect's mouth. (Rebuilds the frozen turn with a clean line.)
    clean = scrub_spoken(turn.spoken, breaking=rel.relevance is Relevance.BREAKING)

    # Backstop 2: never show a blank line or a near-verbatim repeat of a recent answer - a 1.5B
    # model does both. Swap in a distinct in-character deflection (deterministic, no extra call).
    recent = [e.answer for e in sstate.transcript[-4:]]
    if not clean or not clean.strip() or any(_too_similar(clean, prev) for prev in recent):
        clean = _distinct_deflection(rel.relevance, recent, f"{sus_id}:{question}:{len(sstate.transcript)}")

    if clean != turn.spoken:
        turn = turn.model_copy(update={"spoken": clean})

    adj = adjudicate(case, suspect, sstate, turn, presented_clue_id)
    new_state = apply_turn(state, case, sus_id, question, turn, adj, presented_clue_id)
    yield InterrogationEvent(final=FinalTurn(turn=turn, adjudication=adj, state=new_state))