codebook / potato /simulator /interactive_runner.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
8.52 kB
"""
Interactive chat session driver for the simulator.
When the annotation server's ``instance_display`` includes an
``interactive_chat`` field, the annotator is expected to chat with a live
agent backend before submitting trajectory ratings.
:class:`InteractiveSessionRunner` plays the user side of that chat: it asks
a small "persona" LLM to generate user messages, posts them to the server's
``/agent_chat/send`` route, and finishes with ``/agent_chat/finish`` so the
captured conversation is written into the instance data.
After the runner returns, the regular annotation pipeline (e.g.
:class:`AgentSimulatorStrategy`) picks up the freshly populated
``conversation`` field and produces ratings.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import requests
from .config import InteractiveConfig
logger = logging.getLogger(__name__)
@dataclass
class InteractiveSessionResult:
"""Outcome of one interactive_chat run."""
instance_id: str
completed: bool
turns: int
conversation: List[Dict[str, Any]]
error: Optional[str] = None
class InteractiveSessionRunner:
"""Drive a multi-turn ``interactive_chat`` against the server.
The runner is stateless across instances: ``run`` is called once per
instance and returns the resulting conversation. The persona LLM is
initialized lazily on first use so importing the module is cheap.
"""
def __init__(self, config: InteractiveConfig, server_url: str):
self.config = config
self.server_url = server_url.rstrip("/")
self._endpoint = None # lazy
# ------------------------------------------------------------------
# Persona endpoint setup
# ------------------------------------------------------------------
def _get_endpoint(self):
if self._endpoint is not None:
return self._endpoint
try:
from potato.ai.ai_endpoint import AIEndpointFactory
ai_cfg: Dict[str, Any] = {
"model": self.config.model,
"api_key": self.config.api_key,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
}
if self.config.base_url:
ai_cfg["base_url"] = self.config.base_url
self._endpoint = AIEndpointFactory.create_endpoint({
"ai_support": {
"enabled": True,
"endpoint_type": self.config.endpoint_type,
"ai_config": ai_cfg,
}
})
except Exception as e:
logger.warning("InteractiveSessionRunner: persona endpoint init failed: %s", e)
self._endpoint = None
return self._endpoint
# ------------------------------------------------------------------
# Persona messaging
# ------------------------------------------------------------------
def _generate_persona_message(
self,
task: str,
history: List[Dict[str, str]],
) -> Optional[str]:
endpoint = self._get_endpoint()
if endpoint is None:
return None
if not history and self.config.first_message_template:
return self.config.first_message_template.format(task=task)
# Build chat history. The persona's "user" role is what we send to
# the agent, so from the persona LLM's perspective those are
# *assistant* messages and the agent's replies are *user* prompts.
messages: List[Dict[str, str]] = [
{"role": "system", "content": self.config.persona_system_prompt
+ f"\n\nThe task you want completed is:\n{task}"}
]
for msg in history:
if msg["role"] == "user": # what the persona previously sent
messages.append({"role": "assistant", "content": msg["content"]})
else: # agent's reply
messages.append({"role": "user", "content": msg["content"]})
try:
if hasattr(endpoint, "chat_query"):
reply = endpoint.chat_query(messages)
else:
# Fall back to flattening into a single prompt
flat = "\n".join(f'{m["role"]}: {m["content"]}' for m in messages)
reply = endpoint.query(flat + "\nassistant:", None)
except Exception as e:
logger.warning("Persona LLM call failed: %s", e)
return None
if isinstance(reply, dict):
reply = reply.get("response") or reply.get("content") or str(reply)
text = str(reply or "").strip()
return text or None
# ------------------------------------------------------------------
# Server interaction
# ------------------------------------------------------------------
def _send_to_agent(
self, session: requests.Session, message: str
) -> Optional[Dict[str, Any]]:
try:
resp = session.post(
f"{self.server_url}/agent_chat/send",
json={"message": message},
timeout=120,
)
except requests.exceptions.RequestException as e:
logger.warning("agent_chat/send request failed: %s", e)
return None
if resp.status_code != 200:
logger.warning(
"agent_chat/send returned %d: %s", resp.status_code, resp.text[:200]
)
return None
try:
return resp.json()
except ValueError:
return None
def _finish(self, session: requests.Session) -> bool:
try:
resp = session.post(
f"{self.server_url}/agent_chat/finish",
timeout=60,
)
except requests.exceptions.RequestException as e:
logger.warning("agent_chat/finish request failed: %s", e)
return False
if resp.status_code != 200:
logger.warning(
"agent_chat/finish returned %d: %s",
resp.status_code, resp.text[:200],
)
return False
return True
# ------------------------------------------------------------------
# Main entry point
# ------------------------------------------------------------------
def run(
self,
session: requests.Session,
instance_id: str,
task_description: str,
) -> InteractiveSessionResult:
"""Drive one chat session end-to-end."""
history: List[Dict[str, str]] = []
completed = False
error: Optional[str] = None
for turn in range(self.config.max_turns):
user_msg = self._generate_persona_message(task_description, history)
if not user_msg:
error = error or "persona produced no message"
break
# Strip the [DONE] marker before sending so the agent doesn't see
# it; remember that we should finish after this turn.
should_finish = self.config.done_marker.lower() in user_msg.lower()
send_msg = user_msg.replace(self.config.done_marker, "").strip() or "Thanks!"
agent_reply = self._send_to_agent(session, send_msg)
if agent_reply is None:
error = error or "agent send failed"
break
history.append({"role": "user", "content": send_msg})
history.append({
"role": "agent",
"content": agent_reply.get("content", ""),
})
if should_finish:
completed = True
break
ok = self._finish(session)
if not ok and not error:
error = "finish failed"
if ok and not completed:
# We hit max_turns without an explicit DONE; still consider the
# session "done" for accounting purposes.
completed = True
# Build the conversation array the way the server's finish route does
conversation = [
{
"speaker": "User" if msg["role"] == "user" else "Agent",
"text": msg["content"],
}
for msg in history
]
return InteractiveSessionResult(
instance_id=instance_id,
completed=completed,
turns=len(history) // 2,
conversation=conversation,
error=error,
)