File size: 14,702 Bytes
807d5cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
"""Deterministic seeker simulator with hidden internal state.

Why rule-based / deterministic?
-------------------------------
The OpenEnv graders must be reproducible. An LLM-driven seeker would make
reward non-deterministic and fail the "score variance check" in Phase 2 of
judging. We deliberately trade some linguistic realism for full determinism
so that the same action sequence always yields the same reward β€” a hard
requirement of the hackathon rubric ("graders deterministic and reproducible").

Design
------
The seeker is a finite-state machine with continuous hidden variables:

    distress   ∈ [0, 1]   β€” how emotionally overwhelmed the seeker feels
    trust      ∈ [0, 1]   β€” how safe the seeker feels with the agent
    openness   ∈ [0, 1]   β€” willingness to reveal the *true* issue
    revealed   ∈ {0, 1}   β€” has the core issue surfaced yet?
    stage      ∈ enum     β€” opening / exploring / reflecting / planning / closing

On each turn, the environment analyses the agent's reply with a small bank of
deterministic feature detectors (keyword/regex based), then applies a
transition rule to update the hidden state and pick the seeker's next
utterance from a scripted response tree indexed by (stage, features).
"""
from __future__ import annotations

import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Tuple


class Stage(str, Enum):
    OPENING = "opening"
    EXPLORING = "exploring"
    REFLECTING = "reflecting"
    PLANNING = "planning"
    CLOSING = "closing"


# ---------------------------------------------------------------------------
# Feature detectors β€” deterministic text analysis of the agent's reply.
# ---------------------------------------------------------------------------

EMPATHY_PATTERNS = [
    r"\bi\s+(hear|understand|get|see)\s+(you|that|how)",
    r"\bthat\s+(sounds|must\s+be|seems)\b",
    r"\bit\s+makes\s+sense\b",
    r"\bi\s+can\s+imagine\b",
    r"\bthank\s+you\s+for\s+sharing\b",
    r"\bi'?m\s+(here|glad|sorry)\b",
]

VALIDATION_PATTERNS = [
    r"\byour\s+feelings?\s+(are|make)\s+(valid|sense)",
    r"\bit'?s\s+(okay|ok|normal|understandable)\s+to\s+feel",
    r"\banyone\s+would\s+feel\b",
    r"\bof\s+course\s+you\s+(feel|are)\b",
]

OPEN_QUESTION_PATTERNS = [
    r"\bhow\s+(are|do|did|does)\b",
    r"\bwhat\s+(is|are|do|does|has|makes|brought|happened)\b",
    r"\bcan\s+you\s+tell\s+me\s+more\b",
    r"\bwould\s+you\s+like\s+to\s+(talk|share)\b",
]

ADVICE_PATTERNS = [
    r"\byou\s+should\b",
    r"\byou\s+(need|have|ought)\s+to\b",
    r"\btry\s+(to|doing|this)\b",
    r"\bjust\s+(do|go|try|stop|start)\b",
    r"\bwhy\s+don'?t\s+you\b",
    r"\bmy\s+advice\b",
]

DISMISSIVE_PATTERNS = [
    r"\bget\s+over\s+it\b",
    r"\bstop\s+(complaining|whining|crying)\b",
    r"\byou'?re\s+overreacting\b",
    r"\bit'?s\s+not\s+a\s+big\s+deal\b",
    r"\bcalm\s+down\b",
    r"\bit\s+could\s+be\s+worse\b",
]

INTERROGATIVE_PATTERNS = [  # rapid-fire closed questions (trust drain when high)
    r"\?\s*\?",
]

SAFETY_PATTERNS = [
    r"\bare\s+you\s+safe\b",
    r"\bprofessional\s+help\b",
    r"\bcrisis\s+line\b",
    r"\btherapist\b",
]


def _count_matches(patterns: List[str], text: str) -> int:
    t = text.lower()
    return sum(1 for p in patterns if re.search(p, t))


@dataclass
class Features:
    empathy: int
    validation: int
    open_question: int
    advice: int
    dismissive: int
    interrogative: int
    safety: int
    length: int
    closed_question: int  # any '?' not matched by open
    bare: bool  # very short / empty reply


def extract_features(text: str) -> Features:
    stripped = (text or "").strip()
    lower = stripped.lower()
    empathy = _count_matches(EMPATHY_PATTERNS, lower)
    validation = _count_matches(VALIDATION_PATTERNS, lower)
    open_q = _count_matches(OPEN_QUESTION_PATTERNS, lower)
    advice = _count_matches(ADVICE_PATTERNS, lower)
    dismissive = _count_matches(DISMISSIVE_PATTERNS, lower)
    interrogative = _count_matches(INTERROGATIVE_PATTERNS, lower)
    safety = _count_matches(SAFETY_PATTERNS, lower)
    total_q = lower.count("?")
    closed_q = max(0, total_q - open_q)
    bare = len(stripped) < 8
    return Features(
        empathy=empathy,
        validation=validation,
        open_question=open_q,
        advice=advice,
        dismissive=dismissive,
        interrogative=interrogative,
        safety=safety,
        length=len(stripped),
        closed_question=closed_q,
        bare=bare,
    )


