Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import sys | |
| from pathlib import Path | |
| # Add the project root to Python path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from aworld.core.task import Task | |
| from aworld.core.agent.swarm import Swarm | |
| from aworld.runner import Runners | |
| from aworld.agents.llm_agent import Agent | |
| from aworld.config.conf import AgentConfig, ContextRuleConfig, ModelConfig, OptimizationConfig, LlmCompressionConfig | |
| from aworld.core.context.base import Context | |
| from aworld.runners.hook.hook_factory import HookFactory | |
| from tests.base_test import BaseTest | |
| class TestContextManagement(BaseTest): | |
| """Test cases for Context Management system based on README examples""" | |
| def __init__(self): | |
| """Set up test fixtures""" | |
| self.mock_model_name = "qwen/qwen3-1.7b" | |
| self.mock_base_url = "http://localhost:1234/v1" | |
| self.mock_api_key = "lm-studio" | |
| os.environ["LLM_API_KEY"] = self.mock_api_key | |
| os.environ["LLM_BASE_URL"] = self.mock_base_url | |
| os.environ["LLM_MODEL_NAME"] = self.mock_model_name | |
| def init_agent(self, config_type: str = "1", context_rule: ContextRuleConfig = None): | |
| if config_type == "1": | |
| conf = AgentConfig( | |
| llm_model_name=self.mock_model_name, | |
| llm_base_url=self.mock_base_url, | |
| llm_api_key=self.mock_api_key | |
| ) | |
| else: | |
| conf = AgentConfig( | |
| llm_config=ModelConfig( | |
| llm_model_name=self.mock_model_name, | |
| llm_base_url=self.mock_base_url, | |
| llm_api_key=self.mock_api_key | |
| ) | |
| ) | |
| return Agent( | |
| conf=conf, | |
| name="my_agent" + str(random.randint(0, 1000000)), | |
| system_prompt="You are a helpful assistant.", | |
| agent_prompt="You are a helpful assistant.", | |
| context_rule=context_rule | |
| ) | |
| class _AssertRaisesContext: | |
| """Context manager for assertRaises""" | |
| def __init__(self, expected_exception): | |
| self.expected_exception = expected_exception | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| if exc_type is None: | |
| raise AssertionError(f"Expected {self.expected_exception.__name__} to be raised, but no exception was raised") | |
| if not issubclass(exc_type, self.expected_exception): | |
| raise AssertionError(f"Expected {self.expected_exception.__name__} to be raised, but got {exc_type.__name__}: {exc_value}") | |
| return True # Suppress the exception | |
| def fail(self, msg=None): | |
| """Fail immediately with the given message""" | |
| raise AssertionError(msg or "Test failed") | |
| def run_agent(self, input, agent: Agent): | |
| swarm = Swarm(agent, max_steps=1) | |
| return Runners.sync_run( | |
| input=input, | |
| swarm=swarm | |
| ) | |
| def run_multi_agent(self, input, agent1: Agent, agent2: Agent): | |
| swarm = Swarm(agent1, agent2, max_steps=1) | |
| return Runners.sync_run( | |
| input=input, | |
| swarm=swarm | |
| ) | |
| def run_task(self, context: Context, agent: Agent): | |
| swarm = Swarm(agent, max_steps=1) | |
| task = Task(input="""What is an agent.""", swarm=swarm, context=context) | |
| return Runners.sync_run_task(task) | |
| def test_default_context_configuration(self): | |
| # No need to explicitly configure context_rule, system automatically uses default configuration | |
| # Default configuration is equivalent to: | |
| # context_rule=ContextRuleConfig( | |
| # optimization_config=OptimizationConfig( | |
| # enabled=True, | |
| # max_token_budget_ratio=1.0 # Use 100% of context window | |
| # ), | |
| # llm_compression_config=LlmCompressionConfig( | |
| # enabled=False # Compression disabled by default | |
| # ) | |
| # ) | |
| mock_agent = self.init_agent("1") | |
| response = self.run_agent(input="""What is an agent. describe within 20 words""", agent=mock_agent) | |
| self.assertIsNotNone(response.answer) | |
| self.assertEqual(mock_agent.agent_context.model_config.llm_model_name, self.mock_model_name) | |
| # Test default context rule behavior | |
| self.assertIsNotNone(mock_agent.agent_context.context_rule) | |
| self.assertIsNotNone(mock_agent.agent_context.context_rule.optimization_config) | |
| def test_custom_context_configuration(self): | |
| """Test custom context configuration (README Configuration example)""" | |
| # Create custom context rules | |
| mock_agent = self.init_agent(context_rule=ContextRuleConfig( | |
| optimization_config=OptimizationConfig( | |
| enabled=True, | |
| max_token_budget_ratio=0.00015 | |
| ), | |
| llm_compression_config=LlmCompressionConfig( | |
| enabled=True, | |
| trigger_compress_token_length=100, | |
| compress_model=ModelConfig( | |
| llm_model_name=self.mock_model_name, | |
| llm_base_url=self.mock_base_url, | |
| llm_api_key=self.mock_api_key, | |
| ) | |
| ) | |
| )) | |
| response = self.run_agent(input="""describe What is an agent in details""", agent=mock_agent) | |
| self.assertIsNotNone(response.answer) | |
| # Test configuration values | |
| self.assertTrue(mock_agent.agent_context.context_rule.optimization_config.enabled) | |
| self.assertTrue(mock_agent.agent_context.context_rule.llm_compression_config.enabled) | |
| def test_state_management_and_recovery(self): | |
| class StateModifyAgent(Agent): | |
| async def async_policy(self, observation, info=None, **kwargs): | |
| result = await super().async_policy(observation, info, **kwargs) | |
| self.context.state['policy_executed'] = True | |
| return result | |
| class StateTrackingAgent(Agent): | |
| async def async_policy(self, observation, info=None, **kwargs): | |
| result = await super().async_policy(observation, info, **kwargs) | |
| assert self.context.state['policy_executed'] == True | |
| return result | |
| # Create custom agent instance | |
| custom_agent = StateModifyAgent( | |
| conf=AgentConfig( | |
| llm_model_name=self.mock_model_name, | |
| llm_base_url=self.mock_base_url, | |
| llm_api_key=self.mock_api_key | |
| ), | |
| name="state_modify_agent", | |
| system_prompt="You are a Python expert who provides detailed and practical answers.", | |
| agent_prompt="You are a Python expert who provides detailed and practical answers.", | |
| ) | |
| # Create a second agent for multi-agent testing | |
| second_agent = StateTrackingAgent( | |
| conf=AgentConfig( | |
| llm_model_name=self.mock_model_name, | |
| llm_base_url=self.mock_base_url, | |
| llm_api_key=self.mock_api_key | |
| ), | |
| name="state_tracking_agent", | |
| system_prompt="You are a helpful assistant.", | |
| agent_prompt="You are a helpful assistant.", | |
| ) | |
| response = self.run_multi_agent( | |
| input="What is an agent. describe within 20 words", | |
| agent1=custom_agent, | |
| agent2=second_agent | |
| ) | |
| self.assertIsNotNone(response.answer) | |
| # Verify state changes after execution | |
| self.assertTrue(custom_agent.context.state.get('policy_executed', True)) | |
| self.assertTrue(second_agent.agent_context.state.get('policy_executed', True)) | |
| class TestHookSystem(TestContextManagement): | |
| def __init__(self): | |
| super().__init__() | |
| def test_hook_registration(self): | |
| from tests.test_llm_hook import TestPreLLMHook, TestPostLLMHook | |
| """Test hook registration and retrieval""" | |
| # Test that hooks are registered in _cls attribute | |
| self.assertIn("TestPreLLMHook", HookFactory._cls) | |
| self.assertIn("TestPostLLMHook", HookFactory._cls) | |
| # Test hook creation using __call__ method | |
| pre_hook = HookFactory("TestPreLLMHook") | |
| post_hook = HookFactory("TestPostLLMHook") | |
| self.assertIsInstance(pre_hook, TestPreLLMHook) | |
| self.assertIsInstance(post_hook, TestPostLLMHook) | |
| def test_hook_execution(self): | |
| from tests.test_llm_hook import TestPreLLMHook, TestPostLLMHook | |
| mock_agent = self.init_agent("1") | |
| response = self.run_agent(input="""What is an agent. describe within 20 words""", agent=mock_agent) | |
| self.assertIsNotNone(response.answer) | |
| def test_task_context_transfer(self): | |
| from tests.test_context_hook import CheckContextPreLLMHook | |
| mock_agent = self.init_agent("1") | |
| context = Context.instance() | |
| context.state.update({"task": "What is an agent."}) | |
| self.run_task(context=context, agent=mock_agent) | |
| if __name__ == '__main__': | |
| testContextManagement = TestContextManagement() | |
| testContextManagement.test_default_context_configuration() | |
| testContextManagement.test_custom_context_configuration() | |
| testContextManagement.test_state_management_and_recovery() | |
| testHookSystem = TestHookSystem() | |
| testHookSystem.test_hook_registration() | |
| testHookSystem = TestHookSystem() | |
| testHookSystem.test_hook_execution() | |
| testHookSystem = TestHookSystem() | |
| testHookSystem.test_task_context_transfer() | |