File size: 4,276 Bytes
e09f395
 
 
 
 
 
 
 
 
 
 
 
 
d1af7e9
e09f395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1af7e9
e09f395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1af7e9
e09f395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de6582c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1af7e9
e09f395
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import unittest
from unittest.mock import MagicMock, patch
import sys
import os

# Add parent directory to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from app_module import chat_agent_stream

class TestAgentSimulation(unittest.TestCase):
    @patch('app_module.get_llm')
    @patch('app_module.detect_language')
    def test_history_propagation(self, mock_detect, mock_get_llm):
        # Setup Mocks
        mock_model = MagicMock()
        mock_processor = MagicMock()
        mock_get_llm.return_value = (mock_model, mock_processor)
        mock_detect.return_value = "English"
        
        # Mock Processor behavior
        mock_processor.apply_chat_template.return_value = MagicMock() # input_ids
        
        # Simulating conversation state
        # Turn 1: User says "My name is Julian"
        history = []
        query1 = "My name is Julian"
        
        # Run Agent (mocking generation)
        mock_streamer = MagicMock()
        mock_streamer.__iter__.return_value = ["Hello", " Julian", "."]
        
        with patch('app_module.TextIteratorStreamer', return_value=mock_streamer):
             # Execute Turn 1
             responses = list(chat_agent_stream(query1, history))
             
             # Verify input to model for Turn 1
             call_args_1 = mock_processor.apply_chat_template.call_args_list[0]
             messages_1 = call_args_1[0][0]
             
             self.assertEqual(messages_1[-1]['content'][0]['text'], query1)
             
             # Update history manually as wrapper would
             history.append({"role": "user", "content": query1})
             history.append({"role": "assistant", "content": "Hello Julian."})
             
             # Turn 2: User says "What is my name?"
             query2 = "What is my name?"
             
             # Run Agent Turn 2
             mock_streamer.__iter__.return_value = ["Your", " name", " is", " Julian."]
             responses = list(chat_agent_stream(query2, history))
             
             # Verify input to model for Turn 2
             call_args_2 = mock_processor.apply_chat_template.call_args_list[-1]
             messages_2 = call_args_2[0][0]
             
             # Check if "My name is Julian" is in the messages
             found_history = False
             for m in messages_2:
                 content_str = str(m['content'])
                 if "My name is Julian" in content_str:
                     found_history = True
                     break
             
             self.assertTrue(found_history, "Agent input messages did NOT contain previous user instruction!")

    @patch('app_module.get_llm')
    @patch('app_module.detect_language')
    def test_assistant_led_conversation(self, mock_detect, mock_get_llm):
        """Test that history starting with an assistant turn doesn't break role alternation."""
        # Setup Mocks
        mock_model = MagicMock()
        mock_processor = MagicMock()
        mock_get_llm.return_value = (mock_model, mock_processor)
        mock_detect.return_value = "English"
        mock_processor.apply_chat_template.return_value = MagicMock()
        
        # Scenario: History starts with an assistant welcome message
        history = [{"role": "assistant", "content": "Welcome Seeker!"}]
        query = "Tell me my fate"
        
        with patch('app_module.TextIteratorStreamer', return_value=MagicMock()):
             list(chat_agent_stream(query, history))
             
             # Verify alternation
             call_args = mock_processor.apply_chat_template.call_args_list[0]
             messages = call_args[0][0]
             
             # Expected roles: [system, user acknowledgement, assistant welcome, user query]
             roles = [m['role'] for m in messages]
             
             # Alternation check
             for i in range(len(roles) - 1):
                 self.assertNotEqual(roles[i], roles[i+1], f"Roles at indices {i} and {i+1} do not alternate: {roles}")
             
             # Ensure index 1 is 'user' (after system)
             self.assertEqual(roles[0], "system")
             self.assertEqual(roles[1], "user")



if __name__ == '__main__':
    unittest.main()