5ivatej commited on
Commit
3fbae38
·
1 Parent(s): f5e0477

Strengthen hosted inference baseline with hybrid controller

Browse files
Files changed (1) hide show
  1. inference.py +80 -19
inference.py CHANGED
@@ -31,8 +31,10 @@ from typing import List, Optional
31
 
32
  from openai import OpenAI
33
 
 
34
  from src.client import ESCHttpClient
35
  from src.models import Action
 
36
 
37
  BENCHMARK = "emotional-support-conversations"
38
  MAX_STEPS = 14 # upper bound; env imposes per-task limits too
@@ -43,23 +45,23 @@ TASK_IDS = ["work_stress_venting", "guarded_relationship", "crisis_fragile_trust
43
 
44
  SYSTEM_PROMPT = textwrap.dedent(
45
  """
46
- You are an emotionally attuned peer supporter chatting with someone who is
47
- going through a hard time. Your job is NOT to fix their problem. Your job
48
- is to make them feel heard, safe, and understood first — and only move
49
- toward gentle exploration or light action once trust is established.
50
-
51
- Principles:
52
- - Lead with empathy and validation. Reflect what you hear.
53
- - Do NOT give advice until the person has clearly shared what's really
54
- going on and feels heard.
55
- - Ask at most one open-ended question per reply. Never interrogate.
56
- - Never be dismissive, minimising, or instructive in a judgmental tone.
57
- - Keep replies warm, brief (1-3 sentences), and human.
58
- - In high-distress / crisis scenarios, gently reference professional
59
- support (a therapist, crisis line) only after rapport is built.
60
-
61
- You will receive the current conversation state. Reply with ONLY your
62
- next message to the person no role labels, no prefixes, no quotes.
63
  """
64
  ).strip()
65
 
@@ -119,6 +121,10 @@ def build_user_prompt(
119
  remaining: int,
120
  seeker_utterance: str,
121
  history: List[str],
 
 
 
 
122
  ) -> str:
123
  history_block = "\n".join(history[-8:]) if history else "(this is the first turn)"
124
  return textwrap.dedent(
@@ -126,6 +132,9 @@ def build_user_prompt(
126
  Scenario: {scenario_brief}
127
  Conversation stage (public hint): {stage_hint}
128
  Turn: {turn} Remaining turns: {remaining}
 
 
 
129
 
130
  Recent exchange:
131
  {history_block}
@@ -133,7 +142,11 @@ def build_user_prompt(
133
  Seeker just said:
134
  "{seeker_utterance}"
135
 
136
- Write your next reply (1-3 sentences, warm, no advice unless rapport is clearly established):
 
 
 
 
137
  """
138
  ).strip()
139
 
@@ -157,6 +170,39 @@ def call_llm(client: OpenAI, model_name: str, user_prompt: str) -> str:
157
  return "That sounds really hard. I'm here — do you want to tell me more about what's going on?"
158
 
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  # -------------------------- per-task episode ---------------------------------
161
 
162
  async def run_task(
@@ -167,6 +213,11 @@ async def run_task(
167
  ) -> dict:
168
  log_start(task=task_id, env=BENCHMARK, model=model_name)
169
 
 
 
 
 
 
170
  rewards: List[float] = []
171
  steps_taken = 0
172
  score = 0.0
@@ -180,6 +231,10 @@ async def run_task(
180
  history.append(f"Seeker: {obs.seeker_utterance!r}")
181
 
182
  for step in range(1, MAX_STEPS + 1):
 
 
 
 
183
  user_prompt = build_user_prompt(
184
  scenario_brief=obs.scenario_brief,
185
  stage_hint=obs.stage_hint,
@@ -187,8 +242,14 @@ async def run_task(
187
  remaining=obs.remaining_turns,
188
  seeker_utterance=obs.seeker_utterance,
189
  history=history,
 
 
 
 
190
  )
191
- message = call_llm(openai_client, model_name, user_prompt)
 
 
192
 
193
  try:
194
  result = await env_client.step(Action(message=message))
 
31
 
32
  from openai import OpenAI
33
 
34
+ from src.agentic import AgentMemory, SkillRouter, build_default_skills
35
  from src.client import ESCHttpClient
36
  from src.models import Action
37
+ from src.seeker import extract_features
38
 
39
  BENCHMARK = "emotional-support-conversations"
40
  MAX_STEPS = 14 # upper bound; env imposes per-task limits too
 
45
 
46
  SYSTEM_PROMPT = textwrap.dedent(
47
  """
48
+ You are the response generator inside a controlled emotional-support agent.
49
+
50
+ A deterministic controller has already selected the correct conversational
51
+ move for this turn and written a draft reply. Your job is only to lightly
52
+ polish that draft while preserving its intent and structure.
53
+
54
+ Hard rules:
55
+ - Stay extremely close to the draft.
56
+ - Keep the same stage objective. Do not change exploration into advice or
57
+ advice into exploration.
58
+ - Preserve any explicit safety support mention, validation, and questions
59
+ already present in the draft.
60
+ - Do not add extra questions, extra advice, or new topics.
61
+ - Keep replies warm, brief, and human.
62
+ - If the draft is already strong, repeat it verbatim.
63
+
64
+ Reply with ONLY the next message to the seeker.
65
  """
66
  ).strip()
67
 
 
121
  remaining: int,
122
  seeker_utterance: str,
123
  history: List[str],
124
+ skill_name: str,
125
+ rationale: str,
126
+ skill_instruction: str,
127
+ draft_reply: str,
128
  ) -> str:
129
  history_block = "\n".join(history[-8:]) if history else "(this is the first turn)"
130
  return textwrap.dedent(
 
132
  Scenario: {scenario_brief}
133
  Conversation stage (public hint): {stage_hint}
134
  Turn: {turn} Remaining turns: {remaining}
135
+ Selected skill: {skill_name}
136
+ Why this skill was selected: {rationale}
137
+ Skill directive: {skill_instruction}
138
 
139
  Recent exchange:
140
  {history_block}
 
142
  Seeker just said:
143
  "{seeker_utterance}"
144
 
145
+ Deterministic draft reply:
146
+ "{draft_reply}"
147
+
148
+ Lightly polish the draft only if needed. Preserve its goal and
149
+ structure. If unsure, output the draft unchanged.
150
  """
151
  ).strip()
152
 
 
170
  return "That sounds really hard. I'm here — do you want to tell me more about what's going on?"
171
 
172
 
173
+ def _count_questions(text: str) -> int:
174
+ return (text or "").count("?")
175
+
176
+
177
+ def should_accept_rewrite(draft: str, candidate: str) -> bool:
178
+ candidate = (candidate or "").strip()
179
+ if not candidate:
180
+ return False
181
+
182
+ draft_features = extract_features(draft)
183
+ candidate_features = extract_features(candidate)
184
+
185
+ if candidate_features.dismissive > 0 or candidate_features.bare:
186
+ return False
187
+ if _count_questions(candidate) > 1 or candidate_features.interrogative > 0:
188
+ return False
189
+
190
+ # Do not let the rewrite weaken the key stage-driving signals already
191
+ # present in the deterministic draft.
192
+ if draft_features.open_question > 0 and candidate_features.open_question <= 0:
193
+ return False
194
+ if draft_features.validation > 0 and candidate_features.validation <= 0:
195
+ return False
196
+ if draft_features.empathy > 0 and candidate_features.empathy <= 0:
197
+ return False
198
+ if draft_features.advice > 0 and candidate_features.advice <= 0:
199
+ return False
200
+ if draft_features.safety > 0 and candidate_features.safety <= 0:
201
+ return False
202
+
203
+ return True
204
+
205
+
206
  # -------------------------- per-task episode ---------------------------------
207
 
208
  async def run_task(
 
213
  ) -> dict:
214
  log_start(task=task_id, env=BENCHMARK, model=model_name)
215
 
216
+ router = SkillRouter()
217
+ skills = build_default_skills()
218
+ memory = AgentMemory()
219
+ memory.reset(task_id)
220
+
221
  rewards: List[float] = []
222
  steps_taken = 0
223
  score = 0.0
 
231
  history.append(f"Seeker: {obs.seeker_utterance!r}")
232
 
233
  for step in range(1, MAX_STEPS + 1):
234
+ memory.observe(obs)
235
+ decision = router.choose(obs, memory)
236
+ skill = skills[decision.skill_name]
237
+ draft_message = skill.render(obs, memory, decision)
238
  user_prompt = build_user_prompt(
239
  scenario_brief=obs.scenario_brief,
240
  stage_hint=obs.stage_hint,
 
242
  remaining=obs.remaining_turns,
243
  seeker_utterance=obs.seeker_utterance,
244
  history=history,
245
+ skill_name=decision.skill_name,
246
+ rationale=decision.rationale,
247
+ skill_instruction=skill.llm_instruction(obs, memory, decision),
248
+ draft_reply=draft_message,
249
  )
250
+ candidate_message = call_llm(openai_client, model_name, user_prompt)
251
+ message = candidate_message if should_accept_rewrite(draft_message, candidate_message) else draft_message
252
+ memory.remember(decision.skill_name, message)
253
 
254
  try:
255
  result = await env_client.step(Action(message=message))