Spaces:
Paused
Paused
File size: 6,854 Bytes
1070765 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """WatchDog Environment β Shared models for multi-turn oversight and plugins.
Shared types (used by plugins and env): AgentTurn, MultiAgentStep, MultiAgentState,
MultiAgentConfig, ContextMessage. Env-specific types extend OpenEnv Action/Observation/State.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal, TYPE_CHECKING
from pydantic import Field
if TYPE_CHECKING:
pass
# βββ Shared Types (plugins + env) ββββββββββββββββββββββββββββββββββββ
ContextRole = Literal["system", "user", "assistant"]
@dataclass
class ContextMessage:
"""A single message in the system context (LLM conversation history)."""
role: ContextRole
content: str
@dataclass
class AgentTurn:
"""Canonical turn representation. Plugins and env both use this."""
agent_id: str
action_text: str
step_index: int = 0
phase: str = ""
display_name: str = ""
moderator_prompt: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class MultiAgentConfig:
"""Base config for a multi-agent system run. Plugins subclass for game-specific fields."""
pass
ConversationLogEntry = dict[str, Any]
"""Plain conversation log entry: speaker_id, speaker_display, message, optionally moderator_prompt."""
@dataclass
class MultiAgentState:
"""Tracks system behaviour across the run. Used when generating each MultiAgentStep."""
step_index: int = 0
turns_so_far: list[AgentTurn] = field(default_factory=list)
config: MultiAgentConfig | None = None
done: bool = False
metadata: dict[str, Any] = field(default_factory=dict)
conversation_log: list[ConversationLogEntry] = field(default_factory=list)
@dataclass
class MultiAgentStep:
"""One step: multiple agent turns. done=True means scenario is finished."""
turns: list[AgentTurn]
done: bool = False
step_index: int = 0
game_id: str = ""
task_id: str = ""
domain: str = ""
state: MultiAgentState | None = None
# βββ Format Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββ
def format_conversation(turns: list[AgentTurn]) -> str:
"""Format turns for conversation_so_far. Uses display_name or agent_id."""
if not turns:
return "[Conversation start]"
lines = []
for i, t in enumerate(turns):
label = t.display_name or t.agent_id
lines.append(f"[Turn {i + 1}] {label}: {t.action_text}")
return "\n".join(lines)
def format_current_turn(turn: AgentTurn, moderator_prompt: str = "") -> str:
"""Build current_turn string. Includes moderator prompt if present."""
label = turn.display_name or turn.agent_id
prompt = moderator_prompt or turn.moderator_prompt
if prompt:
return f"[Moderator]: {prompt}\n\n[{label}]: {turn.action_text}"
return f"[{label}]: {turn.action_text}"
def agent_turn_to_dict(
turn: AgentTurn,
has_error: bool = False,
displayed_response: str | None = None,
error_detail: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Convert AgentTurn to dict for _turns_seen compatibility."""
out: dict[str, Any] = {
"speaker_id": turn.agent_id,
"speaker_display": turn.display_name or turn.agent_id,
"message": turn.action_text,
"displayed_response": displayed_response if displayed_response is not None else turn.action_text,
"has_error": has_error,
"moderator_prompt": turn.moderator_prompt or "",
"phase": turn.phase,
**turn.metadata,
}
if error_detail is not None:
out["error_detail"] = error_detail
return out
# βββ Env-Specific (extend OpenEnv types) ββββββββββββββββββββββββββββββ
from openenv.core.env_server.types import Action, Observation, State
from typing import Optional
class MultiTurnAction(Action):
"""Overseer action in a multi-turn oversight episode. Supports any multi-agent system."""
action_type: str = Field(
...,
description='One of: "pass", "flag", "question"',
)
error_type: Optional[str] = Field(
default=None,
description='For flag: "factual_error", "logic_error", "code_bug", "safety_violation", "sycophancy"',
)
explanation: Optional[str] = Field(
default=None,
description="Explanation for flag actions",
)
question_text: Optional[str] = Field(
default=None,
description="Question to ask the Worker AI (for question action)",
)
class MultiTurnObservation(Observation):
"""Observation at each step of a multi-agent oversight episode. Plugin-agnostic."""
conversation_so_far: str = Field(
default="", description="All turns revealed so far"
)
current_turn: str = Field(
default="", description="The latest turn to evaluate"
)
current_turn_number: int = Field(
default=0, description="Current worker turn number (1-indexed)"
)
total_turns: int = Field(
default=0, description="Total worker turns in this episode"
)
task_domain: str = Field(default="general", description="Conversation domain / game_id")
task_id: str = Field(default="", description="Episode ID")
difficulty: int = Field(default=1, description="Curriculum difficulty 1-4")
remaining_questions: int = Field(
default=2, description="QUESTION actions remaining (investigation budget)"
)
flags_so_far: int = Field(
default=0, description="Number of FLAGS issued this episode"
)
phase: str = Field(
default="observe",
description='"observe" (new worker turn), "question_response" (after QUESTION), "done" (episode over)',
)
step_reward: Optional[float] = Field(
default=None, description="Reward from last action"
)
cumulative_reward: Optional[float] = Field(
default=None, description="Total reward this episode"
)
feedback: Optional[str] = Field(
default=None, description="Feedback from last action"
)
class MultiTurnState(State):
"""Episode state for multi-agent oversight. Works with any plugin."""
episode_id: str = Field(default="", description="Episode identifier")
step_count: int = Field(default=0, description="Steps in current episode")
current_level: int = Field(default=1, description="Current curriculum level")
total_episodes: int = Field(default=0, description="Total episodes completed")
errors_detected: int = 0
errors_missed: int = 0
false_flags: int = 0
correct_passes: int = 0
questions_used: int = 0
cumulative_reward: float = 0.0
|