File size: 12,290 Bytes
807d5cc
 
 
 
e1ec6bc
 
f5e0477
 
 
 
 
 
 
807d5cc
 
 
 
 
 
 
 
 
 
 
 
 
80d39c4
807d5cc
 
 
 
 
 
 
3fbae38
807d5cc
 
3fbae38
807d5cc
 
 
 
 
 
 
 
 
 
3fbae38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807d5cc
 
 
e1ec6bc
 
 
807d5cc
f5e0477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807d5cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fbae38
 
 
 
807d5cc
 
 
 
 
 
 
3fbae38
 
 
807d5cc
 
 
 
 
 
 
3fbae38
 
 
 
 
807d5cc
 
 
 
f5e0477
807d5cc
 
f5e0477
807d5cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fbae38
 
 
 
 
 
 
 
 
80d39c4
 
3fbae38
 
 
 
 
 
 
80d39c4
 
 
 
 
 
 
 
 
 
3fbae38
 
 
 
 
 
 
 
 
 
 
 
 
80d39c4
 
 
 
 
 
 
 
3fbae38
 
807d5cc
 
f5e0477
 
 
 
 
 
 
807d5cc
3fbae38
 
 
 
 
807d5cc
 
 
 
 
 
 
 
 
 
 
 
 
3fbae38
 
 
 
807d5cc
 
 
 
 
 
 
3fbae38
 
 
 
807d5cc
3fbae38
 
 
807d5cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ec6bc
 
f5e0477
 
 
 
 
807d5cc
 
 
 
f5e0477
807d5cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
"""Baseline inference script for the ESC OpenEnv environment.

MANDATORY env vars
------------------
    API_BASE_URL   - LLM endpoint (defaults to https://api.openai.com/v1)
    MODEL_NAME     - Model identifier (defaults to gpt-4.1-mini)
    HF_TOKEN       - Hugging Face / router token (preferred)
    ESC_ENV_URL    - URL of the running ESC OpenEnv HTTP server (defaults to localhost)

Compatible auth env vars
------------------------
    OPENAI_API_KEY - standard OpenAI-compatible auth key
    API_KEY        - generic OpenAI-compatible auth key

STDOUT contract (strict)
------------------------
One [START] line per episode, one [STEP] per step, one [END] per episode.
See the hackathon spec for exact format.

Runs all 3 tasks (easy/medium/hard) sequentially and prints a final summary
to stderr. Total wall-clock budget kept well under 20min on 2 vCPU / 8GB.
"""
from __future__ import annotations

import asyncio
import os
import re
import sys
import textwrap
import traceback
from typing import List, Optional

from openai import OpenAI

from src.agentic import AgentMemory, SkillRouter, build_default_skills
from src.client import ESCHttpClient
from src.models import Action
from src.seeker import extract_features

BENCHMARK = "emotional-support-conversations"
MAX_STEPS = 14  # upper bound; env imposes per-task limits too
TEMPERATURE = 0.6
MAX_TOKENS = 220

TASK_IDS = ["work_stress_venting", "guarded_relationship", "crisis_fragile_trust"]

SYSTEM_PROMPT = textwrap.dedent(
    """
    You are the response generator inside a controlled emotional-support agent.

    A deterministic controller has already selected the correct conversational
    move for this turn and written a draft reply. Your job is only to lightly
    polish that draft while preserving its intent and structure.

    Hard rules:
    - Stay extremely close to the draft.
    - Keep the same stage objective. Do not change exploration into advice or
      advice into exploration.
    - Preserve any explicit safety support mention, validation, and questions
      already present in the draft.
    - Do not add extra questions, extra advice, or new topics.
    - Keep replies warm, brief, and human.
    - If the draft is already strong, repeat it verbatim.

    Reply with ONLY the next message to the seeker.
    """
).strip()

DEFAULT_API_BASE_URL = "https://api.openai.com/v1"
DEFAULT_MODEL_NAME = "gpt-4.1-mini"


