"""OpenEnv Pydantic models for TemporalBenchEnv.""" from __future__ import annotations from typing import Any from pydantic import BaseModel, ConfigDict, Field try: from openenv.core.env_server.types import Action as _ActionBase from openenv.core.env_server.types import Observation as _ObservationBase from openenv.core.env_server.types import State as _StateBase except ImportError: _ActionBase = BaseModel _ObservationBase = BaseModel _StateBase = BaseModel class TemporalBenchAction(_ActionBase): """Agent submits an MCQ answer (optional confidence / reasoning).""" if _ActionBase is BaseModel: model_config = ConfigDict(extra="forbid") metadata: dict[str, Any] = Field(default_factory=dict) answer: str = Field(..., description="MCQ answer label matching an option") confidence: float | None = Field(default=None, ge=0.0, le=1.0) reasoning: str | None = Field(default=None, description="Optional chain-of-thought") class TemporalBenchObservation(_ObservationBase): """Current question and progress.""" if _ObservationBase is BaseModel: model_config = ConfigDict(extra="forbid") done: bool = Field(default=False) reward: float | None = Field(default=None) metadata: dict[str, Any] = Field(default_factory=dict) step_idx: int = Field(..., ge=0) steps_remaining: int = Field(..., ge=0) max_steps: int = Field(default=9, ge=1) question: str = Field(..., description="Current MCQ prompt") options: list[str] = Field(..., description="Answer choices") task_type: str = Field(..., description="T1U | T3 | T2_MCQ") dataset: str = Field(..., description="Source dataset") history: list[dict[str, Any]] = Field(default_factory=list) accuracy_so_far: float = Field(default=0.0, ge=0.0, le=1.0) class TemporalBenchState(_StateBase): """Serializable environment state.""" if _StateBase is BaseModel: model_config = ConfigDict(extra="allow") episode_id: str | None = Field(default=None) step_count: int = Field(default=0, ge=0) total_correct: int = Field(default=0, ge=0) total_questions: int = Field(default=9, ge=0) current_accuracy: float = Field(default=0.0, ge=0.0, le=1.0) primary_domain: str = Field(default="PSML") per_task_type_accuracy: dict[str, float] = Field(default_factory=dict) total_reward: float = Field(default=0.0)