TemporalBenchEnv / env /models.py
yashu2000's picture
Upload folder using huggingface_hub
d954568 verified
"""OpenEnv Pydantic models for TemporalBenchEnv."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
try:
from openenv.core.env_server.types import Action as _ActionBase
from openenv.core.env_server.types import Observation as _ObservationBase
from openenv.core.env_server.types import State as _StateBase
except ImportError:
_ActionBase = BaseModel
_ObservationBase = BaseModel
_StateBase = BaseModel
class TemporalBenchAction(_ActionBase):
"""Agent submits an MCQ answer (optional confidence / reasoning)."""
if _ActionBase is BaseModel:
model_config = ConfigDict(extra="forbid")
metadata: dict[str, Any] = Field(default_factory=dict)
answer: str = Field(..., description="MCQ answer label matching an option")
confidence: float | None = Field(default=None, ge=0.0, le=1.0)
reasoning: str | None = Field(default=None, description="Optional chain-of-thought")
class TemporalBenchObservation(_ObservationBase):
"""Current question and progress."""
if _ObservationBase is BaseModel:
model_config = ConfigDict(extra="forbid")
done: bool = Field(default=False)
reward: float | None = Field(default=None)
metadata: dict[str, Any] = Field(default_factory=dict)
step_idx: int = Field(..., ge=0)
steps_remaining: int = Field(..., ge=0)
max_steps: int = Field(default=9, ge=1)
question: str = Field(..., description="Current MCQ prompt")
options: list[str] = Field(..., description="Answer choices")
task_type: str = Field(..., description="T1U | T3 | T2_MCQ")
dataset: str = Field(..., description="Source dataset")
history: list[dict[str, Any]] = Field(default_factory=list)
accuracy_so_far: float = Field(default=0.0, ge=0.0, le=1.0)
class TemporalBenchState(_StateBase):
"""Serializable environment state."""
if _StateBase is BaseModel:
model_config = ConfigDict(extra="allow")
episode_id: str | None = Field(default=None)
step_count: int = Field(default=0, ge=0)
total_correct: int = Field(default=0, ge=0)
total_questions: int = Field(default=9, ge=0)
current_accuracy: float = Field(default=0.0, ge=0.0, le=1.0)
primary_domain: str = Field(default="PSML")
per_task_type_accuracy: dict[str, float] = Field(default_factory=dict)
total_reward: float = Field(default=0.0)