def require_env(name: str) -> str:
    value = os.getenv(name)
    if not value:
        raise SystemExit(
            f"Missing required environment variable: {name}\n"
            "Set the judging env vars and rerun `python inference.py`."
        )
    return value


def resolve_api_key() -> str:
    api_key = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
    if not api_key:
        raise SystemExit(
            "Missing authentication token. Set HF_TOKEN, OPENAI_API_KEY, or API_KEY "
            "before running `python inference.py`."
        )
    return api_key


# -------------------------- stdout contract ----------------------------------

def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    err = error if error else "null"
    # collapse any newlines in the action so the stdout contract stays single-line
    flat_action = " ".join((action or "").split())
    print(
        f"[STEP] step={step} action={flat_action} reward={reward:.2f} "
        f"done={str(done).lower()} error={err}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} "
        f"score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


# -------------------------- LLM call -----------------------------------------

def build_user_prompt(
    scenario_brief: str,
    stage_hint: str,
    turn: int,
    remaining: int,
    seeker_utterance: str,
    history: List[str],
    skill_name: str,
    rationale: str,
    skill_instruction: str,
    draft_reply: str,
) -> str:
    history_block = "\n".join(history[-8:]) if history else "(this is the first turn)"
    return textwrap.dedent(
        f"""
        Scenario: {scenario_brief}
        Conversation stage (public hint): {stage_hint}
        Turn: {turn}   Remaining turns: {remaining}
        Selected skill: {skill_name}
        Why this skill was selected: {rationale}
        Skill directive: {skill_instruction}

        Recent exchange:
        {history_block}

        Seeker just said:
        "{seeker_utterance}"

        Deterministic draft reply:
        "{draft_reply}"

        Lightly polish the draft only if needed. Preserve its goal and
        structure. If unsure, output the draft unchanged.
        """
    ).strip()


def call_llm(client: OpenAI, model_name: str, user_prompt: str) -> str:
    try:
        completion = client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt},
            ],
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            stream=False,
        )
        text = (completion.choices[0].message.content or "").strip()
        return text if text else "I hear you. That sounds really hard — can you tell me a little more about what's weighing on you?"
    except Exception as exc:
        print(f"[DEBUG] LLM call failed: {exc}", file=sys.stderr, flush=True)
        return "That sounds really hard. I'm here — do you want to tell me more about what's going on?"


def _count_questions(text: str) -> int:
    return (text or "").count("?")


def should_accept_rewrite(draft: str, candidate: str) -> bool:
    candidate = (candidate or "").strip()
    if not candidate:
        return False

    draft_norm = " ".join(re.sub(r"[^\w\s]", "", draft.lower()).split())
    candidate_norm = " ".join(re.sub(r"[^\w\s]", "", candidate.lower()).split())
    draft_features = extract_features(draft)
    candidate_features = extract_features(candidate)

    if candidate_features.dismissive > 0 or candidate_features.bare:
        return False
    if _count_questions(candidate) > 1 or candidate_features.interrogative > 0:
        return False
    if len(candidate.split()) > max(24, int(len(draft.split()) * 1.2)):
        return False
    if draft_features.open_question != candidate_features.open_question:
        return False
    if draft_features.advice != candidate_features.advice:
        return False
    if draft_features.safety != candidate_features.safety:
        return False
    if draft_features.validation != candidate_features.validation:
        return False

    # Do not let the rewrite weaken the key stage-driving signals already
    # present in the deterministic draft.
    if draft_features.open_question > 0 and candidate_features.open_question <= 0:
        return False
    if draft_features.validation > 0 and candidate_features.validation <= 0:
        return False
    if draft_features.empathy > 0 and candidate_features.empathy <= 0:
        return False
    if draft_features.advice > 0 and candidate_features.advice <= 0:
        return False
    if draft_features.safety > 0 and candidate_features.safety <= 0:
        return False
    if draft_norm == candidate_norm:
        return True

    # Only accept near-verbatim rewrites; otherwise keep the proven draft.
    draft_tokens = set(draft_norm.split())
    candidate_tokens = set(candidate_norm.split())
    overlap = len(draft_tokens & candidate_tokens) / max(1, len(draft_tokens))
    return overlap >= 0.8


