FATHOM-Hero / agents /shared /openenv_compat.py
aarushgupta's picture
Deploy FATHOM-Hero Space bundle
c782fbf verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
from pydantic import BaseModel, ConfigDict, Field
ObsT = TypeVar("ObsT")
ActT = TypeVar("ActT")
StateT = TypeVar("StateT")
try: # pragma: no cover - exercised when openenv-core is installed
from openenv.core.client_types import StepResult as OpenEnvStepResult
from openenv.core.env_server.interfaces import Environment as OpenEnvEnvironment
from openenv.core.env_server.types import (
Action as OpenEnvAction,
EnvironmentMetadata as OpenEnvEnvironmentMetadata,
Observation as OpenEnvObservation,
State as OpenEnvState,
)
OPENENV_AVAILABLE = True
except ImportError: # pragma: no cover - lightweight fallback for local imports/tests
OPENENV_AVAILABLE = False
class Action(BaseModel):
model_config = ConfigDict(
extra="forbid",
validate_assignment=True,
arbitrary_types_allowed=True,
)
metadata: dict[str, Any] = Field(default_factory=dict)
class Observation(BaseModel):
model_config = ConfigDict(
extra="forbid",
validate_assignment=True,
arbitrary_types_allowed=True,
)
done: bool = False
reward: bool | int | float | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
class State(BaseModel):
model_config = ConfigDict(
extra="allow",
validate_assignment=True,
arbitrary_types_allowed=True,
)
episode_id: str | None = None
step_count: int = 0
class EnvironmentMetadata(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
description: str
version: str | None = None
@dataclass
class StepResult(Generic[ObsT]):
observation: ObsT
reward: Optional[float] = None
done: bool = False
class Environment(Generic[ActT, ObsT, StateT]):
SUPPORTS_CONCURRENT_SESSIONS: bool = False
def __init__(self, transform: Any | None = None) -> None:
self.transform = transform
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> ObsT:
raise NotImplementedError
def step(
self,
action: ActT,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> ObsT:
raise NotImplementedError
@property
def state(self) -> StateT:
raise NotImplementedError
def get_metadata(self) -> EnvironmentMetadata:
return EnvironmentMetadata(
name=self.__class__.__name__,
description=f"{self.__class__.__name__} environment",
version="1.0.0",
)
def _apply_transform(self, observation: ObsT) -> ObsT:
return observation if self.transform is None else self.transform(observation)
def close(self) -> None:
return None
else:
Action = OpenEnvAction
Observation = OpenEnvObservation
State = OpenEnvState
Environment = OpenEnvEnvironment
EnvironmentMetadata = OpenEnvEnvironmentMetadata
StepResult = OpenEnvStepResult
def build_step_result(observation: ObsT) -> StepResult[ObsT]:
reward = getattr(observation, "reward", None)
if reward is not None:
reward = float(reward)
return StepResult(
observation=observation,
reward=reward,
done=bool(getattr(observation, "done", False)),
)