gemma-sage / tests /test_simulation.py
neuralworm's picture
Fix Role Alternation crash for assistant-led conversations and extend simulation tests
de6582c
import unittest
from unittest.mock import MagicMock, patch
import sys
import os
# Add parent directory to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from app_module import chat_agent_stream
class TestAgentSimulation(unittest.TestCase):
@patch('app_module.get_llm')
@patch('app_module.detect_language')
def test_history_propagation(self, mock_detect, mock_get_llm):
# Setup Mocks
mock_model = MagicMock()
mock_processor = MagicMock()
mock_get_llm.return_value = (mock_model, mock_processor)
mock_detect.return_value = "English"
# Mock Processor behavior
mock_processor.apply_chat_template.return_value = MagicMock() # input_ids
# Simulating conversation state
# Turn 1: User says "My name is Julian"
history = []
query1 = "My name is Julian"
# Run Agent (mocking generation)
mock_streamer = MagicMock()
mock_streamer.__iter__.return_value = ["Hello", " Julian", "."]
with patch('app_module.TextIteratorStreamer', return_value=mock_streamer):
# Execute Turn 1
responses = list(chat_agent_stream(query1, history))
# Verify input to model for Turn 1
call_args_1 = mock_processor.apply_chat_template.call_args_list[0]
messages_1 = call_args_1[0][0]
self.assertEqual(messages_1[-1]['content'][0]['text'], query1)
# Update history manually as wrapper would
history.append({"role": "user", "content": query1})
history.append({"role": "assistant", "content": "Hello Julian."})
# Turn 2: User says "What is my name?"
query2 = "What is my name?"
# Run Agent Turn 2
mock_streamer.__iter__.return_value = ["Your", " name", " is", " Julian."]
responses = list(chat_agent_stream(query2, history))
# Verify input to model for Turn 2
call_args_2 = mock_processor.apply_chat_template.call_args_list[-1]
messages_2 = call_args_2[0][0]
# Check if "My name is Julian" is in the messages
found_history = False
for m in messages_2:
content_str = str(m['content'])
if "My name is Julian" in content_str:
found_history = True
break
self.assertTrue(found_history, "Agent input messages did NOT contain previous user instruction!")
@patch('app_module.get_llm')
@patch('app_module.detect_language')
def test_assistant_led_conversation(self, mock_detect, mock_get_llm):
"""Test that history starting with an assistant turn doesn't break role alternation."""
# Setup Mocks
mock_model = MagicMock()
mock_processor = MagicMock()
mock_get_llm.return_value = (mock_model, mock_processor)
mock_detect.return_value = "English"
mock_processor.apply_chat_template.return_value = MagicMock()
# Scenario: History starts with an assistant welcome message
history = [{"role": "assistant", "content": "Welcome Seeker!"}]
query = "Tell me my fate"
with patch('app_module.TextIteratorStreamer', return_value=MagicMock()):
list(chat_agent_stream(query, history))
# Verify alternation
call_args = mock_processor.apply_chat_template.call_args_list[0]
messages = call_args[0][0]
# Expected roles: [system, user acknowledgement, assistant welcome, user query]
roles = [m['role'] for m in messages]
# Alternation check
for i in range(len(roles) - 1):
self.assertNotEqual(roles[i], roles[i+1], f"Roles at indices {i} and {i+1} do not alternate: {roles}")
# Ensure index 1 is 'user' (after system)
self.assertEqual(roles[0], "system")
self.assertEqual(roles[1], "user")
if __name__ == '__main__':
unittest.main()