Spaces:
Sleeping
Sleeping
Strengthen hosted inference baseline with hybrid controller
Browse files- inference.py +80 -19
inference.py
CHANGED
|
@@ -31,8 +31,10 @@ from typing import List, Optional
|
|
| 31 |
|
| 32 |
from openai import OpenAI
|
| 33 |
|
|
|
|
| 34 |
from src.client import ESCHttpClient
|
| 35 |
from src.models import Action
|
|
|
|
| 36 |
|
| 37 |
BENCHMARK = "emotional-support-conversations"
|
| 38 |
MAX_STEPS = 14 # upper bound; env imposes per-task limits too
|
|
@@ -43,23 +45,23 @@ TASK_IDS = ["work_stress_venting", "guarded_relationship", "crisis_fragile_trust
|
|
| 43 |
|
| 44 |
SYSTEM_PROMPT = textwrap.dedent(
|
| 45 |
"""
|
| 46 |
-
You are
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
-
|
| 57 |
-
|
| 58 |
-
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
"""
|
| 64 |
).strip()
|
| 65 |
|
|
@@ -119,6 +121,10 @@ def build_user_prompt(
|
|
| 119 |
remaining: int,
|
| 120 |
seeker_utterance: str,
|
| 121 |
history: List[str],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
) -> str:
|
| 123 |
history_block = "\n".join(history[-8:]) if history else "(this is the first turn)"
|
| 124 |
return textwrap.dedent(
|
|
@@ -126,6 +132,9 @@ def build_user_prompt(
|
|
| 126 |
Scenario: {scenario_brief}
|
| 127 |
Conversation stage (public hint): {stage_hint}
|
| 128 |
Turn: {turn} Remaining turns: {remaining}
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
Recent exchange:
|
| 131 |
{history_block}
|
|
@@ -133,7 +142,11 @@ def build_user_prompt(
|
|
| 133 |
Seeker just said:
|
| 134 |
"{seeker_utterance}"
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
"""
|
| 138 |
).strip()
|
| 139 |
|
|
@@ -157,6 +170,39 @@ def call_llm(client: OpenAI, model_name: str, user_prompt: str) -> str:
|
|
| 157 |
return "That sounds really hard. I'm here — do you want to tell me more about what's going on?"
|
| 158 |
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
# -------------------------- per-task episode ---------------------------------
|
| 161 |
|
| 162 |
async def run_task(
|
|
@@ -167,6 +213,11 @@ async def run_task(
|
|
| 167 |
) -> dict:
|
| 168 |
log_start(task=task_id, env=BENCHMARK, model=model_name)
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
rewards: List[float] = []
|
| 171 |
steps_taken = 0
|
| 172 |
score = 0.0
|
|
@@ -180,6 +231,10 @@ async def run_task(
|
|
| 180 |
history.append(f"Seeker: {obs.seeker_utterance!r}")
|
| 181 |
|
| 182 |
for step in range(1, MAX_STEPS + 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
user_prompt = build_user_prompt(
|
| 184 |
scenario_brief=obs.scenario_brief,
|
| 185 |
stage_hint=obs.stage_hint,
|
|
@@ -187,8 +242,14 @@ async def run_task(
|
|
| 187 |
remaining=obs.remaining_turns,
|
| 188 |
seeker_utterance=obs.seeker_utterance,
|
| 189 |
history=history,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
)
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
|
| 193 |
try:
|
| 194 |
result = await env_client.step(Action(message=message))
|
|
|
|
| 31 |
|
| 32 |
from openai import OpenAI
|
| 33 |
|
| 34 |
+
from src.agentic import AgentMemory, SkillRouter, build_default_skills
|
| 35 |
from src.client import ESCHttpClient
|
| 36 |
from src.models import Action
|
| 37 |
+
from src.seeker import extract_features
|
| 38 |
|
| 39 |
BENCHMARK = "emotional-support-conversations"
|
| 40 |
MAX_STEPS = 14 # upper bound; env imposes per-task limits too
|
|
|
|
| 45 |
|
| 46 |
SYSTEM_PROMPT = textwrap.dedent(
|
| 47 |
"""
|
| 48 |
+
You are the response generator inside a controlled emotional-support agent.
|
| 49 |
+
|
| 50 |
+
A deterministic controller has already selected the correct conversational
|
| 51 |
+
move for this turn and written a draft reply. Your job is only to lightly
|
| 52 |
+
polish that draft while preserving its intent and structure.
|
| 53 |
+
|
| 54 |
+
Hard rules:
|
| 55 |
+
- Stay extremely close to the draft.
|
| 56 |
+
- Keep the same stage objective. Do not change exploration into advice or
|
| 57 |
+
advice into exploration.
|
| 58 |
+
- Preserve any explicit safety support mention, validation, and questions
|
| 59 |
+
already present in the draft.
|
| 60 |
+
- Do not add extra questions, extra advice, or new topics.
|
| 61 |
+
- Keep replies warm, brief, and human.
|
| 62 |
+
- If the draft is already strong, repeat it verbatim.
|
| 63 |
+
|
| 64 |
+
Reply with ONLY the next message to the seeker.
|
| 65 |
"""
|
| 66 |
).strip()
|
| 67 |
|
|
|
|
| 121 |
remaining: int,
|
| 122 |
seeker_utterance: str,
|
| 123 |
history: List[str],
|
| 124 |
+
skill_name: str,
|
| 125 |
+
rationale: str,
|
| 126 |
+
skill_instruction: str,
|
| 127 |
+
draft_reply: str,
|
| 128 |
) -> str:
|
| 129 |
history_block = "\n".join(history[-8:]) if history else "(this is the first turn)"
|
| 130 |
return textwrap.dedent(
|
|
|
|
| 132 |
Scenario: {scenario_brief}
|
| 133 |
Conversation stage (public hint): {stage_hint}
|
| 134 |
Turn: {turn} Remaining turns: {remaining}
|
| 135 |
+
Selected skill: {skill_name}
|
| 136 |
+
Why this skill was selected: {rationale}
|
| 137 |
+
Skill directive: {skill_instruction}
|
| 138 |
|
| 139 |
Recent exchange:
|
| 140 |
{history_block}
|
|
|
|
| 142 |
Seeker just said:
|
| 143 |
"{seeker_utterance}"
|
| 144 |
|
| 145 |
+
Deterministic draft reply:
|
| 146 |
+
"{draft_reply}"
|
| 147 |
+
|
| 148 |
+
Lightly polish the draft only if needed. Preserve its goal and
|
| 149 |
+
structure. If unsure, output the draft unchanged.
|
| 150 |
"""
|
| 151 |
).strip()
|
| 152 |
|
|
|
|
| 170 |
return "That sounds really hard. I'm here — do you want to tell me more about what's going on?"
|
| 171 |
|
| 172 |
|
| 173 |
+
def _count_questions(text: str) -> int:
|
| 174 |
+
return (text or "").count("?")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def should_accept_rewrite(draft: str, candidate: str) -> bool:
|
| 178 |
+
candidate = (candidate or "").strip()
|
| 179 |
+
if not candidate:
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
draft_features = extract_features(draft)
|
| 183 |
+
candidate_features = extract_features(candidate)
|
| 184 |
+
|
| 185 |
+
if candidate_features.dismissive > 0 or candidate_features.bare:
|
| 186 |
+
return False
|
| 187 |
+
if _count_questions(candidate) > 1 or candidate_features.interrogative > 0:
|
| 188 |
+
return False
|
| 189 |
+
|
| 190 |
+
# Do not let the rewrite weaken the key stage-driving signals already
|
| 191 |
+
# present in the deterministic draft.
|
| 192 |
+
if draft_features.open_question > 0 and candidate_features.open_question <= 0:
|
| 193 |
+
return False
|
| 194 |
+
if draft_features.validation > 0 and candidate_features.validation <= 0:
|
| 195 |
+
return False
|
| 196 |
+
if draft_features.empathy > 0 and candidate_features.empathy <= 0:
|
| 197 |
+
return False
|
| 198 |
+
if draft_features.advice > 0 and candidate_features.advice <= 0:
|
| 199 |
+
return False
|
| 200 |
+
if draft_features.safety > 0 and candidate_features.safety <= 0:
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
return True
|
| 204 |
+
|
| 205 |
+
|
| 206 |
# -------------------------- per-task episode ---------------------------------
|
| 207 |
|
| 208 |
async def run_task(
|
|
|
|
| 213 |
) -> dict:
|
| 214 |
log_start(task=task_id, env=BENCHMARK, model=model_name)
|
| 215 |
|
| 216 |
+
router = SkillRouter()
|
| 217 |
+
skills = build_default_skills()
|
| 218 |
+
memory = AgentMemory()
|
| 219 |
+
memory.reset(task_id)
|
| 220 |
+
|
| 221 |
rewards: List[float] = []
|
| 222 |
steps_taken = 0
|
| 223 |
score = 0.0
|
|
|
|
| 231 |
history.append(f"Seeker: {obs.seeker_utterance!r}")
|
| 232 |
|
| 233 |
for step in range(1, MAX_STEPS + 1):
|
| 234 |
+
memory.observe(obs)
|
| 235 |
+
decision = router.choose(obs, memory)
|
| 236 |
+
skill = skills[decision.skill_name]
|
| 237 |
+
draft_message = skill.render(obs, memory, decision)
|
| 238 |
user_prompt = build_user_prompt(
|
| 239 |
scenario_brief=obs.scenario_brief,
|
| 240 |
stage_hint=obs.stage_hint,
|
|
|
|
| 242 |
remaining=obs.remaining_turns,
|
| 243 |
seeker_utterance=obs.seeker_utterance,
|
| 244 |
history=history,
|
| 245 |
+
skill_name=decision.skill_name,
|
| 246 |
+
rationale=decision.rationale,
|
| 247 |
+
skill_instruction=skill.llm_instruction(obs, memory, decision),
|
| 248 |
+
draft_reply=draft_message,
|
| 249 |
)
|
| 250 |
+
candidate_message = call_llm(openai_client, model_name, user_prompt)
|
| 251 |
+
message = candidate_message if should_accept_rewrite(draft_message, candidate_message) else draft_message
|
| 252 |
+
memory.remember(decision.skill_name, message)
|
| 253 |
|
| 254 |
try:
|
| 255 |
result = await env_client.step(Action(message=message))
|