# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Data models for the Chat Environment. The Chat environment provides a chat-based interface for LLMs with support for tokenization and message history management. """ from pydantic import Field import torch from openenv.core.env_server.interfaces import Message from openenv.core.env_server.types import Action, Observation, State class ChatAction(Action): """Action for chat environments. Contains tokens that represent the action to be taken. This interfaces directly with models. """ tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([])) def __post_init__(self): """Validate required Fields after initialization.""" if self.tokens.numel() == 0: raise ValueError("tokens is required and cannot be empty") class ChatState(State): """State of the ChatEnvironment containing message history.""" history_messages: list[Message] = Field(default_factory=list) history_tokens: list[torch.Tensor] = Field( default_factory=list ) # Same len as messages class ChatObservation(Observation): """Observation returned by ChatEnvironment. Contains the message history in Huggingface format (list of dicts with role/content) and the tokenized representation of the entire conversation. The environment owns the tokenizer and generates the tokens from the messages. Example: messages = [ {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": "How tall is the Eiffel Tower?"}, ] tokens = tensor([1, 2, 3, 4, 5, ...]) # tokenized entire conversation """ messages: list[Message] = Field(default_factory=list) tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([])) # Inherited Fields from Observation ABC: reward, done, metadata