burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
0ea7763 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, Protocol, TypeVar, TYPE_CHECKING
from typing_extensions import TypedDict
from .types import Action, Observation, State, EnvironmentMetadata
if TYPE_CHECKING:
from openenv.core.rubrics import Rubric
ActT = TypeVar("ActT", bound=Action)
ObsT = TypeVar("ObsT", bound=Observation)
StateT = TypeVar("StateT", bound=State)
class Message(TypedDict):
"""A message in a conversation.
Compatible with Huggingface chat template format.
"""
role: str
content: str
class ModelTokenizer(Protocol):
"""Protocol for tokenizers that support chat templates.
This protocol defines the interface that tokenizers must implement
to work with chat-based environments. It's compatible with
Huggingface transformers tokenizers.
"""
def apply_chat_template(
self,
conversation: list[Message],
tokenize: bool = True,
return_tensors: str | None = None,
**kwargs: Any,
) -> Any:
"""Apply a chat template to format and optionally tokenize a conversation.
Args:
conversation: List of message dictionaries with 'role' and 'content'
tokenize: Whether to tokenize the output
return_tensors: Format for returned tensors ('pt' for PyTorch)
**kwargs: Additional arguments
Returns:
Formatted and optionally tokenized conversation
"""
...
def decode(
self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
) -> str:
"""Decode token IDs back to text.
Args:
token_ids: Token IDs to decode
skip_special_tokens: Whether to skip special tokens in output
**kwargs: Additional arguments
Returns:
Decoded text string
"""
...
class Transform(ABC, Generic[ObsT]):
"""Transform observations to add rewards, metrics, or other modifications.
Transforms follow the TorchRL pattern where they take an observation
and return a (potentially modified) observation. This allows for
flexible reward computation and observation augmentation.
"""
@abstractmethod
def __call__(self, observation: ObsT) -> ObsT:
"""Transform an observation.
Args:
observation: The input observation
Returns:
The transformed observation
"""
pass
class Environment(ABC, Generic[ActT, ObsT, StateT]):
"""Base class for all environment servers following Gym/Gymnasium API.
Args:
transform: Optional transform to apply to observations
rubric: Optional rubric for reward computation. When provided, the
rubric's output can be used to set the observation's reward in step().
Class Attributes:
SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions.
When True, multiple WebSocket connections can each have their own
environment instance (up to max_concurrent_envs). When False (default),
the environment should only be used with a single session at a time.
Set this to True in your Environment subclass if:
- The environment uses proper session isolation (e.g., unique working dirs)
- No shared mutable state exists between instances
- External resources (databases, APIs) can handle concurrent access
Attributes:
rubric: Optional rubric for computing rewards. Environments can set this
in __init__ and use it in step() to compute observation rewards.
Training infrastructure can access it for introspection:
for name, r in env.rubric.named_rubrics():
print(f"{name}: {r.last_score}")
See RFC 004 for rubric design: rfcs/004-rubrics.md
"""
# Class-level flag indicating whether this environment supports concurrent sessions
SUPPORTS_CONCURRENT_SESSIONS: bool = False
# Optional rubric for reward computation
rubric: Optional["Rubric"]
def __init__(
self,
transform: Optional[Transform[ObsT]] = None,
rubric: Optional["Rubric"] = None,
):
self.transform = transform
self.rubric = rubric
@abstractmethod
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> ObsT:
"""Reset the environment and return initial observation."""
pass
async def reset_async(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> ObsT:
"""Async version of reset. Default implementation calls sync reset.
Override to provide true async implementation.
"""
return self.reset(seed=seed, episode_id=episode_id, **kwargs)
@abstractmethod
def step(
self,
action: ActT,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> ObsT:
"""Take a step in the environment."""
pass
async def step_async(
self,
action: ActT,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> ObsT:
"""Async version of step. Default implementation calls sync step.
Override to provide true async implementation.
"""
return self.step(action, timeout_s=timeout_s, **kwargs)
@property
@abstractmethod
def state(self) -> StateT:
"""Get the current environment state."""
pass
def get_metadata(self) -> EnvironmentMetadata:
"""
Get metadata about this environment.
Override this method to provide custom metadata for the environment.
Default implementation returns basic metadata derived from class name.
Returns:
EnvironmentMetadata with environment information
"""
return EnvironmentMetadata(
name=self.__class__.__name__,
description=f"{self.__class__.__name__} environment",
version="1.0.0",
)
def _apply_transform(self, observation: ObsT) -> ObsT:
"""Apply transform if one is provided."""
if self.transform is not None:
return self.transform(observation)
return observation
def _apply_rubric(self, action: ActT, observation: ObsT) -> float:
"""Apply rubric if one is provided.
Args:
action: The action taken by the agent.
observation: The resulting observation.
Returns:
Reward value from the rubric, or 0.0 if no rubric is set.
Usage in step():
def step(self, action: MyAction, ...) -> MyObservation:
# ... execute action and create observation ...
observation.reward = self._apply_rubric(action, observation)
return observation
"""
if self.rubric is not None:
return self.rubric(action, observation)
return 0.0
async def _apply_rubric_async(self, action: ActT, observation: ObsT) -> float:
"""Apply rubric asynchronously if one is provided.
Args:
action: The action taken by the agent.
observation: The resulting observation.
Returns:
Reward value from the rubric, or 0.0 if no rubric is set.
Usage in step_async():
async def step_async(self, action: MyAction, ...) -> MyObservation:
# ... execute action and create observation ...
observation.reward = await self._apply_rubric_async(action, observation)
return observation
"""
if self.rubric is not None:
result = self.rubric(action, observation)
# If rubric returns a coroutine, await it
if inspect.iscoroutine(result):
return await result
return result
return 0.0
def _reset_rubric(self) -> None:
"""Reset the rubric state if one is provided.
Call this in reset() to clear any trajectory state in the rubric.
Usage in reset():
def reset(self, ...) -> MyObservation:
self._reset_rubric()
# ... create initial observation ...
return observation
"""
if self.rubric is not None:
self.rubric.reset()
async def _reset_rubric_async(self) -> None:
"""Reset the rubric state asynchronously if one is provided.
Call this in reset_async() to clear any trajectory state in the rubric.
Usage in reset_async():
async def reset_async(self, ...) -> MyObservation:
await self._reset_rubric_async()
# ... create initial observation ...
return observation
"""
if self.rubric is not None:
# Check if rubric has async reset method
if hasattr(self.rubric, "reset_async"):
result = self.rubric.reset_async()
if inspect.iscoroutine(result):
await result
else:
self.rubric.reset()
def close(self) -> None:
"""Clean up resources used by the environment.
Override this method to implement custom cleanup logic.
Called when the environment is being destroyed or reset.
"""
pass