from __future__ import annotations from dataclasses import dataclass from typing import Any, Generic, Optional, TypeVar from pydantic import BaseModel, ConfigDict, Field ObsT = TypeVar("ObsT") ActT = TypeVar("ActT") StateT = TypeVar("StateT") try: # pragma: no cover - exercised when openenv-core is installed from openenv.core.client_types import StepResult as OpenEnvStepResult from openenv.core.env_server.interfaces import Environment as OpenEnvEnvironment from openenv.core.env_server.types import ( Action as OpenEnvAction, EnvironmentMetadata as OpenEnvEnvironmentMetadata, Observation as OpenEnvObservation, State as OpenEnvState, ) OPENENV_AVAILABLE = True except ImportError: # pragma: no cover - lightweight fallback for local imports/tests OPENENV_AVAILABLE = False class Action(BaseModel): model_config = ConfigDict( extra="forbid", validate_assignment=True, arbitrary_types_allowed=True, ) metadata: dict[str, Any] = Field(default_factory=dict) class Observation(BaseModel): model_config = ConfigDict( extra="forbid", validate_assignment=True, arbitrary_types_allowed=True, ) done: bool = False reward: bool | int | float | None = None metadata: dict[str, Any] = Field(default_factory=dict) class State(BaseModel): model_config = ConfigDict( extra="allow", validate_assignment=True, arbitrary_types_allowed=True, ) episode_id: str | None = None step_count: int = 0 class EnvironmentMetadata(BaseModel): model_config = ConfigDict(extra="forbid") name: str description: str version: str | None = None @dataclass class StepResult(Generic[ObsT]): observation: ObsT reward: Optional[float] = None done: bool = False class Environment(Generic[ActT, ObsT, StateT]): SUPPORTS_CONCURRENT_SESSIONS: bool = False def __init__(self, transform: Any | None = None) -> None: self.transform = transform def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> ObsT: raise NotImplementedError def step( self, action: ActT, timeout_s: Optional[float] = None, **kwargs: Any, ) -> ObsT: raise NotImplementedError @property def state(self) -> StateT: raise NotImplementedError def get_metadata(self) -> EnvironmentMetadata: return EnvironmentMetadata( name=self.__class__.__name__, description=f"{self.__class__.__name__} environment", version="1.0.0", ) def _apply_transform(self, observation: ObsT) -> ObsT: return observation if self.transform is None else self.transform(observation) def close(self) -> None: return None else: Action = OpenEnvAction Observation = OpenEnvObservation State = OpenEnvState Environment = OpenEnvEnvironment EnvironmentMetadata = OpenEnvEnvironmentMetadata StepResult = OpenEnvStepResult def build_step_result(observation: ObsT) -> StepResult[ObsT]: reward = getattr(observation, "reward", None) if reward is not None: reward = float(reward) return StepResult( observation=observation, reward=reward, done=bool(getattr(observation, "done", False)), )