Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Generic, Optional, TypeVar | |
| from pydantic import BaseModel, ConfigDict, Field | |
| ObsT = TypeVar("ObsT") | |
| ActT = TypeVar("ActT") | |
| StateT = TypeVar("StateT") | |
| try: # pragma: no cover - exercised when openenv-core is installed | |
| from openenv.core.client_types import StepResult as OpenEnvStepResult | |
| from openenv.core.env_server.interfaces import Environment as OpenEnvEnvironment | |
| from openenv.core.env_server.types import ( | |
| Action as OpenEnvAction, | |
| EnvironmentMetadata as OpenEnvEnvironmentMetadata, | |
| Observation as OpenEnvObservation, | |
| State as OpenEnvState, | |
| ) | |
| OPENENV_AVAILABLE = True | |
| except ImportError: # pragma: no cover - lightweight fallback for local imports/tests | |
| OPENENV_AVAILABLE = False | |
| class Action(BaseModel): | |
| model_config = ConfigDict( | |
| extra="forbid", | |
| validate_assignment=True, | |
| arbitrary_types_allowed=True, | |
| ) | |
| metadata: dict[str, Any] = Field(default_factory=dict) | |
| class Observation(BaseModel): | |
| model_config = ConfigDict( | |
| extra="forbid", | |
| validate_assignment=True, | |
| arbitrary_types_allowed=True, | |
| ) | |
| done: bool = False | |
| reward: bool | int | float | None = None | |
| metadata: dict[str, Any] = Field(default_factory=dict) | |
| class State(BaseModel): | |
| model_config = ConfigDict( | |
| extra="allow", | |
| validate_assignment=True, | |
| arbitrary_types_allowed=True, | |
| ) | |
| episode_id: str | None = None | |
| step_count: int = 0 | |
| class EnvironmentMetadata(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| name: str | |
| description: str | |
| version: str | None = None | |
| class StepResult(Generic[ObsT]): | |
| observation: ObsT | |
| reward: Optional[float] = None | |
| done: bool = False | |
| class Environment(Generic[ActT, ObsT, StateT]): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = False | |
| def __init__(self, transform: Any | None = None) -> None: | |
| self.transform = transform | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> ObsT: | |
| raise NotImplementedError | |
| def step( | |
| self, | |
| action: ActT, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> ObsT: | |
| raise NotImplementedError | |
| def state(self) -> StateT: | |
| raise NotImplementedError | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| return EnvironmentMetadata( | |
| name=self.__class__.__name__, | |
| description=f"{self.__class__.__name__} environment", | |
| version="1.0.0", | |
| ) | |
| def _apply_transform(self, observation: ObsT) -> ObsT: | |
| return observation if self.transform is None else self.transform(observation) | |
| def close(self) -> None: | |
| return None | |
| else: | |
| Action = OpenEnvAction | |
| Observation = OpenEnvObservation | |
| State = OpenEnvState | |
| Environment = OpenEnvEnvironment | |
| EnvironmentMetadata = OpenEnvEnvironmentMetadata | |
| StepResult = OpenEnvStepResult | |
| def build_step_result(observation: ObsT) -> StepResult[ObsT]: | |
| reward = getattr(observation, "reward", None) | |
| if reward is not None: | |
| reward = float(reward) | |
| return StepResult( | |
| observation=observation, | |
| reward=reward, | |
| done=bool(getattr(observation, "done", False)), | |
| ) | |