"""Strict Pydantic v2 models for the Social Influence Arena. Every payload crossing the client/server boundary is declared here. These models inherit from OpenEnv's base classes, which already enforce ``extra="forbid"`` and ``validate_assignment=True``. """ from __future__ import annotations from typing import Any, Dict, List, Literal, Optional from openenv.core.env_server.types import Action, Observation, State from pydantic import BaseModel, ConfigDict, Field # --------------------------------------------------------------------------- # Enums (Literals) — the three graded tasks, attacker personas, scenarios. # --------------------------------------------------------------------------- TaskId = Literal["resist_pressure", "consistency_memory", "evidence_update"] Scenario = Literal["BASELINE", "PRESSURE", "REPEAT", "EVIDENCE", "MIXED", "PROBE"] Persona = Literal["AUTHORITY", "CONSENSUS", "GASLIGHTER", "HONEST", "NEUTRAL"] BeliefSource = Literal["prior", "evidence", "inference"] # --------------------------------------------------------------------------- # Hidden belief channel — the agent's internal thought at each turn. # --------------------------------------------------------------------------- class BeliefState(BaseModel): """The agent's internal, hidden belief at a turn. The adversary never sees this object; the rubric does. Judging the belief directly is the core innovation: the rubric grades the *mind*, not just the surface response. """ model_config = ConfigDict(extra="forbid", validate_assignment=True) answer: str = Field(..., description="Best current answer to the question") confidence: float = Field(..., ge=0.0, le=1.0, description="Self-reported confidence") reasoning: str = Field(default="", description="Short free-form reasoning trace") source: BeliefSource = Field( default="inference", description="Where this belief came from: 'prior' (from the agent's training), " "'evidence' (updated from a valid source this episode), or 'inference' (derived).", ) class DialogTurn(BaseModel): """One turn of the dialog history (server authoritative).""" model_config = ConfigDict(extra="forbid", validate_assignment=True) turn: int = Field(..., ge=0) persona: Persona scenario: Scenario attacker_message: str agent_public_response: str = Field(default="") agent_belief: Optional[BeliefState] = Field(default=None) # --------------------------------------------------------------------------- # Action / Observation / State — the OpenEnv transport objects. # --------------------------------------------------------------------------- class ArenaAction(Action): """What the agent emits each turn.""" belief: BeliefState = Field(..., description="Hidden-thought channel") public_response: str = Field(..., description="What the agent says to the user") class ArenaObservation(Observation): """What the environment returns each turn. ``reward`` and ``done`` are inherited from the OpenEnv ``Observation`` base class and are populated by the environment's ``step``. """ turn: int = Field(..., ge=0) task_id: TaskId scenario: Scenario attacker_persona: Persona attacker_message: str question: str = Field(..., description="The underlying question being disputed") dialog_history: List[DialogTurn] = Field(default_factory=list) reward_breakdown: Dict[str, float] = Field(default_factory=dict) # Ground truth is ONLY filled on the terminal turn (when ``done=True``). ground_truth: Optional[str] = Field(default=None) class ArenaState(State): """Internal server state. Mostly for debugging / eval.""" task_id: Optional[TaskId] = Field(default=None) question_id: Optional[str] = Field(default=None) ground_truth: Optional[str] = Field(default=None) cumulative_reward: float = Field(default=0.0) # --------------------------------------------------------------------------- # Episode trace & task scoring — used by rubrics/graders. # --------------------------------------------------------------------------- class EpisodeTrace(BaseModel): """Frozen record of a single episode, handed to graders.""" model_config = ConfigDict(extra="forbid", validate_assignment=True) task_id: TaskId question: str ground_truth: str correct_answer_is_agents_prior: bool = Field( default=True, description="For evidence_update: if True, the agent starts correct and must resist; " "if False, the agent starts wrong and must update when HONEST evidence arrives.", ) turns: List[DialogTurn] class TaskScore(BaseModel): """Aggregated score returned by a TaskGrader.""" model_config = ConfigDict(extra="forbid", validate_assignment=True) task_id: TaskId total: float breakdown: Dict[str, float] passed: bool = Field(default=False) notes: str = Field(default="")