# ---------------------------------------------------------------------------
# Seeker state + scripted persona
# ---------------------------------------------------------------------------

@dataclass
class SeekerPersona:
    """Static configuration describing the seeker's initial state + script."""

    task_id: str
    scenario_brief: str
    surface_concern: str  # what seeker says at turn 0
    true_issue: str  # hidden; only revealed if openness crosses threshold
    initial_distress: float
    initial_trust: float
    initial_openness: float
    reveal_threshold: float  # openness value at which true_issue is revealed
    trust_fragility: float  # how much a misstep drops trust (0..1)
    openness_gain_per_empathy: float
    distress_drop_per_validation: float
    # Scripted utterances by stage when cooperative
    opening_lines: List[str]
    exploring_lines: List[str]
    reflecting_lines: List[str]
    planning_lines: List[str]
    closing_lines: List[str]
    reveal_line: str  # said the turn openness crosses reveal_threshold
    # Adverse reactions
    dismissed_lines: List[str] = field(default_factory=list)
    advice_too_early_lines: List[str] = field(default_factory=list)


@dataclass
class SeekerState:
    """Mutable hidden state updated each turn."""

    persona: SeekerPersona
    distress: float
    trust: float
    openness: float
    revealed: bool
    stage: Stage
    last_line_idx_by_stage: Dict[Stage, int]
    turn: int

    @classmethod
    def from_persona(cls, persona: SeekerPersona) -> "SeekerState":
        return cls(
            persona=persona,
            distress=persona.initial_distress,
            trust=persona.initial_trust,
            openness=persona.initial_openness,
            revealed=False,
            stage=Stage.OPENING,
            last_line_idx_by_stage={s: -1 for s in Stage},
            turn=0,
        )

    # Snapshot for lookahead simulation β€” must be cheap and pure.
    def snapshot(self) -> "SeekerState":
        return SeekerState(
            persona=self.persona,
            distress=self.distress,
            trust=self.trust,
            openness=self.openness,
            revealed=self.revealed,
            stage=self.stage,
            last_line_idx_by_stage=dict(self.last_line_idx_by_stage),
            turn=self.turn,
        )


def _clip(x: float) -> float:
    return max(0.0, min(1.0, x))


# Stage ordering used for "progress" scalar in [0,1]
STAGE_ORDER: List[Stage] = [
    Stage.OPENING,
    Stage.EXPLORING,
    Stage.REFLECTING,
    Stage.PLANNING,
    Stage.CLOSING,
]


def stage_progress(stage: Stage) -> float:
    return STAGE_ORDER.index(stage) / (len(STAGE_ORDER) - 1)


def resolution_score(state: SeekerState) -> float:
    """Scalar summary of how 'resolved' the conversation currently is, in [0,1].

    Weighted combination of stage progress, trust gained, distress relieved,
    and whether the true issue surfaced. This is the quantity the
    future-oriented reward tries to project forward under an oracle policy.
    """
    p = state.persona
    progress = stage_progress(state.stage)
    trust_gain = max(0.0, state.trust - p.initial_trust)
    distress_relief = max(0.0, p.initial_distress - state.distress)
    reveal_bonus = 1.0 if state.revealed else 0.0
    return _clip(
        0.40 * progress
        + 0.25 * trust_gain / max(1e-6, 1.0 - p.initial_trust)
        + 0.25 * distress_relief / max(1e-6, p.initial_distress)
        + 0.10 * reveal_bonus
    )


# ---------------------------------------------------------------------------
# Transition: given current state + agent features, produce new state +
# seeker's next utterance + transition info.
# ---------------------------------------------------------------------------

@dataclass
class Transition:
    new_state: SeekerState
    seeker_utterance: str
    flags: Dict[str, bool]  # e.g. {"dismissed": True, "advice_too_early": False, ...}


def _next_line(state: SeekerState, stage: Stage, pool: List[str]) -> str:
    if not pool:
        return "..."
    idx = (state.last_line_idx_by_stage[stage] + 1) % len(pool)
    state.last_line_idx_by_stage[stage] = idx
    return pool[idx]


