Spaces:
Sleeping
Sleeping
| """WatchDog Environment β Shared models for multi-turn oversight and plugins. | |
| Shared types (used by plugins and env): AgentTurn, MultiAgentStep, MultiAgentState, | |
| MultiAgentConfig, ContextMessage. Env-specific types extend OpenEnv Action/Observation/State. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Any, Literal, TYPE_CHECKING | |
| from pydantic import Field | |
| if TYPE_CHECKING: | |
| pass | |
| # βββ Shared Types (plugins + env) ββββββββββββββββββββββββββββββββββββ | |
| ContextRole = Literal["system", "user", "assistant"] | |
| class ContextMessage: | |
| """A single message in the system context (LLM conversation history).""" | |
| role: ContextRole | |
| content: str | |
| class AgentTurn: | |
| """Canonical turn representation. Plugins and env both use this.""" | |
| agent_id: str | |
| action_text: str | |
| step_index: int = 0 | |
| phase: str = "" | |
| display_name: str = "" | |
| moderator_prompt: str = "" | |
| metadata: dict[str, Any] = field(default_factory=dict) | |
| class MultiAgentConfig: | |
| """Base config for a multi-agent system run. Plugins subclass for game-specific fields.""" | |
| pass | |
| ConversationLogEntry = dict[str, Any] | |
| """Plain conversation log entry: speaker_id, speaker_display, message, optionally moderator_prompt.""" | |
| class MultiAgentState: | |
| """Tracks system behaviour across the run. Used when generating each MultiAgentStep.""" | |
| step_index: int = 0 | |
| turns_so_far: list[AgentTurn] = field(default_factory=list) | |
| config: MultiAgentConfig | None = None | |
| done: bool = False | |
| metadata: dict[str, Any] = field(default_factory=dict) | |
| conversation_log: list[ConversationLogEntry] = field(default_factory=list) | |
| class MultiAgentStep: | |
| """One step: multiple agent turns. done=True means scenario is finished.""" | |
| turns: list[AgentTurn] | |
| done: bool = False | |
| step_index: int = 0 | |
| game_id: str = "" | |
| task_id: str = "" | |
| domain: str = "" | |
| state: MultiAgentState | None = None | |
| # βββ Format Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def format_conversation(turns: list[AgentTurn]) -> str: | |
| """Format turns for conversation_so_far. Uses display_name or agent_id.""" | |
| if not turns: | |
| return "[Conversation start]" | |
| lines = [] | |
| for i, t in enumerate(turns): | |
| label = t.display_name or t.agent_id | |
| lines.append(f"[Turn {i + 1}] {label}: {t.action_text}") | |
| return "\n".join(lines) | |
| def format_current_turn(turn: AgentTurn, moderator_prompt: str = "") -> str: | |
| """Build current_turn string. Includes moderator prompt if present.""" | |
| label = turn.display_name or turn.agent_id | |
| prompt = moderator_prompt or turn.moderator_prompt | |
| if prompt: | |
| return f"[Moderator]: {prompt}\n\n[{label}]: {turn.action_text}" | |
| return f"[{label}]: {turn.action_text}" | |
| def agent_turn_to_dict( | |
| turn: AgentTurn, | |
| has_error: bool = False, | |
| displayed_response: str | None = None, | |
| error_detail: dict[str, Any] | None = None, | |
| ) -> dict[str, Any]: | |
| """Convert AgentTurn to dict for _turns_seen compatibility.""" | |
| out: dict[str, Any] = { | |
| "speaker_id": turn.agent_id, | |
| "speaker_display": turn.display_name or turn.agent_id, | |
| "message": turn.action_text, | |
| "displayed_response": displayed_response if displayed_response is not None else turn.action_text, | |
| "has_error": has_error, | |
| "moderator_prompt": turn.moderator_prompt or "", | |
| "phase": turn.phase, | |
| **turn.metadata, | |
| } | |
| if error_detail is not None: | |
| out["error_detail"] = error_detail | |
| return out | |
| # βββ Env-Specific (extend OpenEnv types) ββββββββββββββββββββββββββββββ | |
| from openenv.core.env_server.types import Action, Observation, State | |
| from typing import Optional | |
| class MultiTurnAction(Action): | |
| """Overseer action in a multi-turn oversight episode. Supports any multi-agent system.""" | |
| action_type: str = Field( | |
| ..., | |
| description='One of: "pass", "flag", "question"', | |
| ) | |
| error_type: Optional[str] = Field( | |
| default=None, | |
| description='For flag: "factual_error", "logic_error", "code_bug", "safety_violation", "sycophancy"', | |
| ) | |
| explanation: Optional[str] = Field( | |
| default=None, | |
| description="Explanation for flag actions", | |
| ) | |
| question_text: Optional[str] = Field( | |
| default=None, | |
| description="Question to ask the Worker AI (for question action)", | |
| ) | |
| class MultiTurnObservation(Observation): | |
| """Observation at each step of a multi-agent oversight episode. Plugin-agnostic.""" | |
| conversation_so_far: str = Field( | |
| default="", description="All turns revealed so far" | |
| ) | |
| current_turn: str = Field( | |
| default="", description="The latest turn to evaluate" | |
| ) | |
| current_turn_number: int = Field( | |
| default=0, description="Current worker turn number (1-indexed)" | |
| ) | |
| total_turns: int = Field( | |
| default=0, description="Total worker turns in this episode" | |
| ) | |
| task_domain: str = Field(default="general", description="Conversation domain / game_id") | |
| task_id: str = Field(default="", description="Episode ID") | |
| difficulty: int = Field(default=1, description="Curriculum difficulty 1-4") | |
| remaining_questions: int = Field( | |
| default=2, description="QUESTION actions remaining (investigation budget)" | |
| ) | |
| flags_so_far: int = Field( | |
| default=0, description="Number of FLAGS issued this episode" | |
| ) | |
| phase: str = Field( | |
| default="observe", | |
| description='"observe" (new worker turn), "question_response" (after QUESTION), "done" (episode over)', | |
| ) | |
| step_reward: Optional[float] = Field( | |
| default=None, description="Reward from last action" | |
| ) | |
| cumulative_reward: Optional[float] = Field( | |
| default=None, description="Total reward this episode" | |
| ) | |
| feedback: Optional[str] = Field( | |
| default=None, description="Feedback from last action" | |
| ) | |
| class MultiTurnState(State): | |
| """Episode state for multi-agent oversight. Works with any plugin.""" | |
| episode_id: str = Field(default="", description="Episode identifier") | |
| step_count: int = Field(default=0, description="Steps in current episode") | |
| current_level: int = Field(default=1, description="Current curriculum level") | |
| total_episodes: int = Field(default=0, description="Total episodes completed") | |
| errors_detected: int = 0 | |
| errors_missed: int = 0 | |
| false_flags: int = 0 | |
| correct_passes: int = 0 | |
| questions_used: int = 0 | |
| cumulative_reward: float = 0.0 | |