scratch_chat / tests /unit /test_chat_agent.py
WebashalarForML's picture
Upload 178 files
330b6e4 verified
"""
Unit tests for ChatAgent service.
Tests the core functionality of the ChatAgent class including message processing,
language switching, streaming responses, and error handling.
"""
import unittest
from unittest.mock import Mock, patch
from datetime import datetime
from chat_agent.services.chat_agent import ChatAgent, ChatAgentError, create_chat_agent
from chat_agent.services.groq_client import GroqClient, ChatMessage, LanguageContext
from chat_agent.services.language_context import LanguageContextManager
from chat_agent.services.session_manager import SessionManager, SessionNotFoundError
from chat_agent.services.chat_history import ChatHistoryManager
class TestChatAgent(unittest.TestCase):
"""Test cases for ChatAgent class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock dependencies
self.mock_groq_client = Mock(spec=GroqClient)
self.mock_groq_client.generate_response.return_value = "Test LLM response"
self.mock_groq_client.stream_response.return_value = iter(["Test ", "stream ", "response"])
self.mock_groq_client.get_model_info.return_value = {"model": "test-model"}
self.language_context_manager = LanguageContextManager()
self.mock_session_manager = Mock(spec=SessionManager)
mock_session = Mock()
mock_session.id = "test-session"
mock_session.language = "python"
mock_session.message_count = 0
self.mock_session_manager.get_session.return_value = mock_session
self.mock_chat_history_manager = Mock(spec=ChatHistoryManager)
mock_message = Mock()
mock_message.id = "test-message-id"
mock_message.role = "user"
mock_message.content = "test content"
mock_message.language = "python"
mock_message.timestamp = datetime.utcnow()
mock_message.message_metadata = {}
self.mock_chat_history_manager.store_message.return_value = mock_message
self.mock_chat_history_manager.get_recent_history.return_value = [mock_message]
self.mock_chat_history_manager.get_message_count.return_value = 1
self.mock_chat_history_manager.get_cache_stats.return_value = {"cached_messages": 1}
# Create ChatAgent instance
self.chat_agent = ChatAgent(
self.mock_groq_client,
self.language_context_manager,
self.mock_session_manager,
self.mock_chat_history_manager
)
def test_initialization(self):
"""Test ChatAgent initialization."""
self.assertEqual(self.chat_agent.groq_client, self.mock_groq_client)
self.assertEqual(self.chat_agent.language_context_manager, self.language_context_manager)
self.assertEqual(self.chat_agent.session_manager, self.mock_session_manager)
self.assertEqual(self.chat_agent.chat_history_manager, self.mock_chat_history_manager)
def test_process_message_success(self):
"""Test successful message processing."""
result = self.chat_agent.process_message("test-session", "Hello, world!")
self.assertEqual(result['response'], "Test LLM response")
self.assertEqual(result['language'], "python")
self.assertEqual(result['session_id'], "test-session")
self.assertIn('message_id', result)
self.assertIn('timestamp', result)
self.assertIn('metadata', result)
# Verify service calls
self.mock_session_manager.get_session.assert_called_with("test-session")
self.mock_session_manager.update_session_activity.assert_called_with("test-session")
self.mock_session_manager.increment_message_count.assert_called_with("test-session")
self.assertEqual(self.mock_chat_history_manager.store_message.call_count, 2)
self.mock_groq_client.generate_response.assert_called_once()
def test_process_message_with_language_override(self):
"""Test message processing with language override."""
result = self.chat_agent.process_message("test-session", "Hello!", "javascript")
self.assertEqual(result['language'], "javascript")
self.mock_session_manager.set_session_language.assert_called_with("test-session", "javascript")
def test_process_message_session_not_found(self):
"""Test message processing with session not found."""
self.mock_session_manager.get_session.side_effect = SessionNotFoundError("Not found")
with self.assertRaises(ChatAgentError) as context:
self.chat_agent.process_message("invalid-session", "Hello!")
self.assertIn("Session error", str(context.exception))
def test_switch_language_success(self):
"""Test successful language switching."""
result = self.chat_agent.switch_language("test-session", "javascript")
self.assertTrue(result['success'])
self.assertEqual(result['new_language'], "javascript")
self.assertEqual(result['previous_language'], "python")
self.assertEqual(result['session_id'], "test-session")
self.assertIn('message', result)
# Verify service calls
self.mock_session_manager.set_session_language.assert_called_with("test-session", "javascript")
self.mock_chat_history_manager.store_message.assert_called_once()
def test_switch_language_invalid_language(self):
"""Test language switching with invalid language."""
with self.assertRaises(ChatAgentError) as context:
self.chat_agent.switch_language("test-session", "invalid-lang")
self.assertIn("Unsupported language", str(context.exception))
def test_stream_response_success(self):
"""Test successful streaming response."""
stream_results = list(self.chat_agent.stream_response("test-session", "Hello!"))
self.assertGreaterEqual(len(stream_results), 5) # start + chunks + complete
# Check start event
start_event = stream_results[0]
self.assertEqual(start_event['type'], 'start')
self.assertEqual(start_event['session_id'], "test-session")
# Check chunk events
chunk_events = [event for event in stream_results if event['type'] == 'chunk']
self.assertEqual(len(chunk_events), 3)
# Check complete event
complete_event = stream_results[-1]
self.assertEqual(complete_event['type'], 'complete')
self.assertEqual(complete_event['session_id'], "test-session")
def test_get_chat_history_success(self):
"""Test successful chat history retrieval."""
history = self.chat_agent.get_chat_history("test-session", 10)
self.assertIsInstance(history, list)
self.assertEqual(len(history), 1)
message = history[0]
self.assertIn('id', message)
self.assertIn('role', message)
self.assertIn('content', message)
self.assertIn('language', message)
self.assertIn('timestamp', message)
self.mock_chat_history_manager.get_recent_history.assert_called_with("test-session", 10)
def test_get_session_info_success(self):
"""Test successful session info retrieval."""
info = self.chat_agent.get_session_info("test-session")
self.assertIn('session', info)
self.assertIn('language_context', info)
self.assertIn('statistics', info)
self.assertIn('supported_languages', info)
# Verify service calls
self.mock_session_manager.get_session.assert_called_with("test-session")
self.mock_chat_history_manager.get_message_count.assert_called_with("test-session")
self.mock_chat_history_manager.get_cache_stats.assert_called_with("test-session")
class TestChatAgentFactory(unittest.TestCase):
"""Test cases for ChatAgent factory function."""
def test_create_chat_agent(self):
"""Test ChatAgent factory function."""
groq_client = Mock(spec=GroqClient)
language_manager = LanguageContextManager()
session_manager = Mock(spec=SessionManager)
history_manager = Mock(spec=ChatHistoryManager)
agent = create_chat_agent(groq_client, language_manager, session_manager, history_manager)
self.assertIsInstance(agent, ChatAgent)
self.assertEqual(agent.groq_client, groq_client)
self.assertEqual(agent.language_context_manager, language_manager)
self.assertEqual(agent.session_manager, session_manager)
self.assertEqual(agent.chat_history_manager, history_manager)
if __name__ == '__main__':
unittest.main()