File size: 9,717 Bytes
de0f978
 
 
 
 
 
 
 
913ebda
 
 
de0f978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import inspect
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, Protocol, TypeVar, TYPE_CHECKING

from typing_extensions import TypedDict

from .types import Action, Observation, State, EnvironmentMetadata

if TYPE_CHECKING:
    from openenv.core.rubrics import Rubric

ActT = TypeVar("ActT", bound=Action)
ObsT = TypeVar("ObsT", bound=Observation)
StateT = TypeVar("StateT", bound=State)


class Message(TypedDict):
    """A message in a conversation.

    Compatible with Huggingface chat template format.
    """

    role: str
    content: str


class ModelTokenizer(Protocol):
    """Protocol for tokenizers that support chat templates.

    This protocol defines the interface that tokenizers must implement
    to work with chat-based environments. It's compatible with
    Huggingface transformers tokenizers.
    """

    def apply_chat_template(
        self,
        conversation: list[Message],
        tokenize: bool = True,
        return_tensors: str | None = None,
        **kwargs: Any,
    ) -> Any:
        """Apply a chat template to format and optionally tokenize a conversation.

        Args:
            conversation: List of message dictionaries with 'role' and 'content'
            tokenize: Whether to tokenize the output
            return_tensors: Format for returned tensors ('pt' for PyTorch)
            **kwargs: Additional arguments

        Returns:
            Formatted and optionally tokenized conversation
        """
        ...

    def decode(
        self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
    ) -> str:
        """Decode token IDs back to text.

        Args:
            token_ids: Token IDs to decode
            skip_special_tokens: Whether to skip special tokens in output
            **kwargs: Additional arguments

        Returns:
            Decoded text string
        """
        ...


class Transform(ABC, Generic[ObsT]):
    """Transform observations to add rewards, metrics, or other modifications.

    Transforms follow the TorchRL pattern where they take an observation
    and return a (potentially modified) observation. This allows for
    flexible reward computation and observation augmentation.
    """

    @abstractmethod
    def __call__(self, observation: ObsT) -> ObsT:
        """Transform an observation.

        Args:
            observation: The input observation

        Returns:
            The transformed observation
        """
        pass


class Environment(ABC, Generic[ActT, ObsT, StateT]):
    """Base class for all environment servers following Gym/Gymnasium API.

    Args:
        transform: Optional transform to apply to observations
        rubric: Optional rubric for reward computation. When provided, the
            rubric's output can be used to set the observation's reward in step().

    Class Attributes:
        SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions.
            When True, multiple WebSocket connections can each have their own
            environment instance (up to max_concurrent_envs). When False (default),
            the environment should only be used with a single session at a time.

            Set this to True in your Environment subclass if:
            - The environment uses proper session isolation (e.g., unique working dirs)
            - No shared mutable state exists between instances
            - External resources (databases, APIs) can handle concurrent access

    Attributes:
        rubric: Optional rubric for computing rewards. Environments can set this
            in __init__ and use it in step() to compute observation rewards.
            Training infrastructure can access it for introspection:
                for name, r in env.rubric.named_rubrics():
                    print(f"{name}: {r.last_score}")

    See RFC 004 for rubric design: rfcs/004-rubrics.md
    """

    # Class-level flag indicating whether this environment supports concurrent sessions
    SUPPORTS_CONCURRENT_SESSIONS: bool = False

    # Optional rubric for reward computation
    rubric: Optional["Rubric"]

    def __init__(
        self,
        transform: Optional[Transform[ObsT]] = None,
        rubric: Optional["Rubric"] = None,
    ):
        self.transform = transform
        self.rubric = rubric

    @abstractmethod
    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs: Any,
    ) -> ObsT:
        """Reset the environment and return initial observation."""
        pass

    async def reset_async(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs: Any,
    ) -> ObsT:
        """Async version of reset. Default implementation calls sync reset.

        Override to provide true async implementation.
        """
        return self.reset(seed=seed, episode_id=episode_id, **kwargs)

    @abstractmethod
    def step(
        self,
        action: ActT,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> ObsT:
        """Take a step in the environment."""
        pass

    async def step_async(
        self,
        action: ActT,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> ObsT:
        """Async version of step. Default implementation calls sync step.

        Override to provide true async implementation.
        """
        return self.step(action, timeout_s=timeout_s, **kwargs)

    @property
    @abstractmethod
    def state(self) -> StateT:
        """Get the current environment state."""
        pass

    def get_metadata(self) -> EnvironmentMetadata:
        """
        Get metadata about this environment.

        Override this method to provide custom metadata for the environment.
        Default implementation returns basic metadata derived from class name.

        Returns:
            EnvironmentMetadata with environment information
        """
        return EnvironmentMetadata(
            name=self.__class__.__name__,
            description=f"{self.__class__.__name__} environment",
            version="1.0.0",
        )

    def _apply_transform(self, observation: ObsT) -> ObsT:
        """Apply transform if one is provided."""
        if self.transform is not None:
            return self.transform(observation)
        return observation

    def _apply_rubric(self, action: ActT, observation: ObsT) -> float:
        """Apply rubric if one is provided.

        Args:
            action: The action taken by the agent.
            observation: The resulting observation.

        Returns:
            Reward value from the rubric, or 0.0 if no rubric is set.

        Usage in step():
            def step(self, action: MyAction, ...) -> MyObservation:
                # ... execute action and create observation ...
                observation.reward = self._apply_rubric(action, observation)
                return observation
        """
        if self.rubric is not None:
            return self.rubric(action, observation)
        return 0.0

    async def _apply_rubric_async(self, action: ActT, observation: ObsT) -> float:
        """Apply rubric asynchronously if one is provided.

        Args:
            action: The action taken by the agent.
            observation: The resulting observation.

        Returns:
            Reward value from the rubric, or 0.0 if no rubric is set.

        Usage in step_async():
            async def step_async(self, action: MyAction, ...) -> MyObservation:
                # ... execute action and create observation ...
                observation.reward = await self._apply_rubric_async(action, observation)
                return observation
        """
        if self.rubric is not None:
            result = self.rubric(action, observation)
            # If rubric returns a coroutine, await it
            if inspect.iscoroutine(result):
                return await result
            return result
        return 0.0

    def _reset_rubric(self) -> None:
        """Reset the rubric state if one is provided.

        Call this in reset() to clear any trajectory state in the rubric.

        Usage in reset():
            def reset(self, ...) -> MyObservation:
                self._reset_rubric()
                # ... create initial observation ...
                return observation
        """
        if self.rubric is not None:
            self.rubric.reset()

    async def _reset_rubric_async(self) -> None:
        """Reset the rubric state asynchronously if one is provided.

        Call this in reset_async() to clear any trajectory state in the rubric.

        Usage in reset_async():
            async def reset_async(self, ...) -> MyObservation:
                await self._reset_rubric_async()
                # ... create initial observation ...
                return observation
        """
        if self.rubric is not None:
            # Check if rubric has async reset method
            if hasattr(self.rubric, "reset_async"):
                result = self.rubric.reset_async()
                if inspect.iscoroutine(result):
                    await result
            else:
                self.rubric.reset()

    def close(self) -> None:
        """Clean up resources used by the environment.

        Override this method to implement custom cleanup logic.
        Called when the environment is being destroyed or reset.
        """
        pass