|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any, Protocol, TypedDict |
|
|
|
|
|
from .types import Action, Observation, State |
|
|
|
|
|
|
|
|
class Message(TypedDict): |
|
|
"""A message in a conversation. |
|
|
|
|
|
Compatible with Huggingface chat template format. |
|
|
""" |
|
|
|
|
|
role: str |
|
|
content: str |
|
|
|
|
|
|
|
|
class ModelTokenizer(Protocol): |
|
|
"""Protocol for tokenizers that support chat templates. |
|
|
|
|
|
This protocol defines the interface that tokenizers must implement |
|
|
to work with chat-based environments. It's compatible with |
|
|
Huggingface transformers tokenizers. |
|
|
""" |
|
|
|
|
|
def apply_chat_template( |
|
|
self, |
|
|
conversation: list[Message], |
|
|
tokenize: bool = True, |
|
|
return_tensors: str | None = None, |
|
|
**kwargs: Any, |
|
|
) -> Any: |
|
|
"""Apply a chat template to format and optionally tokenize a conversation. |
|
|
|
|
|
Args: |
|
|
conversation: List of message dictionaries with 'role' and 'content' |
|
|
tokenize: Whether to tokenize the output |
|
|
return_tensors: Format for returned tensors ('pt' for PyTorch) |
|
|
**kwargs: Additional arguments |
|
|
|
|
|
Returns: |
|
|
Formatted and optionally tokenized conversation |
|
|
""" |
|
|
... |
|
|
|
|
|
def decode( |
|
|
self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any |
|
|
) -> str: |
|
|
"""Decode token IDs back to text. |
|
|
|
|
|
Args: |
|
|
token_ids: Token IDs to decode |
|
|
skip_special_tokens: Whether to skip special tokens in output |
|
|
**kwargs: Additional arguments |
|
|
|
|
|
Returns: |
|
|
Decoded text string |
|
|
""" |
|
|
... |
|
|
|
|
|
|
|
|
class Transform(ABC): |
|
|
"""Transform observations to add rewards, metrics, or other modifications. |
|
|
|
|
|
Transforms follow the TorchRL pattern where they take an observation |
|
|
and return a (potentially modified) observation. This allows for |
|
|
flexible reward computation and observation augmentation. |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
def __call__(self, observation: Observation) -> Observation: |
|
|
"""Transform an observation. |
|
|
|
|
|
Args: |
|
|
observation: The input observation |
|
|
|
|
|
Returns: |
|
|
The transformed observation |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class Environment(ABC): |
|
|
"""Base class for all environment servers following Gym/Gymnasium API. |
|
|
|
|
|
Args: |
|
|
transform: Optional transform to apply to observations |
|
|
""" |
|
|
|
|
|
def __init__(self, transform: Transform | None = None): |
|
|
self.transform = transform |
|
|
|
|
|
@abstractmethod |
|
|
def reset(self) -> Observation: |
|
|
"""Reset the environment and return initial observation.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def step(self, action: Action) -> Observation: |
|
|
"""Take a step in the environment.""" |
|
|
pass |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def state(self) -> State: |
|
|
"""Get the current environment state.""" |
|
|
pass |
|
|
|
|
|
def _apply_transform(self, observation: Observation) -> Observation: |
|
|
"""Apply transform if one is provided.""" |
|
|
if self.transform is not None: |
|
|
return self.transform(observation) |
|
|
return observation |
|
|
|