def step_seeker(state: SeekerState, features: Features) -> Transition:
    """Apply one turn of seeker dynamics given the agent's extracted features.

    Pure-ish: mutates a *copy* of state (caller should pass a snapshot if they
    want to preserve the original β€” the env always passes the live state).
    """
    p = state.persona
    flags: Dict[str, bool] = {
        "dismissed": False,
        "advice_too_early": False,
        "bare_reply": features.bare,
        "empathic": features.empathy + features.validation > 0,
        "interrogated": False,
        "revealed_this_turn": False,
    }

    # --- 1. Dismissive / hostile language: hard drop on trust & distress spike.
    if features.dismissive > 0:
        state.trust = _clip(state.trust - 0.4 * (1.0 + p.trust_fragility))
        state.distress = _clip(state.distress + 0.15)
        state.openness = _clip(state.openness - 0.2)
        flags["dismissed"] = True

    # --- 2. Premature advice (advice before trust β‰₯ 0.55): trust drop, openness drop.
    if features.advice > 0 and state.trust < 0.55:
        state.trust = _clip(state.trust - 0.15 * (1.0 + p.trust_fragility))
        state.openness = _clip(state.openness - 0.1)
        flags["advice_too_early"] = True

    # --- 3. Empathy & validation: trust + openness up, distress down.
    if features.empathy > 0 or features.validation > 0:
        gain = p.openness_gain_per_empathy * (features.empathy + features.validation)
        state.trust = _clip(state.trust + 0.12 * (features.empathy + features.validation))
        state.openness = _clip(state.openness + gain)
        state.distress = _clip(state.distress - p.distress_drop_per_validation * features.validation)

    # --- 4. Open questions: small trust gain, nudges stage forward.
    if features.open_question > 0:
        state.trust = _clip(state.trust + 0.05)
        state.openness = _clip(state.openness + 0.04)

    # --- 5. Interrogation (many closed questions or multiple "?"): trust drain.
    if features.closed_question >= 3 or features.interrogative > 0:
        state.trust = _clip(state.trust - 0.1)
        flags["interrogated"] = True

    # --- 6. Bare / empty reply: small penalty across the board.
    if features.bare:
        state.trust = _clip(state.trust - 0.05)
        state.distress = _clip(state.distress + 0.02)

    # --- 7. Stage progression (monotonic forward with cooperative conditions).
    def advance_to(s: Stage) -> None:
        if STAGE_ORDER.index(s) > STAGE_ORDER.index(state.stage):
            state.stage = s

    if state.stage == Stage.OPENING and (
        features.empathy + features.validation + features.open_question > 0
    ):
        advance_to(Stage.EXPLORING)
    elif state.stage == Stage.EXPLORING and state.trust >= 0.5 and state.openness >= 0.5:
        advance_to(Stage.REFLECTING)
    elif state.stage == Stage.REFLECTING and state.revealed and state.distress <= 0.5:
        advance_to(Stage.PLANNING)
    elif state.stage == Stage.PLANNING and features.open_question + features.empathy > 0:
        advance_to(Stage.CLOSING)

    # --- 8. Reveal check (cross threshold once).
    if not state.revealed and state.openness >= p.reveal_threshold:
        state.revealed = True
        flags["revealed_this_turn"] = True

    # --- 9. Pick seeker's next utterance.
    if flags["dismissed"] and p.dismissed_lines:
        utterance = _next_line(state, state.stage, p.dismissed_lines)
    elif flags["advice_too_early"] and p.advice_too_early_lines:
        utterance = _next_line(state, state.stage, p.advice_too_early_lines)
    elif flags["revealed_this_turn"]:
        utterance = p.reveal_line
    else:
        pool_by_stage = {
            Stage.OPENING: p.opening_lines,
            Stage.EXPLORING: p.exploring_lines,
            Stage.REFLECTING: p.reflecting_lines,
            Stage.PLANNING: p.planning_lines,
            Stage.CLOSING: p.closing_lines,
        }
        utterance = _next_line(state, state.stage, pool_by_stage[state.stage])

    state.turn += 1
    return Transition(new_state=state, seeker_utterance=utterance, flags=flags)


# ---------------------------------------------------------------------------
# Oracle policy for the future-oriented reward lookahead.
# ---------------------------------------------------------------------------

def oracle_features(state: SeekerState) -> Features:
    """What the 'oracle' agent would do from this state.

    Picks the stage-appropriate ideal action:
      - opening/exploring: empathy + open question
      - reflecting: empathy + validation
      - planning: open question + mild advice (trust is high here)
      - closing: empathy + safety mention
    """
    s = state.stage
    if s in (Stage.OPENING, Stage.EXPLORING):
        return Features(
            empathy=1, validation=0, open_question=1, advice=0,
            dismissive=0, interrogative=0, safety=0, length=80,
            closed_question=0, bare=False,
        )
    if s == Stage.REFLECTING:
        return Features(
            empathy=1, validation=1, open_question=0, advice=0,
            dismissive=0, interrogative=0, safety=0, length=90,
            closed_question=0, bare=False,
        )
    if s == Stage.PLANNING:
        return Features(
            empathy=0, validation=0, open_question=1, advice=1,
            dismissive=0, interrogative=0, safety=0, length=90,
            closed_question=0, bare=False,
        )
    return Features(  # CLOSING
        empathy=1, validation=0, open_question=0, advice=0,
        dismissive=0, interrogative=0, safety=1, length=90,
        closed_question=0, bare=False,
    )


def simulate_oracle_rollout(state: SeekerState, k: int) -> float:
    """Run the oracle policy from a snapshot for k steps and return the final
    resolution_score. Used by the future-oriented reward."""
    sim = state.snapshot()
    for _ in range(k):
        step_seeker(sim, oracle_features(sim))
    return resolution_score(sim)