Spaces:
Runtime error
Runtime error
| """ | |
| Unit tests for Groq LangChain integration service. | |
| Tests cover API authentication, response generation, streaming functionality, | |
| error handling, and various edge cases with mocked responses. | |
| """ | |
| import pytest | |
| import os | |
| from unittest.mock import Mock, patch, MagicMock | |
| from typing import List | |
| from chat_agent.services.groq_client import ( | |
| GroqClient, | |
| ChatMessage, | |
| LanguageContext, | |
| GroqError, | |
| GroqRateLimitError, | |
| GroqAuthenticationError, | |
| GroqNetworkError, | |
| create_language_context, | |
| DEFAULT_LANGUAGE_TEMPLATES | |
| ) | |
| class TestGroqClient: | |
| """Test suite for GroqClient class""" | |
| def mock_env_vars(self): | |
| """Mock environment variables for testing""" | |
| with patch.dict(os.environ, { | |
| 'GROQ_API_KEY': 'test_api_key', | |
| 'GROQ_MODEL': 'mixtral-8x7b-32768', | |
| 'MAX_TOKENS': '2048', | |
| 'TEMPERATURE': '0.7', | |
| 'STREAM_RESPONSES': 'True', | |
| 'CONTEXT_WINDOW_SIZE': '10' | |
| }): | |
| yield | |
| def sample_chat_history(self): | |
| """Sample chat history for testing""" | |
| return [ | |
| ChatMessage(role="user", content="What is Python?", language="python"), | |
| ChatMessage(role="assistant", content="Python is a programming language.", language="python"), | |
| ChatMessage(role="user", content="How do I create a list?", language="python") | |
| ] | |
| def sample_language_context(self): | |
| """Sample language context for testing""" | |
| return LanguageContext( | |
| language="python", | |
| prompt_template="You are a Python programming assistant. Language: {language}", | |
| syntax_highlighting="python" | |
| ) | |
| def test_client_initialization_success(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test successful client initialization""" | |
| # Arrange | |
| mock_groq_instance = Mock() | |
| mock_chatgroq_instance = Mock() | |
| mock_groq.return_value = mock_groq_instance | |
| mock_chatgroq.return_value = mock_chatgroq_instance | |
| # Act | |
| client = GroqClient() | |
| # Assert | |
| assert client.api_key == 'test_api_key' | |
| assert client.model == 'mixtral-8x7b-32768' | |
| assert client.max_tokens == 2048 | |
| assert client.temperature == 0.7 | |
| assert client.stream_responses is True | |
| mock_groq.assert_called_once_with(api_key='test_api_key') | |
| mock_chatgroq.assert_called_once() | |
| def test_client_initialization_no_api_key(self): | |
| """Test client initialization fails without API key""" | |
| with patch.dict(os.environ, {}, clear=True): | |
| with pytest.raises(GroqAuthenticationError, match="Groq API key not provided"): | |
| GroqClient() | |
| def test_client_initialization_with_custom_params(self, mock_chatgroq, mock_groq): | |
| """Test client initialization with custom parameters""" | |
| # Act | |
| client = GroqClient(api_key="custom_key", model="custom_model") | |
| # Assert | |
| assert client.api_key == "custom_key" | |
| assert client.model == "custom_model" | |
| def test_generate_response_standard(self, mock_chatgroq, mock_groq, mock_env_vars, | |
| sample_chat_history, sample_language_context): | |
| """Test standard response generation""" | |
| # Arrange | |
| mock_langchain_client = Mock() | |
| mock_response = Mock() | |
| mock_response.content = "Here's how to create a list in Python: my_list = []" | |
| mock_langchain_client.invoke.return_value = mock_response | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| client.langchain_client = mock_langchain_client | |
| # Act | |
| response = client.generate_response( | |
| prompt="How do I create a list?", | |
| chat_history=sample_chat_history, | |
| language_context=sample_language_context, | |
| stream=False | |
| ) | |
| # Assert | |
| assert response == "Here's how to create a list in Python: my_list = []" | |
| mock_langchain_client.invoke.assert_called_once() | |
| def test_generate_response_streaming(self, mock_chatgroq, mock_groq, mock_env_vars, | |
| sample_chat_history, sample_language_context): | |
| """Test streaming response generation""" | |
| # Arrange | |
| mock_groq_client = Mock() | |
| mock_chunk1 = Mock() | |
| mock_chunk1.choices = [Mock()] | |
| mock_chunk1.choices[0].delta.content = "Here's " | |
| mock_chunk2 = Mock() | |
| mock_chunk2.choices = [Mock()] | |
| mock_chunk2.choices[0].delta.content = "how to create a list" | |
| mock_groq_client.chat.completions.create.return_value = [mock_chunk1, mock_chunk2] | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| client.groq_client = mock_groq_client | |
| # Act | |
| response_chunks = list(client.stream_response( | |
| prompt="How do I create a list?", | |
| chat_history=sample_chat_history, | |
| language_context=sample_language_context | |
| )) | |
| # Assert | |
| assert response_chunks == ["Here's ", "how to create a list"] | |
| mock_groq_client.chat.completions.create.assert_called_once() | |
| def test_build_messages(self, mock_chatgroq, mock_groq, mock_env_vars, | |
| sample_chat_history, sample_language_context): | |
| """Test message building with context and history""" | |
| # Arrange | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| # Act | |
| messages = client._build_messages( | |
| prompt="New question", | |
| chat_history=sample_chat_history, | |
| language_context=sample_language_context | |
| ) | |
| # Assert | |
| assert len(messages) == 5 # system + 3 history + current | |
| assert messages[0].role == "system" | |
| assert "Python programming assistant" in messages[0].content | |
| assert messages[-1].role == "user" | |
| assert messages[-1].content == "New question" | |
| def test_build_messages_context_window_limit(self, mock_chatgroq, mock_groq, mock_env_vars, | |
| sample_language_context): | |
| """Test message building respects context window limit""" | |
| # Arrange | |
| long_history = [ | |
| ChatMessage(role="user", content=f"Question {i}", language="python") | |
| for i in range(15) # More than context window of 10 | |
| ] | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| # Act | |
| messages = client._build_messages( | |
| prompt="New question", | |
| chat_history=long_history, | |
| language_context=sample_language_context | |
| ) | |
| # Assert | |
| # Should have system + last 10 from history + current = 12 messages | |
| assert len(messages) == 12 | |
| assert messages[0].role == "system" | |
| assert messages[-1].content == "New question" | |
| def test_handle_rate_limit_error(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test handling of rate limit errors""" | |
| # Arrange | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| rate_limit_error = Exception("Rate limit exceeded (429)") | |
| # Act | |
| result = client._handle_api_error(rate_limit_error) | |
| # Assert | |
| assert "high demand" in result.lower() | |
| assert "try again" in result.lower() | |
| def test_handle_authentication_error(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test handling of authentication errors""" | |
| # Arrange | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| auth_error = Exception("Authentication failed (401)") | |
| # Act & Assert | |
| with pytest.raises(GroqAuthenticationError): | |
| client._handle_api_error(auth_error) | |
| def test_handle_network_error(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test handling of network errors""" | |
| # Arrange | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| network_error = Exception("Network connection failed") | |
| # Act | |
| result = client._handle_api_error(network_error) | |
| # Assert | |
| assert "connection" in result.lower() | |
| assert "try again" in result.lower() | |
| def test_handle_quota_error(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test handling of quota/billing errors""" | |
| # Arrange | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| quota_error = Exception("Quota exceeded for billing") | |
| # Act | |
| result = client._handle_api_error(quota_error) | |
| # Assert | |
| assert "temporarily unavailable" in result.lower() | |
| def test_handle_unexpected_error(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test handling of unexpected errors""" | |
| # Arrange | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| unexpected_error = Exception("Something went wrong") | |
| # Act | |
| result = client._handle_api_error(unexpected_error) | |
| # Assert | |
| assert "unexpected error" in result.lower() | |
| assert "rephrasing" in result.lower() | |
| def test_test_connection_success(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test successful connection test""" | |
| # Arrange | |
| mock_langchain_client = Mock() | |
| mock_response = Mock() | |
| mock_response.content = "Hello! How can I help you?" | |
| mock_langchain_client.invoke.return_value = mock_response | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| client.langchain_client = mock_langchain_client | |
| # Act | |
| result = client.test_connection() | |
| # Assert | |
| assert result is True | |
| mock_langchain_client.invoke.assert_called_once() | |
| def test_test_connection_failure(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test connection test failure""" | |
| # Arrange | |
| mock_langchain_client = Mock() | |
| mock_langchain_client.invoke.side_effect = Exception("Connection failed") | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| client.langchain_client = mock_langchain_client | |
| # Act | |
| result = client.test_connection() | |
| # Assert | |
| assert result is False | |
| def test_get_model_info(self, mock_chatgroq, mock_groq, mock_env_vars): | |
| """Test getting model configuration information""" | |
| # Arrange | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| # Act | |
| info = client.get_model_info() | |
| # Assert | |
| assert info['model'] == 'mixtral-8x7b-32768' | |
| assert info['max_tokens'] == 2048 | |
| assert info['temperature'] == 0.7 | |
| assert info['stream_responses'] is True | |
| assert info['api_key_configured'] is True | |
| def test_streaming_response_with_error(self, mock_env_vars, sample_chat_history, sample_language_context): | |
| """Test streaming response handles errors gracefully""" | |
| # Arrange | |
| mock_groq_client = Mock() | |
| mock_groq_client.chat.completions.create.side_effect = Exception("API Error") | |
| with patch.object(GroqClient, '_initialize_clients'): | |
| client = GroqClient() | |
| client.groq_client = mock_groq_client | |
| # Act | |
| response_chunks = list(client.stream_response( | |
| prompt="Test question", | |
| chat_history=sample_chat_history, | |
| language_context=sample_language_context | |
| )) | |
| # Assert | |
| assert len(response_chunks) == 1 | |
| assert "Error:" in response_chunks[0] | |
| class TestLanguageContext: | |
| """Test suite for language context functionality""" | |
| def test_create_language_context_python(self): | |
| """Test creating Python language context""" | |
| context = create_language_context("python") | |
| assert context.language == "python" | |
| assert "Python" in context.prompt_template | |
| assert context.syntax_highlighting == "python" | |
| def test_create_language_context_javascript(self): | |
| """Test creating JavaScript language context""" | |
| context = create_language_context("javascript") | |
| assert context.language == "javascript" | |
| assert "JavaScript" in context.prompt_template | |
| assert context.syntax_highlighting == "javascript" | |
| def test_create_language_context_java(self): | |
| """Test creating Java language context""" | |
| context = create_language_context("java") | |
| assert context.language == "java" | |
| assert "Java" in context.prompt_template | |
| assert context.syntax_highlighting == "java" | |
| def test_create_language_context_cpp(self): | |
| """Test creating C++ language context""" | |
| context = create_language_context("cpp") | |
| assert context.language == "cpp" | |
| assert "C++" in context.prompt_template | |
| assert context.syntax_highlighting == "cpp" | |
| def test_create_language_context_unsupported_defaults_to_python(self): | |
| """Test unsupported language defaults to Python""" | |
| context = create_language_context("unsupported_language") | |
| assert context.language == "unsupported_language" | |
| assert "Python" in context.prompt_template # Should use Python template | |
| assert context.syntax_highlighting == "unsupported_language" | |
| def test_create_language_context_case_insensitive(self): | |
| """Test language context creation is case insensitive""" | |
| context = create_language_context("PYTHON") | |
| assert context.language == "PYTHON" | |
| assert "Python" in context.prompt_template | |
| assert context.syntax_highlighting == "python" | |
| class TestChatMessage: | |
| """Test suite for ChatMessage dataclass""" | |
| def test_chat_message_creation(self): | |
| """Test creating a chat message""" | |
| message = ChatMessage( | |
| role="user", | |
| content="Hello, world!", | |
| language="python", | |
| timestamp="2023-01-01T00:00:00Z" | |
| ) | |
| assert message.role == "user" | |
| assert message.content == "Hello, world!" | |
| assert message.language == "python" | |
| assert message.timestamp == "2023-01-01T00:00:00Z" | |
| def test_chat_message_optional_fields(self): | |
| """Test chat message with optional fields""" | |
| message = ChatMessage(role="assistant", content="Hi there!") | |
| assert message.role == "assistant" | |
| assert message.content == "Hi there!" | |
| assert message.language is None | |
| assert message.timestamp is None | |
| class TestDefaultLanguageTemplates: | |
| """Test suite for default language templates""" | |
| def test_all_templates_exist(self): | |
| """Test that all expected language templates exist""" | |
| expected_languages = ["python", "javascript", "java", "cpp"] | |
| for lang in expected_languages: | |
| assert lang in DEFAULT_LANGUAGE_TEMPLATES | |
| assert len(DEFAULT_LANGUAGE_TEMPLATES[lang]) > 0 | |
| def test_templates_contain_language_name(self): | |
| """Test that templates contain the language name""" | |
| for lang, template in DEFAULT_LANGUAGE_TEMPLATES.items(): | |
| # Convert cpp to C++ for the check | |
| display_name = "C++" if lang == "cpp" else lang.title() | |
| assert display_name in template | |
| def test_templates_are_educational(self): | |
| """Test that templates are focused on education""" | |
| for template in DEFAULT_LANGUAGE_TEMPLATES.values(): | |
| assert "student" in template.lower() or "learn" in template.lower() | |
| assert "beginner" in template.lower() | |
| assert "example" in template.lower() | |
| if __name__ == "__main__": | |
| pytest.main([__file__]) |