File size: 3,690 Bytes
2803d7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar

from pydantic import BaseModel, ConfigDict, Field

ObsT = TypeVar("ObsT")
ActT = TypeVar("ActT")
StateT = TypeVar("StateT")

try:  # pragma: no cover - exercised when openenv-core is installed
    from openenv.core.client_types import StepResult as OpenEnvStepResult
    from openenv.core.env_server.interfaces import Environment as OpenEnvEnvironment
    from openenv.core.env_server.types import (
        Action as OpenEnvAction,
        EnvironmentMetadata as OpenEnvEnvironmentMetadata,
        Observation as OpenEnvObservation,
        State as OpenEnvState,
    )

    OPENENV_AVAILABLE = True
except ImportError:  # pragma: no cover - lightweight fallback for local imports/tests
    OPENENV_AVAILABLE = False

    class Action(BaseModel):
        model_config = ConfigDict(
            extra="forbid",
            validate_assignment=True,
            arbitrary_types_allowed=True,
        )

        metadata: dict[str, Any] = Field(default_factory=dict)

    class Observation(BaseModel):
        model_config = ConfigDict(
            extra="forbid",
            validate_assignment=True,
            arbitrary_types_allowed=True,
        )

        done: bool = False
        reward: bool | int | float | None = None
        metadata: dict[str, Any] = Field(default_factory=dict)

    class State(BaseModel):
        model_config = ConfigDict(
            extra="allow",
            validate_assignment=True,
            arbitrary_types_allowed=True,
        )

        episode_id: str | None = None
        step_count: int = 0

    class EnvironmentMetadata(BaseModel):
        model_config = ConfigDict(extra="forbid")

        name: str
        description: str
        version: str | None = None

    @dataclass
    class StepResult(Generic[ObsT]):
        observation: ObsT
        reward: Optional[float] = None
        done: bool = False

    class Environment(Generic[ActT, ObsT, StateT]):
        SUPPORTS_CONCURRENT_SESSIONS: bool = False

        def __init__(self, transform: Any | None = None) -> None:
            self.transform = transform

        def reset(
            self,
            seed: Optional[int] = None,
            episode_id: Optional[str] = None,
            **kwargs: Any,
        ) -> ObsT:
            raise NotImplementedError

        def step(
            self,
            action: ActT,
            timeout_s: Optional[float] = None,
            **kwargs: Any,
        ) -> ObsT:
            raise NotImplementedError

        @property
        def state(self) -> StateT:
            raise NotImplementedError

        def get_metadata(self) -> EnvironmentMetadata:
            return EnvironmentMetadata(
                name=self.__class__.__name__,
                description=f"{self.__class__.__name__} environment",
                version="1.0.0",
            )

        def _apply_transform(self, observation: ObsT) -> ObsT:
            return observation if self.transform is None else self.transform(observation)

        def close(self) -> None:
            return None

else:
    Action = OpenEnvAction
    Observation = OpenEnvObservation
    State = OpenEnvState
    Environment = OpenEnvEnvironment
    EnvironmentMetadata = OpenEnvEnvironmentMetadata
    StepResult = OpenEnvStepResult


def build_step_result(observation: ObsT) -> StepResult[ObsT]:
    reward = getattr(observation, "reward", None)
    if reward is not None:
        reward = float(reward)
    return StepResult(
        observation=observation,
        reward=reward,
        done=bool(getattr(observation, "done", False)),
    )