Spaces:
Paused
Paused
| """ | |
| Interactive chat session driver for the simulator. | |
| When the annotation server's ``instance_display`` includes an | |
| ``interactive_chat`` field, the annotator is expected to chat with a live | |
| agent backend before submitting trajectory ratings. | |
| :class:`InteractiveSessionRunner` plays the user side of that chat: it asks | |
| a small "persona" LLM to generate user messages, posts them to the server's | |
| ``/agent_chat/send`` route, and finishes with ``/agent_chat/finish`` so the | |
| captured conversation is written into the instance data. | |
| After the runner returns, the regular annotation pipeline (e.g. | |
| :class:`AgentSimulatorStrategy`) picks up the freshly populated | |
| ``conversation`` field and produces ratings. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional | |
| import requests | |
| from .config import InteractiveConfig | |
| logger = logging.getLogger(__name__) | |
| class InteractiveSessionResult: | |
| """Outcome of one interactive_chat run.""" | |
| instance_id: str | |
| completed: bool | |
| turns: int | |
| conversation: List[Dict[str, Any]] | |
| error: Optional[str] = None | |
| class InteractiveSessionRunner: | |
| """Drive a multi-turn ``interactive_chat`` against the server. | |
| The runner is stateless across instances: ``run`` is called once per | |
| instance and returns the resulting conversation. The persona LLM is | |
| initialized lazily on first use so importing the module is cheap. | |
| """ | |
| def __init__(self, config: InteractiveConfig, server_url: str): | |
| self.config = config | |
| self.server_url = server_url.rstrip("/") | |
| self._endpoint = None # lazy | |
| # ------------------------------------------------------------------ | |
| # Persona endpoint setup | |
| # ------------------------------------------------------------------ | |
| def _get_endpoint(self): | |
| if self._endpoint is not None: | |
| return self._endpoint | |
| try: | |
| from potato.ai.ai_endpoint import AIEndpointFactory | |
| ai_cfg: Dict[str, Any] = { | |
| "model": self.config.model, | |
| "api_key": self.config.api_key, | |
| "max_tokens": self.config.max_tokens, | |
| "temperature": self.config.temperature, | |
| } | |
| if self.config.base_url: | |
| ai_cfg["base_url"] = self.config.base_url | |
| self._endpoint = AIEndpointFactory.create_endpoint({ | |
| "ai_support": { | |
| "enabled": True, | |
| "endpoint_type": self.config.endpoint_type, | |
| "ai_config": ai_cfg, | |
| } | |
| }) | |
| except Exception as e: | |
| logger.warning("InteractiveSessionRunner: persona endpoint init failed: %s", e) | |
| self._endpoint = None | |
| return self._endpoint | |
| # ------------------------------------------------------------------ | |
| # Persona messaging | |
| # ------------------------------------------------------------------ | |
| def _generate_persona_message( | |
| self, | |
| task: str, | |
| history: List[Dict[str, str]], | |
| ) -> Optional[str]: | |
| endpoint = self._get_endpoint() | |
| if endpoint is None: | |
| return None | |
| if not history and self.config.first_message_template: | |
| return self.config.first_message_template.format(task=task) | |
| # Build chat history. The persona's "user" role is what we send to | |
| # the agent, so from the persona LLM's perspective those are | |
| # *assistant* messages and the agent's replies are *user* prompts. | |
| messages: List[Dict[str, str]] = [ | |
| {"role": "system", "content": self.config.persona_system_prompt | |
| + f"\n\nThe task you want completed is:\n{task}"} | |
| ] | |
| for msg in history: | |
| if msg["role"] == "user": # what the persona previously sent | |
| messages.append({"role": "assistant", "content": msg["content"]}) | |
| else: # agent's reply | |
| messages.append({"role": "user", "content": msg["content"]}) | |
| try: | |
| if hasattr(endpoint, "chat_query"): | |
| reply = endpoint.chat_query(messages) | |
| else: | |
| # Fall back to flattening into a single prompt | |
| flat = "\n".join(f'{m["role"]}: {m["content"]}' for m in messages) | |
| reply = endpoint.query(flat + "\nassistant:", None) | |
| except Exception as e: | |
| logger.warning("Persona LLM call failed: %s", e) | |
| return None | |
| if isinstance(reply, dict): | |
| reply = reply.get("response") or reply.get("content") or str(reply) | |
| text = str(reply or "").strip() | |
| return text or None | |
| # ------------------------------------------------------------------ | |
| # Server interaction | |
| # ------------------------------------------------------------------ | |
| def _send_to_agent( | |
| self, session: requests.Session, message: str | |
| ) -> Optional[Dict[str, Any]]: | |
| try: | |
| resp = session.post( | |
| f"{self.server_url}/agent_chat/send", | |
| json={"message": message}, | |
| timeout=120, | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| logger.warning("agent_chat/send request failed: %s", e) | |
| return None | |
| if resp.status_code != 200: | |
| logger.warning( | |
| "agent_chat/send returned %d: %s", resp.status_code, resp.text[:200] | |
| ) | |
| return None | |
| try: | |
| return resp.json() | |
| except ValueError: | |
| return None | |
| def _finish(self, session: requests.Session) -> bool: | |
| try: | |
| resp = session.post( | |
| f"{self.server_url}/agent_chat/finish", | |
| timeout=60, | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| logger.warning("agent_chat/finish request failed: %s", e) | |
| return False | |
| if resp.status_code != 200: | |
| logger.warning( | |
| "agent_chat/finish returned %d: %s", | |
| resp.status_code, resp.text[:200], | |
| ) | |
| return False | |
| return True | |
| # ------------------------------------------------------------------ | |
| # Main entry point | |
| # ------------------------------------------------------------------ | |
| def run( | |
| self, | |
| session: requests.Session, | |
| instance_id: str, | |
| task_description: str, | |
| ) -> InteractiveSessionResult: | |
| """Drive one chat session end-to-end.""" | |
| history: List[Dict[str, str]] = [] | |
| completed = False | |
| error: Optional[str] = None | |
| for turn in range(self.config.max_turns): | |
| user_msg = self._generate_persona_message(task_description, history) | |
| if not user_msg: | |
| error = error or "persona produced no message" | |
| break | |
| # Strip the [DONE] marker before sending so the agent doesn't see | |
| # it; remember that we should finish after this turn. | |
| should_finish = self.config.done_marker.lower() in user_msg.lower() | |
| send_msg = user_msg.replace(self.config.done_marker, "").strip() or "Thanks!" | |
| agent_reply = self._send_to_agent(session, send_msg) | |
| if agent_reply is None: | |
| error = error or "agent send failed" | |
| break | |
| history.append({"role": "user", "content": send_msg}) | |
| history.append({ | |
| "role": "agent", | |
| "content": agent_reply.get("content", ""), | |
| }) | |
| if should_finish: | |
| completed = True | |
| break | |
| ok = self._finish(session) | |
| if not ok and not error: | |
| error = "finish failed" | |
| if ok and not completed: | |
| # We hit max_turns without an explicit DONE; still consider the | |
| # session "done" for accounting purposes. | |
| completed = True | |
| # Build the conversation array the way the server's finish route does | |
| conversation = [ | |
| { | |
| "speaker": "User" if msg["role"] == "user" else "Agent", | |
| "text": msg["content"], | |
| } | |
| for msg in history | |
| ] | |
| return InteractiveSessionResult( | |
| instance_id=instance_id, | |
| completed=completed, | |
| turns=len(history) // 2, | |
| conversation=conversation, | |
| error=error, | |
| ) | |