chat_env-v2-1-0 / server /test_chat_env.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
8b391f5 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Test suite for ChatEnvironment.
Proper unit tests with assertions to verify correct behavior.
"""
import torch
from openenv.core.env_server.interfaces import Message
from ..models import ChatAction
from .chat_environment import ChatEnvironment
class MockTokenizer:
"""Mock tokenizer for testing without requiring transformers library."""
def apply_chat_template(
self,
conversation: list[Message],
tokenize: bool = True,
return_tensors: str | None = None,
**kwargs,
):
"""Mock implementation that creates deterministic token tensors from text."""
# Concatenate all message content
text = " ".join([msg["content"] for msg in conversation])
# Create deterministic tokens based on text content
# Use character codes modulo 256 to get valid token IDs
tokens = [ord(c) % 256 for c in text]
if return_tensors == "pt":
return torch.tensor([tokens])
return tokens
def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str:
"""Mock decode that reverses the encoding process."""
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
# Reverse the encoding: convert tokens back to characters
chars = [chr(t) for t in token_ids]
return "".join(chars)
def test_tokenization_consistency():
"""Test that tokenizing the same string produces the same tokens."""
tokenizer = MockTokenizer()
env = ChatEnvironment(tokenizer=tokenizer)
# Create the same message twice
message1: Message = {"role": "user", "content": "Hello, world!"}
message2: Message = {"role": "user", "content": "Hello, world!"}
# Convert to actions
action1 = env.message_to_action(message1)
action2 = env.message_to_action(message2)
# Verify tokens are identical
assert torch.equal(
action1.tokens, action2.tokens
), "Same message should produce identical tokens"
# Verify tokens are not empty
assert action1.tokens.numel() > 0, "Tokens should not be empty"
print("βœ“ test_tokenization_consistency passed")
def test_message_content_preservation():
"""Test that message content is preserved in the observation."""
tokenizer = MockTokenizer()
env = ChatEnvironment(tokenizer=tokenizer)
env.reset()
# Test with user message
user_content = "What is the capital of France?"
user_message: Message = {"role": "user", "content": user_content}
action = env.message_to_action(user_message)
obs = env.step(action)
# The last message should have the decoded content
assert len(obs.messages) > 0, "Observation should have at least one message"
last_message = obs.messages[-1]
# Verify the decoded content matches what we sent
# Note: The environment decodes the tokens, so we verify the round-trip
decoded_content = last_message["content"]
assert decoded_content == user_content, (
f"Message content should be preserved. "
f"Expected: {user_content}, Got: {decoded_content}"
)
# Test with assistant message
assistant_content = "The capital of France is Paris."
assistant_message: Message = {"role": "assistant", "content": assistant_content}
action = env.message_to_action(assistant_message)
obs = env.step(action)
# Verify the last message has the assistant content
assert len(obs.messages) >= 2, "Should have at least 2 messages now"
last_message = obs.messages[-1]
decoded_content = last_message["content"]
assert decoded_content == assistant_content, (
f"Assistant message content should be preserved. "
f"Expected: {assistant_content}, Got: {decoded_content}"
)
print("βœ“ test_message_content_preservation passed")
def test_system_prompt_preserved():
"""Test that system prompt is preserved after reset."""
tokenizer = MockTokenizer()
system_prompt = "You are a helpful assistant."
env = ChatEnvironment(tokenizer=tokenizer, system_prompt=system_prompt)
# Check after initialization
obs = env.reset()
assert len(obs.messages) == 1, "Should have exactly one message (system prompt)"
assert obs.messages[0]["role"] == "system", "First message should have system role"
assert (
obs.messages[0]["content"] == system_prompt
), "System prompt content should match"
# Add some messages
action = env.message_to_action({"role": "user", "content": "Hello"})
env.step(action)
# Reset and verify system prompt is still there
obs = env.reset()
assert len(obs.messages) == 1, "After reset, should only have system prompt"
assert (
obs.messages[0]["content"] == system_prompt
), "System prompt should be preserved after reset"
print("βœ“ test_system_prompt_preserved passed")
def test_token_history_accumulation():
"""Test that tokens accumulate correctly in the observation."""
tokenizer = MockTokenizer()
env = ChatEnvironment(tokenizer=tokenizer)
obs = env.reset()
initial_token_count = obs.tokens.numel()
# Step with first message
message1 = {"role": "user", "content": "Hi"}
action1 = env.message_to_action(message1)
obs1 = env.step(action1)
token_count_1 = obs1.tokens.numel()
# Tokens should increase
assert token_count_1 > initial_token_count, "Token count should increase after step"
# Step with second message
message2 = {"role": "assistant", "content": "Hello there"}
action2 = env.message_to_action(message2)
obs2 = env.step(action2)
token_count_2 = obs2.tokens.numel()
# Tokens should continue to accumulate
assert (
token_count_2 > token_count_1
), "Token count should keep increasing with more messages"
# Verify tokens are the concatenation of both messages
expected_tokens = torch.cat([action1.tokens.flatten(), action2.tokens.flatten()])
assert torch.equal(
obs2.tokens, expected_tokens
), "Tokens should be concatenation of all actions"
print("βœ“ test_token_history_accumulation passed")
def test_direct_token_action():
"""Test creating actions directly from tokens."""
tokenizer = MockTokenizer()
env = ChatEnvironment(tokenizer=tokenizer)
env.reset()
# Create raw tokens
raw_tokens = torch.tensor([[72, 101, 108, 108, 111]]) # ASCII for "Hello"
action = ChatAction(tokens=raw_tokens)
# Step with raw tokens
obs = env.step(action)
# Verify message was added
assert len(obs.messages) == 1, "Should have one message"
assert obs.messages[0]["role"] == "assistant", "Should default to assistant role"
# Verify tokens match what we sent (flattened)
assert torch.equal(
obs.tokens, raw_tokens.flatten()
), "Observation tokens should match input tokens"
print("βœ“ test_direct_token_action passed")
def test_empty_tokens_validation():
"""Test that empty tokens raise a ValueError."""
try:
action = ChatAction(tokens=torch.tensor([]))
assert False, "Should have raised ValueError for empty tokens"
except ValueError as e:
assert "empty" in str(e).lower(), "Error message should mention empty tokens"
print("βœ“ test_empty_tokens_validation passed")
def test_message_validation():
"""Test that invalid messages raise appropriate errors."""
tokenizer = MockTokenizer()
env = ChatEnvironment(tokenizer=tokenizer)
# Test missing 'role' key
try:
env.message_to_action({"content": "test"}) # type: ignore
assert False, "Should have raised error for missing 'role' key"
except (ValueError, KeyError):
pass
# Test missing 'content' key
try:
env.message_to_action({"role": "user"}) # type: ignore
assert False, "Should have raised error for missing 'content' key"
except (ValueError, KeyError):
pass
# Test None content
try:
env.message_to_action({"role": "user", "content": None}) # type: ignore
assert False, "Should have raised error for None content"
except ValueError:
pass
print("βœ“ test_message_validation passed")
def test_reset_clears_history():
"""Test that reset properly clears all message and token history."""
tokenizer = MockTokenizer()
env = ChatEnvironment(tokenizer=tokenizer, system_prompt="System message")
# Add some messages
obs1 = env.reset()
initial_messages = len(obs1.messages)
action = env.message_to_action({"role": "user", "content": "Test message"})
obs2 = env.step(action)
# Verify message was added
assert (
len(obs2.messages) > initial_messages
), "Message should be added after step"
# Reset
obs3 = env.reset()
# Verify we're back to just the system prompt
assert (
len(obs3.messages) == initial_messages
), "Reset should clear history back to initial state"
assert (
obs3.messages[0]["content"] == "System message"
), "System prompt should be preserved"
print("βœ“ test_reset_clears_history passed")
def main():
"""Run all tests."""
print("\n" + "=" * 60)
print("ChatEnvironment Test Suite")
print("=" * 60 + "\n")
tests = [
test_tokenization_consistency,
test_message_content_preservation,
test_system_prompt_preserved,
test_token_history_accumulation,
test_direct_token_action,
test_empty_tokens_validation,
test_message_validation,
test_reset_clears_history,
]
failed = []
for test in tests:
try:
test()
except AssertionError as e:
print(f"βœ— {test.__name__} failed: {e}")
failed.append(test.__name__)
except Exception as e:
print(f"βœ— {test.__name__} errored: {e}")
import traceback
traceback.print_exc()
failed.append(test.__name__)
print("\n" + "=" * 60)
if not failed:
print(f"βœ“ All {len(tests)} tests passed!")
print("=" * 60)
return 0
else:
print(f"βœ— {len(failed)}/{len(tests)} tests failed:")
for name in failed:
print(f" - {name}")
print("=" * 60)
return 1
if __name__ == "__main__":
exit(main())