chat_env-v2-1-0 / models.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
8b391f5 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Chat Environment.
The Chat environment provides a chat-based interface for LLMs with support
for tokenization and message history management.
"""
from pydantic import Field
import torch
from openenv.core.env_server.interfaces import Message
from openenv.core.env_server.types import Action, Observation, State
class ChatAction(Action):
"""Action for chat environments.
Contains tokens that represent the action to be taken.
This interfaces directly with models.
"""
tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([]))
def __post_init__(self):
"""Validate required Fields after initialization."""
if self.tokens.numel() == 0:
raise ValueError("tokens is required and cannot be empty")
class ChatState(State):
"""State of the ChatEnvironment containing message history."""
history_messages: list[Message] = Field(default_factory=list)
history_tokens: list[torch.Tensor] = Field(
default_factory=list
) # Same len as messages
class ChatObservation(Observation):
"""Observation returned by ChatEnvironment.
Contains the message history in Huggingface format (list of dicts with role/content)
and the tokenized representation of the entire conversation.
The environment owns the tokenizer and generates the tokens from the messages.
Example:
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "How tall is the Eiffel Tower?"},
]
tokens = tensor([1, 2, 3, 4, 5, ...]) # tokenized entire conversation
"""
messages: list[Message] = Field(default_factory=list)
tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([]))
# Inherited Fields from Observation ABC: reward, done, metadata