Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Test suite for ChatEnvironment. | |
| Proper unit tests with assertions to verify correct behavior. | |
| """ | |
| import torch | |
| from openenv.core.env_server.interfaces import Message | |
| from ..models import ChatAction | |
| from .chat_environment import ChatEnvironment | |
| class MockTokenizer: | |
| """Mock tokenizer for testing without requiring transformers library.""" | |
| def apply_chat_template( | |
| self, | |
| conversation: list[Message], | |
| tokenize: bool = True, | |
| return_tensors: str | None = None, | |
| **kwargs, | |
| ): | |
| """Mock implementation that creates deterministic token tensors from text.""" | |
| # Concatenate all message content | |
| text = " ".join([msg["content"] for msg in conversation]) | |
| # Create deterministic tokens based on text content | |
| # Use character codes modulo 256 to get valid token IDs | |
| tokens = [ord(c) % 256 for c in text] | |
| if return_tensors == "pt": | |
| return torch.tensor([tokens]) | |
| return tokens | |
| def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str: | |
| """Mock decode that reverses the encoding process.""" | |
| if isinstance(token_ids, torch.Tensor): | |
| token_ids = token_ids.tolist() | |
| # Reverse the encoding: convert tokens back to characters | |
| chars = [chr(t) for t in token_ids] | |
| return "".join(chars) | |
| def test_tokenization_consistency(): | |
| """Test that tokenizing the same string produces the same tokens.""" | |
| tokenizer = MockTokenizer() | |
| env = ChatEnvironment(tokenizer=tokenizer) | |
| # Create the same message twice | |
| message1: Message = {"role": "user", "content": "Hello, world!"} | |
| message2: Message = {"role": "user", "content": "Hello, world!"} | |
| # Convert to actions | |
| action1 = env.message_to_action(message1) | |
| action2 = env.message_to_action(message2) | |
| # Verify tokens are identical | |
| assert torch.equal( | |
| action1.tokens, action2.tokens | |
| ), "Same message should produce identical tokens" | |
| # Verify tokens are not empty | |
| assert action1.tokens.numel() > 0, "Tokens should not be empty" | |
| print("β test_tokenization_consistency passed") | |
| def test_message_content_preservation(): | |
| """Test that message content is preserved in the observation.""" | |
| tokenizer = MockTokenizer() | |
| env = ChatEnvironment(tokenizer=tokenizer) | |
| env.reset() | |
| # Test with user message | |
| user_content = "What is the capital of France?" | |
| user_message: Message = {"role": "user", "content": user_content} | |
| action = env.message_to_action(user_message) | |
| obs = env.step(action) | |
| # The last message should have the decoded content | |
| assert len(obs.messages) > 0, "Observation should have at least one message" | |
| last_message = obs.messages[-1] | |
| # Verify the decoded content matches what we sent | |
| # Note: The environment decodes the tokens, so we verify the round-trip | |
| decoded_content = last_message["content"] | |
| assert decoded_content == user_content, ( | |
| f"Message content should be preserved. " | |
| f"Expected: {user_content}, Got: {decoded_content}" | |
| ) | |
| # Test with assistant message | |
| assistant_content = "The capital of France is Paris." | |
| assistant_message: Message = {"role": "assistant", "content": assistant_content} | |
| action = env.message_to_action(assistant_message) | |
| obs = env.step(action) | |
| # Verify the last message has the assistant content | |
| assert len(obs.messages) >= 2, "Should have at least 2 messages now" | |
| last_message = obs.messages[-1] | |
| decoded_content = last_message["content"] | |
| assert decoded_content == assistant_content, ( | |
| f"Assistant message content should be preserved. " | |
| f"Expected: {assistant_content}, Got: {decoded_content}" | |
| ) | |
| print("β test_message_content_preservation passed") | |
| def test_system_prompt_preserved(): | |
| """Test that system prompt is preserved after reset.""" | |
| tokenizer = MockTokenizer() | |
| system_prompt = "You are a helpful assistant." | |
| env = ChatEnvironment(tokenizer=tokenizer, system_prompt=system_prompt) | |
| # Check after initialization | |
| obs = env.reset() | |
| assert len(obs.messages) == 1, "Should have exactly one message (system prompt)" | |
| assert obs.messages[0]["role"] == "system", "First message should have system role" | |
| assert ( | |
| obs.messages[0]["content"] == system_prompt | |
| ), "System prompt content should match" | |
| # Add some messages | |
| action = env.message_to_action({"role": "user", "content": "Hello"}) | |
| env.step(action) | |
| # Reset and verify system prompt is still there | |
| obs = env.reset() | |
| assert len(obs.messages) == 1, "After reset, should only have system prompt" | |
| assert ( | |
| obs.messages[0]["content"] == system_prompt | |
| ), "System prompt should be preserved after reset" | |
| print("β test_system_prompt_preserved passed") | |
| def test_token_history_accumulation(): | |
| """Test that tokens accumulate correctly in the observation.""" | |
| tokenizer = MockTokenizer() | |
| env = ChatEnvironment(tokenizer=tokenizer) | |
| obs = env.reset() | |
| initial_token_count = obs.tokens.numel() | |
| # Step with first message | |
| message1 = {"role": "user", "content": "Hi"} | |
| action1 = env.message_to_action(message1) | |
| obs1 = env.step(action1) | |
| token_count_1 = obs1.tokens.numel() | |
| # Tokens should increase | |
| assert token_count_1 > initial_token_count, "Token count should increase after step" | |
| # Step with second message | |
| message2 = {"role": "assistant", "content": "Hello there"} | |
| action2 = env.message_to_action(message2) | |
| obs2 = env.step(action2) | |
| token_count_2 = obs2.tokens.numel() | |
| # Tokens should continue to accumulate | |
| assert ( | |
| token_count_2 > token_count_1 | |
| ), "Token count should keep increasing with more messages" | |
| # Verify tokens are the concatenation of both messages | |
| expected_tokens = torch.cat([action1.tokens.flatten(), action2.tokens.flatten()]) | |
| assert torch.equal( | |
| obs2.tokens, expected_tokens | |
| ), "Tokens should be concatenation of all actions" | |
| print("β test_token_history_accumulation passed") | |
| def test_direct_token_action(): | |
| """Test creating actions directly from tokens.""" | |
| tokenizer = MockTokenizer() | |
| env = ChatEnvironment(tokenizer=tokenizer) | |
| env.reset() | |
| # Create raw tokens | |
| raw_tokens = torch.tensor([[72, 101, 108, 108, 111]]) # ASCII for "Hello" | |
| action = ChatAction(tokens=raw_tokens) | |
| # Step with raw tokens | |
| obs = env.step(action) | |
| # Verify message was added | |
| assert len(obs.messages) == 1, "Should have one message" | |
| assert obs.messages[0]["role"] == "assistant", "Should default to assistant role" | |
| # Verify tokens match what we sent (flattened) | |
| assert torch.equal( | |
| obs.tokens, raw_tokens.flatten() | |
| ), "Observation tokens should match input tokens" | |
| print("β test_direct_token_action passed") | |
| def test_empty_tokens_validation(): | |
| """Test that empty tokens raise a ValueError.""" | |
| try: | |
| action = ChatAction(tokens=torch.tensor([])) | |
| assert False, "Should have raised ValueError for empty tokens" | |
| except ValueError as e: | |
| assert "empty" in str(e).lower(), "Error message should mention empty tokens" | |
| print("β test_empty_tokens_validation passed") | |
| def test_message_validation(): | |
| """Test that invalid messages raise appropriate errors.""" | |
| tokenizer = MockTokenizer() | |
| env = ChatEnvironment(tokenizer=tokenizer) | |
| # Test missing 'role' key | |
| try: | |
| env.message_to_action({"content": "test"}) # type: ignore | |
| assert False, "Should have raised error for missing 'role' key" | |
| except (ValueError, KeyError): | |
| pass | |
| # Test missing 'content' key | |
| try: | |
| env.message_to_action({"role": "user"}) # type: ignore | |
| assert False, "Should have raised error for missing 'content' key" | |
| except (ValueError, KeyError): | |
| pass | |
| # Test None content | |
| try: | |
| env.message_to_action({"role": "user", "content": None}) # type: ignore | |
| assert False, "Should have raised error for None content" | |
| except ValueError: | |
| pass | |
| print("β test_message_validation passed") | |
| def test_reset_clears_history(): | |
| """Test that reset properly clears all message and token history.""" | |
| tokenizer = MockTokenizer() | |
| env = ChatEnvironment(tokenizer=tokenizer, system_prompt="System message") | |
| # Add some messages | |
| obs1 = env.reset() | |
| initial_messages = len(obs1.messages) | |
| action = env.message_to_action({"role": "user", "content": "Test message"}) | |
| obs2 = env.step(action) | |
| # Verify message was added | |
| assert ( | |
| len(obs2.messages) > initial_messages | |
| ), "Message should be added after step" | |
| # Reset | |
| obs3 = env.reset() | |
| # Verify we're back to just the system prompt | |
| assert ( | |
| len(obs3.messages) == initial_messages | |
| ), "Reset should clear history back to initial state" | |
| assert ( | |
| obs3.messages[0]["content"] == "System message" | |
| ), "System prompt should be preserved" | |
| print("β test_reset_clears_history passed") | |
| def main(): | |
| """Run all tests.""" | |
| print("\n" + "=" * 60) | |
| print("ChatEnvironment Test Suite") | |
| print("=" * 60 + "\n") | |
| tests = [ | |
| test_tokenization_consistency, | |
| test_message_content_preservation, | |
| test_system_prompt_preserved, | |
| test_token_history_accumulation, | |
| test_direct_token_action, | |
| test_empty_tokens_validation, | |
| test_message_validation, | |
| test_reset_clears_history, | |
| ] | |
| failed = [] | |
| for test in tests: | |
| try: | |
| test() | |
| except AssertionError as e: | |
| print(f"β {test.__name__} failed: {e}") | |
| failed.append(test.__name__) | |
| except Exception as e: | |
| print(f"β {test.__name__} errored: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| failed.append(test.__name__) | |
| print("\n" + "=" * 60) | |
| if not failed: | |
| print(f"β All {len(tests)} tests passed!") | |
| print("=" * 60) | |
| return 0 | |
| else: | |
| print(f"β {len(failed)}/{len(tests)} tests failed:") | |
| for name in failed: | |
| print(f" - {name}") | |
| print("=" * 60) | |
| return 1 | |
| if __name__ == "__main__": | |
| exit(main()) | |