New-space-openenv / models.py
Mooizz's picture
Upload folder using huggingface_hub
1070765 verified
"""WatchDog Environment β€” Shared models for multi-turn oversight and plugins.
Shared types (used by plugins and env): AgentTurn, MultiAgentStep, MultiAgentState,
MultiAgentConfig, ContextMessage. Env-specific types extend OpenEnv Action/Observation/State.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal, TYPE_CHECKING
from pydantic import Field
if TYPE_CHECKING:
pass
# ─── Shared Types (plugins + env) ────────────────────────────────────
ContextRole = Literal["system", "user", "assistant"]
@dataclass
class ContextMessage:
"""A single message in the system context (LLM conversation history)."""
role: ContextRole
content: str
@dataclass
class AgentTurn:
"""Canonical turn representation. Plugins and env both use this."""
agent_id: str
action_text: str
step_index: int = 0
phase: str = ""
display_name: str = ""
moderator_prompt: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class MultiAgentConfig:
"""Base config for a multi-agent system run. Plugins subclass for game-specific fields."""
pass
ConversationLogEntry = dict[str, Any]
"""Plain conversation log entry: speaker_id, speaker_display, message, optionally moderator_prompt."""
@dataclass
class MultiAgentState:
"""Tracks system behaviour across the run. Used when generating each MultiAgentStep."""
step_index: int = 0
turns_so_far: list[AgentTurn] = field(default_factory=list)
config: MultiAgentConfig | None = None
done: bool = False
metadata: dict[str, Any] = field(default_factory=dict)
conversation_log: list[ConversationLogEntry] = field(default_factory=list)
@dataclass
class MultiAgentStep:
"""One step: multiple agent turns. done=True means scenario is finished."""
turns: list[AgentTurn]
done: bool = False
step_index: int = 0
game_id: str = ""
task_id: str = ""
domain: str = ""
state: MultiAgentState | None = None
# ─── Format Helpers ──────────────────────────────────────────────────
def format_conversation(turns: list[AgentTurn]) -> str:
"""Format turns for conversation_so_far. Uses display_name or agent_id."""
if not turns:
return "[Conversation start]"
lines = []
for i, t in enumerate(turns):
label = t.display_name or t.agent_id
lines.append(f"[Turn {i + 1}] {label}: {t.action_text}")
return "\n".join(lines)
def format_current_turn(turn: AgentTurn, moderator_prompt: str = "") -> str:
"""Build current_turn string. Includes moderator prompt if present."""
label = turn.display_name or turn.agent_id
prompt = moderator_prompt or turn.moderator_prompt
if prompt:
return f"[Moderator]: {prompt}\n\n[{label}]: {turn.action_text}"
return f"[{label}]: {turn.action_text}"
def agent_turn_to_dict(
turn: AgentTurn,
has_error: bool = False,
displayed_response: str | None = None,
error_detail: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Convert AgentTurn to dict for _turns_seen compatibility."""
out: dict[str, Any] = {
"speaker_id": turn.agent_id,
"speaker_display": turn.display_name or turn.agent_id,
"message": turn.action_text,
"displayed_response": displayed_response if displayed_response is not None else turn.action_text,
"has_error": has_error,
"moderator_prompt": turn.moderator_prompt or "",
"phase": turn.phase,
**turn.metadata,
}
if error_detail is not None:
out["error_detail"] = error_detail
return out
# ─── Env-Specific (extend OpenEnv types) ──────────────────────────────
from openenv.core.env_server.types import Action, Observation, State
from typing import Optional
class MultiTurnAction(Action):
"""Overseer action in a multi-turn oversight episode. Supports any multi-agent system."""
action_type: str = Field(
...,
description='One of: "pass", "flag", "question"',
)
error_type: Optional[str] = Field(
default=None,
description='For flag: "factual_error", "logic_error", "code_bug", "safety_violation", "sycophancy"',
)
explanation: Optional[str] = Field(
default=None,
description="Explanation for flag actions",
)
question_text: Optional[str] = Field(
default=None,
description="Question to ask the Worker AI (for question action)",
)
class MultiTurnObservation(Observation):
"""Observation at each step of a multi-agent oversight episode. Plugin-agnostic."""
conversation_so_far: str = Field(
default="", description="All turns revealed so far"
)
current_turn: str = Field(
default="", description="The latest turn to evaluate"
)
current_turn_number: int = Field(
default=0, description="Current worker turn number (1-indexed)"
)
total_turns: int = Field(
default=0, description="Total worker turns in this episode"
)
task_domain: str = Field(default="general", description="Conversation domain / game_id")
task_id: str = Field(default="", description="Episode ID")
difficulty: int = Field(default=1, description="Curriculum difficulty 1-4")
remaining_questions: int = Field(
default=2, description="QUESTION actions remaining (investigation budget)"
)
flags_so_far: int = Field(
default=0, description="Number of FLAGS issued this episode"
)
phase: str = Field(
default="observe",
description='"observe" (new worker turn), "question_response" (after QUESTION), "done" (episode over)',
)
step_reward: Optional[float] = Field(
default=None, description="Reward from last action"
)
cumulative_reward: Optional[float] = Field(
default=None, description="Total reward this episode"
)
feedback: Optional[str] = Field(
default=None, description="Feedback from last action"
)
class MultiTurnState(State):
"""Episode state for multi-agent oversight. Works with any plugin."""
episode_id: str = Field(default="", description="Episode identifier")
step_count: int = Field(default=0, description="Steps in current episode")
current_level: int = Field(default=1, description="Current curriculum level")
total_episodes: int = Field(default=0, description="Total episodes completed")
errors_detected: int = 0
errors_missed: int = 0
false_flags: int = 0
correct_passes: int = 0
questions_used: int = 0
cumulative_reward: float = 0.0