# -------------------------- per-task episode ---------------------------------

async def run_task(
    openai_client: OpenAI,
    env_client: ESCHttpClient,
    model_name: str,
    task_id: str,
) -> dict:
    log_start(task=task_id, env=BENCHMARK, model=model_name)

    router = SkillRouter()
    skills = build_default_skills()
    memory = AgentMemory()
    memory.reset(task_id)

    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False
    history: List[str] = []
    last_error: Optional[str] = None

    try:
        reset = await env_client.reset(task_id=task_id)
        obs = reset.observation
        history.append(f"Seeker: {obs.seeker_utterance!r}")

        for step in range(1, MAX_STEPS + 1):
            memory.observe(obs)
            decision = router.choose(obs, memory)
            skill = skills[decision.skill_name]
            draft_message = skill.render(obs, memory, decision)
            user_prompt = build_user_prompt(
                scenario_brief=obs.scenario_brief,
                stage_hint=obs.stage_hint,
                turn=obs.turn,
                remaining=obs.remaining_turns,
                seeker_utterance=obs.seeker_utterance,
                history=history,
                skill_name=decision.skill_name,
                rationale=decision.rationale,
                skill_instruction=skill.llm_instruction(obs, memory, decision),
                draft_reply=draft_message,
            )
            candidate_message = call_llm(openai_client, model_name, user_prompt)
            message = candidate_message if should_accept_rewrite(draft_message, candidate_message) else draft_message
            memory.remember(decision.skill_name, message)

            try:
                result = await env_client.step(Action(message=message))
            except Exception as e:
                last_error = f"step_failed: {e}"
                log_step(step=step, action=message, reward=0.0, done=True, error=last_error)
                break

            reward = float(result.reward)
            done = bool(result.done)
            rewards.append(reward)
            steps_taken = step
            obs = result.observation

            history.append(f"Agent: {message!r}")
            history.append(f"Seeker: {obs.seeker_utterance!r}")

            log_step(step=step, action=message, reward=reward, done=done, error=None)

            if done:
                final = result.info.get("final", {}) if isinstance(result.info, dict) else {}
                score = float(final.get("score", sum(rewards) / max(1, steps_taken)))
                success = bool(final.get("success", 0.0) >= 1.0)
                break
        else:
            # Ran out of outer loop without env-side done — fall back to state().
            st = await env_client.state()
            score = float(st.get("cumulative_reward", 0.0)) / max(1, steps_taken)
            success = score >= 0.5

    except Exception as exc:
        last_error = f"episode_failed: {exc}"
        traceback.print_exc(file=sys.stderr)

    log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
    return {"task_id": task_id, "score": score, "success": success, "steps": steps_taken}


# -------------------------- main ---------------------------------------------

async def main() -> None:
    api_base_url = os.getenv("API_BASE_URL") or DEFAULT_API_BASE_URL
    model_name = os.getenv("MODEL_NAME") or DEFAULT_MODEL_NAME
    api_key = resolve_api_key()
    env_url = os.getenv("ESC_ENV_URL") or "http://127.0.0.1:7860"

    openai_client = OpenAI(base_url=api_base_url, api_key=api_key)
    env_client = ESCHttpClient.from_url(env_url)

    results = []
    try:
        for task_id in TASK_IDS:
            res = await run_task(openai_client, env_client, model_name, task_id)
            results.append(res)
    finally:
        await env_client.close()

    # Summary to stderr so it doesn't pollute the stdout contract.
    print("\n=== Baseline summary ===", file=sys.stderr)
    for r in results:
        print(
            f"  {r['task_id']:<26} score={r['score']:.3f}  success={r['success']}  steps={r['steps']}",
            file=sys.stderr,
        )
    avg = sum(r["score"] for r in results) / max(1, len(results))
    print(f"  {'AVERAGE':<26} score={avg:.3f}", file=sys.stderr)


if __name__ == "__main__":
    asyncio.run(main())