Spaces:
Sleeping
Sleeping
| import unittest | |
| from unittest.mock import MagicMock, patch | |
| import sys | |
| import os | |
| # Add parent directory to path | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from app_module import chat_agent_stream | |
| class TestAgentSimulation(unittest.TestCase): | |
| def test_history_propagation(self, mock_detect, mock_get_llm): | |
| # Setup Mocks | |
| mock_model = MagicMock() | |
| mock_processor = MagicMock() | |
| mock_get_llm.return_value = (mock_model, mock_processor) | |
| mock_detect.return_value = "English" | |
| # Mock Processor behavior | |
| mock_processor.apply_chat_template.return_value = MagicMock() # input_ids | |
| # Simulating conversation state | |
| # Turn 1: User says "My name is Julian" | |
| history = [] | |
| query1 = "My name is Julian" | |
| # Run Agent (mocking generation) | |
| mock_streamer = MagicMock() | |
| mock_streamer.__iter__.return_value = ["Hello", " Julian", "."] | |
| with patch('app_module.TextIteratorStreamer', return_value=mock_streamer): | |
| # Execute Turn 1 | |
| responses = list(chat_agent_stream(query1, history)) | |
| # Verify input to model for Turn 1 | |
| call_args_1 = mock_processor.apply_chat_template.call_args_list[0] | |
| messages_1 = call_args_1[0][0] | |
| self.assertEqual(messages_1[-1]['content'][0]['text'], query1) | |
| # Update history manually as wrapper would | |
| history.append({"role": "user", "content": query1}) | |
| history.append({"role": "assistant", "content": "Hello Julian."}) | |
| # Turn 2: User says "What is my name?" | |
| query2 = "What is my name?" | |
| # Run Agent Turn 2 | |
| mock_streamer.__iter__.return_value = ["Your", " name", " is", " Julian."] | |
| responses = list(chat_agent_stream(query2, history)) | |
| # Verify input to model for Turn 2 | |
| call_args_2 = mock_processor.apply_chat_template.call_args_list[-1] | |
| messages_2 = call_args_2[0][0] | |
| # Check if "My name is Julian" is in the messages | |
| found_history = False | |
| for m in messages_2: | |
| content_str = str(m['content']) | |
| if "My name is Julian" in content_str: | |
| found_history = True | |
| break | |
| self.assertTrue(found_history, "Agent input messages did NOT contain previous user instruction!") | |
| def test_assistant_led_conversation(self, mock_detect, mock_get_llm): | |
| """Test that history starting with an assistant turn doesn't break role alternation.""" | |
| # Setup Mocks | |
| mock_model = MagicMock() | |
| mock_processor = MagicMock() | |
| mock_get_llm.return_value = (mock_model, mock_processor) | |
| mock_detect.return_value = "English" | |
| mock_processor.apply_chat_template.return_value = MagicMock() | |
| # Scenario: History starts with an assistant welcome message | |
| history = [{"role": "assistant", "content": "Welcome Seeker!"}] | |
| query = "Tell me my fate" | |
| with patch('app_module.TextIteratorStreamer', return_value=MagicMock()): | |
| list(chat_agent_stream(query, history)) | |
| # Verify alternation | |
| call_args = mock_processor.apply_chat_template.call_args_list[0] | |
| messages = call_args[0][0] | |
| # Expected roles: [system, user acknowledgement, assistant welcome, user query] | |
| roles = [m['role'] for m in messages] | |
| # Alternation check | |
| for i in range(len(roles) - 1): | |
| self.assertNotEqual(roles[i], roles[i+1], f"Roles at indices {i} and {i+1} do not alternate: {roles}") | |
| # Ensure index 1 is 'user' (after system) | |
| self.assertEqual(roles[0], "system") | |
| self.assertEqual(roles[1], "user") | |
| if __name__ == '__main__': | |
| unittest.main() | |