File size: 6,854 Bytes
1070765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""WatchDog Environment β€” Shared models for multi-turn oversight and plugins.

Shared types (used by plugins and env): AgentTurn, MultiAgentStep, MultiAgentState,
MultiAgentConfig, ContextMessage. Env-specific types extend OpenEnv Action/Observation/State.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Literal, TYPE_CHECKING

from pydantic import Field

if TYPE_CHECKING:
    pass

# ─── Shared Types (plugins + env) ────────────────────────────────────

ContextRole = Literal["system", "user", "assistant"]


@dataclass
class ContextMessage:
    """A single message in the system context (LLM conversation history)."""

    role: ContextRole
    content: str


@dataclass
class AgentTurn:
    """Canonical turn representation. Plugins and env both use this."""

    agent_id: str
    action_text: str
    step_index: int = 0
    phase: str = ""
    display_name: str = ""
    moderator_prompt: str = ""
    metadata: dict[str, Any] = field(default_factory=dict)


@dataclass
class MultiAgentConfig:
    """Base config for a multi-agent system run. Plugins subclass for game-specific fields."""

    pass


ConversationLogEntry = dict[str, Any]
"""Plain conversation log entry: speaker_id, speaker_display, message, optionally moderator_prompt."""


@dataclass
class MultiAgentState:
    """Tracks system behaviour across the run. Used when generating each MultiAgentStep."""

    step_index: int = 0
    turns_so_far: list[AgentTurn] = field(default_factory=list)
    config: MultiAgentConfig | None = None
    done: bool = False
    metadata: dict[str, Any] = field(default_factory=dict)
    conversation_log: list[ConversationLogEntry] = field(default_factory=list)


@dataclass
class MultiAgentStep:
    """One step: multiple agent turns. done=True means scenario is finished."""

    turns: list[AgentTurn]
    done: bool = False
    step_index: int = 0
    game_id: str = ""
    task_id: str = ""
    domain: str = ""
    state: MultiAgentState | None = None


# ─── Format Helpers ──────────────────────────────────────────────────

def format_conversation(turns: list[AgentTurn]) -> str:
    """Format turns for conversation_so_far. Uses display_name or agent_id."""
    if not turns:
        return "[Conversation start]"
    lines = []
    for i, t in enumerate(turns):
        label = t.display_name or t.agent_id
        lines.append(f"[Turn {i + 1}] {label}: {t.action_text}")
    return "\n".join(lines)


def format_current_turn(turn: AgentTurn, moderator_prompt: str = "") -> str:
    """Build current_turn string. Includes moderator prompt if present."""
    label = turn.display_name or turn.agent_id
    prompt = moderator_prompt or turn.moderator_prompt
    if prompt:
        return f"[Moderator]: {prompt}\n\n[{label}]: {turn.action_text}"
    return f"[{label}]: {turn.action_text}"


def agent_turn_to_dict(
    turn: AgentTurn,
    has_error: bool = False,
    displayed_response: str | None = None,
    error_detail: dict[str, Any] | None = None,
) -> dict[str, Any]:
    """Convert AgentTurn to dict for _turns_seen compatibility."""
    out: dict[str, Any] = {
        "speaker_id": turn.agent_id,
        "speaker_display": turn.display_name or turn.agent_id,
        "message": turn.action_text,
        "displayed_response": displayed_response if displayed_response is not None else turn.action_text,
        "has_error": has_error,
        "moderator_prompt": turn.moderator_prompt or "",
        "phase": turn.phase,
        **turn.metadata,
    }
    if error_detail is not None:
        out["error_detail"] = error_detail
    return out


# ─── Env-Specific (extend OpenEnv types) ──────────────────────────────

from openenv.core.env_server.types import Action, Observation, State
from typing import Optional


class MultiTurnAction(Action):
    """Overseer action in a multi-turn oversight episode. Supports any multi-agent system."""

    action_type: str = Field(
        ...,
        description='One of: "pass", "flag", "question"',
    )
    error_type: Optional[str] = Field(
        default=None,
        description='For flag: "factual_error", "logic_error", "code_bug", "safety_violation", "sycophancy"',
    )
    explanation: Optional[str] = Field(
        default=None,
        description="Explanation for flag actions",
    )
    question_text: Optional[str] = Field(
        default=None,
        description="Question to ask the Worker AI (for question action)",
    )


class MultiTurnObservation(Observation):
    """Observation at each step of a multi-agent oversight episode. Plugin-agnostic."""

    conversation_so_far: str = Field(
        default="", description="All turns revealed so far"
    )
    current_turn: str = Field(
        default="", description="The latest turn to evaluate"
    )
    current_turn_number: int = Field(
        default=0, description="Current worker turn number (1-indexed)"
    )
    total_turns: int = Field(
        default=0, description="Total worker turns in this episode"
    )
    task_domain: str = Field(default="general", description="Conversation domain / game_id")
    task_id: str = Field(default="", description="Episode ID")
    difficulty: int = Field(default=1, description="Curriculum difficulty 1-4")
    remaining_questions: int = Field(
        default=2, description="QUESTION actions remaining (investigation budget)"
    )
    flags_so_far: int = Field(
        default=0, description="Number of FLAGS issued this episode"
    )
    phase: str = Field(
        default="observe",
        description='"observe" (new worker turn), "question_response" (after QUESTION), "done" (episode over)',
    )
    step_reward: Optional[float] = Field(
        default=None, description="Reward from last action"
    )
    cumulative_reward: Optional[float] = Field(
        default=None, description="Total reward this episode"
    )
    feedback: Optional[str] = Field(
        default=None, description="Feedback from last action"
    )


class MultiTurnState(State):
    """Episode state for multi-agent oversight. Works with any plugin."""

    episode_id: str = Field(default="", description="Episode identifier")
    step_count: int = Field(default=0, description="Steps in current episode")
    current_level: int = Field(default=1, description="Current curriculum level")
    total_episodes: int = Field(default=0, description="Total episodes completed")
    errors_detected: int = 0
    errors_missed: int = 0
    false_flags: int = 0
    correct_passes: int = 0
    questions_used: int = 0

    cumulative_reward: float = 0.0