File size: 2,045 Bytes
8b391f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Data models for the Chat Environment.

The Chat environment provides a chat-based interface for LLMs with support
for tokenization and message history management.
"""

from pydantic import Field

import torch

from openenv.core.env_server.interfaces import Message
from openenv.core.env_server.types import Action, Observation, State


class ChatAction(Action):
    """Action for chat environments.

    Contains tokens that represent the action to be taken.
    This interfaces directly with models.
    """

    tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([]))

    def __post_init__(self):
        """Validate required Fields after initialization."""
        if self.tokens.numel() == 0:
            raise ValueError("tokens is required and cannot be empty")


class ChatState(State):
    """State of the ChatEnvironment containing message history."""

    history_messages: list[Message] = Field(default_factory=list)
    history_tokens: list[torch.Tensor] = Field(
        default_factory=list
    )  # Same len as messages


class ChatObservation(Observation):
    """Observation returned by ChatEnvironment.

    Contains the message history in Huggingface format (list of dicts with role/content)
    and the tokenized representation of the entire conversation.

    The environment owns the tokenizer and generates the tokens from the messages.

    Example:
    messages = [
     {"role": "system", "content": "You are a helpful assistant"},
     {"role": "user", "content": "How tall is the Eiffel Tower?"},
    ]
    tokens = tensor([1, 2, 3, 4, 5, ...])  # tokenized entire conversation
    """

    messages: list[Message] = Field(default_factory=list)
    tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([]))
    # Inherited Fields from Observation ABC: reward, done, metadata