| | |
| | |
| | |
| | |
| | |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import Any, Protocol, TypedDict |
| |
|
| | from .types import Action, Observation, State |
| |
|
| |
|
| | class Message(TypedDict): |
| | """A message in a conversation. |
| | |
| | Compatible with Huggingface chat template format. |
| | """ |
| |
|
| | role: str |
| | content: str |
| |
|
| |
|
| | class ModelTokenizer(Protocol): |
| | """Protocol for tokenizers that support chat templates. |
| | |
| | This protocol defines the interface that tokenizers must implement |
| | to work with chat-based environments. It's compatible with |
| | Huggingface transformers tokenizers. |
| | """ |
| |
|
| | def apply_chat_template( |
| | self, |
| | conversation: list[Message], |
| | tokenize: bool = True, |
| | return_tensors: str | None = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | """Apply a chat template to format and optionally tokenize a conversation. |
| | |
| | Args: |
| | conversation: List of message dictionaries with 'role' and 'content' |
| | tokenize: Whether to tokenize the output |
| | return_tensors: Format for returned tensors ('pt' for PyTorch) |
| | **kwargs: Additional arguments |
| | |
| | Returns: |
| | Formatted and optionally tokenized conversation |
| | """ |
| | ... |
| |
|
| | def decode( |
| | self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any |
| | ) -> str: |
| | """Decode token IDs back to text. |
| | |
| | Args: |
| | token_ids: Token IDs to decode |
| | skip_special_tokens: Whether to skip special tokens in output |
| | **kwargs: Additional arguments |
| | |
| | Returns: |
| | Decoded text string |
| | """ |
| | ... |
| |
|
| |
|
| | class Transform(ABC): |
| | """Transform observations to add rewards, metrics, or other modifications. |
| | |
| | Transforms follow the TorchRL pattern where they take an observation |
| | and return a (potentially modified) observation. This allows for |
| | flexible reward computation and observation augmentation. |
| | """ |
| |
|
| | @abstractmethod |
| | def __call__(self, observation: Observation) -> Observation: |
| | """Transform an observation. |
| | |
| | Args: |
| | observation: The input observation |
| | |
| | Returns: |
| | The transformed observation |
| | """ |
| | pass |
| |
|
| |
|
| | class Environment(ABC): |
| | """Base class for all environment servers following Gym/Gymnasium API. |
| | |
| | Args: |
| | transform: Optional transform to apply to observations |
| | """ |
| |
|
| | def __init__(self, transform: Transform | None = None): |
| | self.transform = transform |
| |
|
| | @abstractmethod |
| | def reset(self) -> Observation: |
| | """Reset the environment and return initial observation.""" |
| | pass |
| |
|
| | @abstractmethod |
| | def step(self, action: Action) -> Observation: |
| | """Take a step in the environment.""" |
| | pass |
| |
|
| | @property |
| | @abstractmethod |
| | def state(self) -> State: |
| | """Get the current environment state.""" |
| | pass |
| |
|
| | def _apply_transform(self, observation: Observation) -> Observation: |
| | """Apply transform if one is provided.""" |
| | if self.transform is not None: |
| | return self.transform(observation) |
| | return observation |
| |
|