Spaces:
Paused
Paused
File size: 8,519 Bytes
aceb1b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | """
Interactive chat session driver for the simulator.
When the annotation server's ``instance_display`` includes an
``interactive_chat`` field, the annotator is expected to chat with a live
agent backend before submitting trajectory ratings.
:class:`InteractiveSessionRunner` plays the user side of that chat: it asks
a small "persona" LLM to generate user messages, posts them to the server's
``/agent_chat/send`` route, and finishes with ``/agent_chat/finish`` so the
captured conversation is written into the instance data.
After the runner returns, the regular annotation pipeline (e.g.
:class:`AgentSimulatorStrategy`) picks up the freshly populated
``conversation`` field and produces ratings.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import requests
from .config import InteractiveConfig
logger = logging.getLogger(__name__)
@dataclass
class InteractiveSessionResult:
"""Outcome of one interactive_chat run."""
instance_id: str
completed: bool
turns: int
conversation: List[Dict[str, Any]]
error: Optional[str] = None
class InteractiveSessionRunner:
"""Drive a multi-turn ``interactive_chat`` against the server.
The runner is stateless across instances: ``run`` is called once per
instance and returns the resulting conversation. The persona LLM is
initialized lazily on first use so importing the module is cheap.
"""
def __init__(self, config: InteractiveConfig, server_url: str):
self.config = config
self.server_url = server_url.rstrip("/")
self._endpoint = None # lazy
# ------------------------------------------------------------------
# Persona endpoint setup
# ------------------------------------------------------------------
def _get_endpoint(self):
if self._endpoint is not None:
return self._endpoint
try:
from potato.ai.ai_endpoint import AIEndpointFactory
ai_cfg: Dict[str, Any] = {
"model": self.config.model,
"api_key": self.config.api_key,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
}
if self.config.base_url:
ai_cfg["base_url"] = self.config.base_url
self._endpoint = AIEndpointFactory.create_endpoint({
"ai_support": {
"enabled": True,
"endpoint_type": self.config.endpoint_type,
"ai_config": ai_cfg,
}
})
except Exception as e:
logger.warning("InteractiveSessionRunner: persona endpoint init failed: %s", e)
self._endpoint = None
return self._endpoint
# ------------------------------------------------------------------
# Persona messaging
# ------------------------------------------------------------------
def _generate_persona_message(
self,
task: str,
history: List[Dict[str, str]],
) -> Optional[str]:
endpoint = self._get_endpoint()
if endpoint is None:
return None
if not history and self.config.first_message_template:
return self.config.first_message_template.format(task=task)
# Build chat history. The persona's "user" role is what we send to
# the agent, so from the persona LLM's perspective those are
# *assistant* messages and the agent's replies are *user* prompts.
messages: List[Dict[str, str]] = [
{"role": "system", "content": self.config.persona_system_prompt
+ f"\n\nThe task you want completed is:\n{task}"}
]
for msg in history:
if msg["role"] == "user": # what the persona previously sent
messages.append({"role": "assistant", "content": msg["content"]})
else: # agent's reply
messages.append({"role": "user", "content": msg["content"]})
try:
if hasattr(endpoint, "chat_query"):
reply = endpoint.chat_query(messages)
else:
# Fall back to flattening into a single prompt
flat = "\n".join(f'{m["role"]}: {m["content"]}' for m in messages)
reply = endpoint.query(flat + "\nassistant:", None)
except Exception as e:
logger.warning("Persona LLM call failed: %s", e)
return None
if isinstance(reply, dict):
reply = reply.get("response") or reply.get("content") or str(reply)
text = str(reply or "").strip()
return text or None
# ------------------------------------------------------------------
# Server interaction
# ------------------------------------------------------------------
def _send_to_agent(
self, session: requests.Session, message: str
) -> Optional[Dict[str, Any]]:
try:
resp = session.post(
f"{self.server_url}/agent_chat/send",
json={"message": message},
timeout=120,
)
except requests.exceptions.RequestException as e:
logger.warning("agent_chat/send request failed: %s", e)
return None
if resp.status_code != 200:
logger.warning(
"agent_chat/send returned %d: %s", resp.status_code, resp.text[:200]
)
return None
try:
return resp.json()
except ValueError:
return None
def _finish(self, session: requests.Session) -> bool:
try:
resp = session.post(
f"{self.server_url}/agent_chat/finish",
timeout=60,
)
except requests.exceptions.RequestException as e:
logger.warning("agent_chat/finish request failed: %s", e)
return False
if resp.status_code != 200:
logger.warning(
"agent_chat/finish returned %d: %s",
resp.status_code, resp.text[:200],
)
return False
return True
# ------------------------------------------------------------------
# Main entry point
# ------------------------------------------------------------------
def run(
self,
session: requests.Session,
instance_id: str,
task_description: str,
) -> InteractiveSessionResult:
"""Drive one chat session end-to-end."""
history: List[Dict[str, str]] = []
completed = False
error: Optional[str] = None
for turn in range(self.config.max_turns):
user_msg = self._generate_persona_message(task_description, history)
if not user_msg:
error = error or "persona produced no message"
break
# Strip the [DONE] marker before sending so the agent doesn't see
# it; remember that we should finish after this turn.
should_finish = self.config.done_marker.lower() in user_msg.lower()
send_msg = user_msg.replace(self.config.done_marker, "").strip() or "Thanks!"
agent_reply = self._send_to_agent(session, send_msg)
if agent_reply is None:
error = error or "agent send failed"
break
history.append({"role": "user", "content": send_msg})
history.append({
"role": "agent",
"content": agent_reply.get("content", ""),
})
if should_finish:
completed = True
break
ok = self._finish(session)
if not ok and not error:
error = "finish failed"
if ok and not completed:
# We hit max_turns without an explicit DONE; still consider the
# session "done" for accounting purposes.
completed = True
# Build the conversation array the way the server's finish route does
conversation = [
{
"speaker": "User" if msg["role"] == "user" else "Agent",
"text": msg["content"],
}
for msg in history
]
return InteractiveSessionResult(
instance_id=instance_id,
completed=completed,
turns=len(history) // 2,
conversation=conversation,
error=error,
)
|