Spaces:
Runtime error
Runtime error
| """ | |
| Unit tests for comprehensive error handling and logging. | |
| Tests the error handling utilities, circuit breaker pattern, | |
| and fallback response mechanisms. | |
| """ | |
| import pytest | |
| import logging | |
| import time | |
| from unittest.mock import Mock, patch, MagicMock | |
| from datetime import datetime | |
| from chat_agent.utils.error_handler import ( | |
| ErrorSeverity, ErrorCategory, ChatAgentError, ErrorHandler, | |
| error_handler_decorator, get_error_handler | |
| ) | |
| from chat_agent.utils.circuit_breaker import ( | |
| CircuitState, CircuitBreakerConfig, CircuitBreaker, | |
| circuit_breaker, CircuitBreakerManager | |
| ) | |
| from chat_agent.utils.logging_config import ( | |
| StructuredFormatter, ChatAgentFilter, LoggingConfig, | |
| PerformanceLogger, setup_logging | |
| ) | |
| class TestChatAgentError: | |
| """Test ChatAgentError class functionality.""" | |
| def test_error_initialization(self): | |
| """Test error initialization with all parameters.""" | |
| context = {'session_id': 'test-123', 'operation': 'test_op'} | |
| error = ChatAgentError( | |
| message="Test error", | |
| category=ErrorCategory.API_ERROR, | |
| severity=ErrorSeverity.HIGH, | |
| user_message="User friendly message", | |
| error_code="TEST_001", | |
| context=context | |
| ) | |
| assert error.category == ErrorCategory.API_ERROR | |
| assert error.severity == ErrorSeverity.HIGH | |
| assert error.user_message == "User friendly message" | |
| assert error.error_code == "TEST_001" | |
| assert error.context == context | |
| assert isinstance(error.timestamp, datetime) | |
| def test_error_default_values(self): | |
| """Test error initialization with default values.""" | |
| error = ChatAgentError("Test error") | |
| assert error.category == ErrorCategory.SYSTEM_ERROR | |
| assert error.severity == ErrorSeverity.MEDIUM | |
| assert error.user_message is not None | |
| assert error.error_code is not None | |
| assert error.context == {} | |
| def test_error_to_dict(self): | |
| """Test error serialization to dictionary.""" | |
| error = ChatAgentError( | |
| message="Test error", | |
| category=ErrorCategory.VALIDATION_ERROR, | |
| severity=ErrorSeverity.LOW | |
| ) | |
| error_dict = error.to_dict() | |
| assert error_dict['category'] == 'validation_error' | |
| assert error_dict['severity'] == 'low' | |
| assert 'error_code' in error_dict | |
| assert 'message' in error_dict | |
| assert 'timestamp' in error_dict | |
| assert 'context' in error_dict | |
| def test_default_user_messages(self): | |
| """Test default user messages for different categories.""" | |
| api_error = ChatAgentError("Test", category=ErrorCategory.API_ERROR) | |
| db_error = ChatAgentError("Test", category=ErrorCategory.DATABASE_ERROR) | |
| rate_error = ChatAgentError("Test", category=ErrorCategory.RATE_LIMIT_ERROR) | |
| assert "connecting to my services" in api_error.user_message | |
| assert "technical difficulties" in db_error.user_message | |
| assert "high demand" in rate_error.user_message | |
| class TestErrorHandler: | |
| """Test ErrorHandler class functionality.""" | |
| def setup_method(self): | |
| """Set up test fixtures.""" | |
| self.logger = Mock(spec=logging.Logger) | |
| self.error_handler = ErrorHandler(self.logger) | |
| def test_error_classification_api_error(self): | |
| """Test classification of API-related errors.""" | |
| api_error = Exception("Groq API connection failed") | |
| chat_error = self.error_handler._classify_error(api_error) | |
| assert chat_error.category == ErrorCategory.API_ERROR | |
| assert "Groq API connection failed" in str(chat_error) | |
| def test_error_classification_rate_limit(self): | |
| """Test classification of rate limit errors.""" | |
| rate_error = Exception("Rate limit exceeded (429)") | |
| chat_error = self.error_handler._classify_error(rate_error) | |
| assert chat_error.category == ErrorCategory.RATE_LIMIT_ERROR | |
| assert chat_error.severity == ErrorSeverity.MEDIUM | |
| def test_error_classification_database_error(self): | |
| """Test classification of database errors.""" | |
| db_error = Exception("PostgreSQL connection failed") | |
| chat_error = self.error_handler._classify_error(db_error) | |
| assert chat_error.category == ErrorCategory.DATABASE_ERROR | |
| assert chat_error.severity == ErrorSeverity.HIGH | |
| def test_error_classification_network_error(self): | |
| """Test classification of network errors.""" | |
| network_error = Exception("Connection timeout") | |
| chat_error = self.error_handler._classify_error(network_error) | |
| assert chat_error.category == ErrorCategory.NETWORK_ERROR | |
| def test_error_classification_validation_error(self): | |
| """Test classification of validation errors.""" | |
| validation_error = Exception("Invalid input format") | |
| chat_error = self.error_handler._classify_error(validation_error) | |
| assert chat_error.category == ErrorCategory.VALIDATION_ERROR | |
| assert chat_error.severity == ErrorSeverity.LOW | |
| def test_error_logging_levels(self): | |
| """Test that errors are logged with appropriate levels.""" | |
| # Critical error | |
| critical_error = ChatAgentError("Critical", severity=ErrorSeverity.CRITICAL) | |
| self.error_handler._log_error(critical_error, Exception("test")) | |
| self.logger.critical.assert_called_once() | |
| # High severity error | |
| self.logger.reset_mock() | |
| high_error = ChatAgentError("High", severity=ErrorSeverity.HIGH) | |
| self.error_handler._log_error(high_error, Exception("test")) | |
| self.logger.error.assert_called_once() | |
| # Medium severity error | |
| self.logger.reset_mock() | |
| medium_error = ChatAgentError("Medium", severity=ErrorSeverity.MEDIUM) | |
| self.error_handler._log_error(medium_error, Exception("test")) | |
| self.logger.warning.assert_called_once() | |
| # Low severity error | |
| self.logger.reset_mock() | |
| low_error = ChatAgentError("Low", severity=ErrorSeverity.LOW) | |
| self.error_handler._log_error(low_error, Exception("test")) | |
| self.logger.info.assert_called_once() | |
| def test_fallback_responses(self): | |
| """Test fallback response generation.""" | |
| api_error = ChatAgentError("Test", category=ErrorCategory.API_ERROR) | |
| fallback = self.error_handler.get_fallback_response(api_error) | |
| assert "programming tips" in fallback | |
| assert "try again" in fallback | |
| def test_handle_api_response_error(self): | |
| """Test API response error handling.""" | |
| error = Exception("Test API error") | |
| response = self.error_handler.handle_api_response_error(error) | |
| assert response['success'] is False | |
| assert 'error' in response | |
| assert 'fallback_response' in response | |
| assert isinstance(response['error'], dict) | |
| def test_handle_websocket_error(self, mock_emit): | |
| """Test WebSocket error handling.""" | |
| error = Exception("Test WebSocket error") | |
| self.error_handler.handle_websocket_error(error) | |
| mock_emit.assert_called_once() | |
| call_args = mock_emit.call_args[0] | |
| assert call_args[0] == 'error' | |
| assert 'error' in call_args[1] | |
| assert 'fallback_response' in call_args[1] | |
| class TestErrorHandlerDecorator: | |
| """Test error handler decorator functionality.""" | |
| def setup_method(self): | |
| """Set up test fixtures.""" | |
| self.logger = Mock(spec=logging.Logger) | |
| self.error_handler = ErrorHandler(self.logger) | |
| def test_decorator_success(self): | |
| """Test decorator with successful function execution.""" | |
| def test_function(x, y): | |
| return x + y | |
| result = test_function(2, 3) | |
| assert result == 5 | |
| def test_decorator_with_exception(self): | |
| """Test decorator with function that raises exception.""" | |
| def test_function(): | |
| raise ValueError("Test error") | |
| with pytest.raises(ChatAgentError): | |
| test_function() | |
| def test_decorator_with_fallback(self): | |
| """Test decorator with fallback response.""" | |
| def test_function(): | |
| raise ValueError("Test error") | |
| result = test_function() | |
| assert isinstance(result, str) | |
| assert "try again" in result.lower() | |
| def test_decorator_with_websocket_emit(self, mock_emit): | |
| """Test decorator with WebSocket error emission.""" | |
| def test_function(): | |
| raise ValueError("Test error") | |
| result = test_function() | |
| assert result is None | |
| mock_emit.assert_called_once() | |
| class TestCircuitBreaker: | |
| """Test CircuitBreaker class functionality.""" | |
| def setup_method(self): | |
| """Set up test fixtures.""" | |
| self.logger = Mock(spec=logging.Logger) | |
| self.config = CircuitBreakerConfig( | |
| failure_threshold=3, | |
| recovery_timeout=1, | |
| success_threshold=2, | |
| timeout=1.0 | |
| ) | |
| self.circuit_breaker = CircuitBreaker("test_circuit", self.config, logger=self.logger) | |
| def test_circuit_breaker_initialization(self): | |
| """Test circuit breaker initialization.""" | |
| assert self.circuit_breaker.name == "test_circuit" | |
| assert self.circuit_breaker.state == CircuitState.CLOSED | |
| assert self.circuit_breaker.is_closed | |
| assert not self.circuit_breaker.is_open | |
| assert not self.circuit_breaker.is_half_open | |
| def test_successful_call(self): | |
| """Test successful function call through circuit breaker.""" | |
| def success_function(x, y): | |
| return x + y | |
| result = self.circuit_breaker.call(success_function, 2, 3) | |
| assert result == 5 | |
| assert self.circuit_breaker.state == CircuitState.CLOSED | |
| def test_circuit_opening_on_failures(self): | |
| """Test circuit opening after threshold failures.""" | |
| def failing_function(): | |
| raise ValueError("Test failure") | |
| # Execute failures up to threshold | |
| for i in range(self.config.failure_threshold): | |
| with pytest.raises(ValueError): | |
| self.circuit_breaker.call(failing_function) | |
| # Circuit should now be open | |
| assert self.circuit_breaker.state == CircuitState.OPEN | |
| assert self.circuit_breaker.is_open | |
| def test_circuit_open_behavior(self): | |
| """Test behavior when circuit is open.""" | |
| # Force circuit to open | |
| self.circuit_breaker._open_circuit() | |
| def test_function(): | |
| return "should not execute" | |
| # Should raise ChatAgentError when circuit is open and no fallback | |
| with pytest.raises(ChatAgentError) as exc_info: | |
| self.circuit_breaker.call(test_function) | |
| assert exc_info.value.category == ErrorCategory.API_ERROR | |
| assert "circuit breaker" in str(exc_info.value).lower() | |
| def test_circuit_with_fallback(self): | |
| """Test circuit breaker with fallback function.""" | |
| def fallback_function(*args, **kwargs): | |
| return "fallback response" | |
| circuit_with_fallback = CircuitBreaker( | |
| "test_fallback", self.config, fallback_function, self.logger | |
| ) | |
| # Force circuit to open | |
| circuit_with_fallback._open_circuit() | |
| def test_function(): | |
| return "should not execute" | |
| result = circuit_with_fallback.call(test_function) | |
| assert result == "fallback response" | |
| def test_circuit_recovery_to_half_open(self): | |
| """Test circuit recovery to half-open state.""" | |
| # Force circuit to open | |
| self.circuit_breaker._open_circuit() | |
| # Wait for recovery timeout | |
| time.sleep(self.config.recovery_timeout + 0.1) | |
| def test_function(): | |
| return "success" | |
| # First call after timeout should move to half-open | |
| result = self.circuit_breaker.call(test_function) | |
| assert result == "success" | |
| assert self.circuit_breaker.state == CircuitState.HALF_OPEN | |
| def test_circuit_closing_from_half_open(self): | |
| """Test circuit closing from half-open after successful calls.""" | |
| # Move to half-open state | |
| self.circuit_breaker._half_open_circuit() | |
| def success_function(): | |
| return "success" | |
| # Execute successful calls up to success threshold | |
| for i in range(self.config.success_threshold): | |
| result = self.circuit_breaker.call(success_function) | |
| assert result == "success" | |
| # Circuit should now be closed | |
| assert self.circuit_breaker.state == CircuitState.CLOSED | |
| def test_circuit_stats(self): | |
| """Test circuit breaker statistics.""" | |
| def success_function(): | |
| return "success" | |
| def failing_function(): | |
| raise ValueError("failure") | |
| # Execute some calls | |
| self.circuit_breaker.call(success_function) | |
| try: | |
| self.circuit_breaker.call(failing_function) | |
| except ValueError: | |
| pass | |
| stats = self.circuit_breaker.get_stats() | |
| assert stats.total_requests == 2 | |
| assert stats.total_successes == 1 | |
| assert stats.total_failures == 1 | |
| assert stats.state == CircuitState.CLOSED | |
| def test_circuit_reset(self): | |
| """Test manual circuit reset.""" | |
| # Force circuit to open | |
| self.circuit_breaker._open_circuit() | |
| assert self.circuit_breaker.is_open | |
| # Reset circuit | |
| self.circuit_breaker.reset() | |
| assert self.circuit_breaker.is_closed | |
| class TestCircuitBreakerDecorator: | |
| """Test circuit breaker decorator functionality.""" | |
| def test_decorator_success(self): | |
| """Test decorator with successful function.""" | |
| def test_function(x, y): | |
| return x + y | |
| result = test_function(2, 3) | |
| assert result == 5 | |
| assert hasattr(test_function, 'circuit_breaker') | |
| assert test_function.circuit_breaker.name == "test_decorator" | |
| def test_decorator_with_failures(self): | |
| """Test decorator with failing function.""" | |
| config = CircuitBreakerConfig(failure_threshold=2) | |
| def failing_function(): | |
| raise ValueError("Test failure") | |
| # Execute failures | |
| for i in range(2): | |
| with pytest.raises(ValueError): | |
| failing_function() | |
| # Circuit should be open now | |
| assert failing_function.circuit_breaker.is_open | |
| class TestCircuitBreakerManager: | |
| """Test CircuitBreakerManager functionality.""" | |
| def setup_method(self): | |
| """Set up test fixtures.""" | |
| self.logger = Mock(spec=logging.Logger) | |
| self.manager = CircuitBreakerManager(self.logger) | |
| def test_create_breaker(self): | |
| """Test creating circuit breaker through manager.""" | |
| config = CircuitBreakerConfig(failure_threshold=5) | |
| breaker = self.manager.create_breaker("test_managed", config) | |
| assert breaker.name == "test_managed" | |
| assert breaker.config.failure_threshold == 5 | |
| def test_get_breaker(self): | |
| """Test retrieving circuit breaker from manager.""" | |
| breaker = self.manager.create_breaker("test_get") | |
| retrieved = self.manager.get_breaker("test_get") | |
| assert retrieved is breaker | |
| assert self.manager.get_breaker("nonexistent") is None | |
| def test_get_all_stats(self): | |
| """Test getting statistics for all breakers.""" | |
| breaker1 = self.manager.create_breaker("test1") | |
| breaker2 = self.manager.create_breaker("test2") | |
| stats = self.manager.get_all_stats() | |
| assert "test1" in stats | |
| assert "test2" in stats | |
| assert len(stats) == 2 | |
| def test_reset_all(self): | |
| """Test resetting all circuit breakers.""" | |
| breaker1 = self.manager.create_breaker("test1") | |
| breaker2 = self.manager.create_breaker("test2") | |
| # Force breakers to open | |
| breaker1._open_circuit() | |
| breaker2._open_circuit() | |
| assert breaker1.is_open | |
| assert breaker2.is_open | |
| # Reset all | |
| self.manager.reset_all() | |
| assert breaker1.is_closed | |
| assert breaker2.is_closed | |
| class TestLoggingConfiguration: | |
| """Test logging configuration functionality.""" | |
| def test_structured_formatter(self): | |
| """Test structured JSON formatter.""" | |
| formatter = StructuredFormatter() | |
| # Create log record | |
| record = logging.LogRecord( | |
| name="test_logger", | |
| level=logging.INFO, | |
| pathname="test.py", | |
| lineno=10, | |
| msg="Test message", | |
| args=(), | |
| exc_info=None | |
| ) | |
| # Add extra fields | |
| record.error_code = "TEST_001" | |
| record.session_id = "session-123" | |
| formatted = formatter.format(record) | |
| # Should be valid JSON | |
| import json | |
| log_data = json.loads(formatted) | |
| assert log_data['level'] == 'INFO' | |
| assert log_data['message'] == 'Test message' | |
| assert log_data['error_code'] == 'TEST_001' | |
| assert log_data['session_id'] == 'session-123' | |
| assert 'timestamp' in log_data | |
| def test_chat_agent_filter(self): | |
| """Test chat agent logging filter.""" | |
| filter_obj = ChatAgentFilter() | |
| record = logging.LogRecord( | |
| name="test_logger", | |
| level=logging.INFO, | |
| pathname="test.py", | |
| lineno=10, | |
| msg="Test message", | |
| args=(), | |
| exc_info=None | |
| ) | |
| # Add performance data | |
| record.processing_time = 6.0 # Slow operation | |
| result = filter_obj.filter(record) | |
| assert result is True | |
| assert hasattr(record, 'performance_alert') | |
| assert record.performance_alert is True | |
| def test_logging_config_setup(self): | |
| """Test logging configuration setup.""" | |
| config = LoggingConfig("test_app", "DEBUG") | |
| loggers = config.setup_logging() | |
| assert 'main' in loggers | |
| assert 'error' in loggers | |
| assert 'performance' in loggers | |
| assert 'security' in loggers | |
| assert 'api' in loggers | |
| assert 'websocket' in loggers | |
| assert 'database' in loggers | |
| # Check logger configuration | |
| main_logger = loggers['main'] | |
| assert main_logger.level == logging.DEBUG | |
| assert len(main_logger.handlers) > 0 | |
| def test_performance_logger(self): | |
| """Test performance logger functionality.""" | |
| logger = Mock(spec=logging.Logger) | |
| perf_logger = PerformanceLogger(logger) | |
| # Log normal operation | |
| perf_logger.log_operation("test_op", 1.0, {"key": "value"}) | |
| logger.info.assert_called_once() | |
| # Log slow operation | |
| logger.reset_mock() | |
| perf_logger.log_operation("slow_op", 6.0, {"key": "value"}) | |
| logger.warning.assert_called_once() | |
| def test_performance_logger_api_call(self): | |
| """Test performance logger API call logging.""" | |
| logger = Mock(spec=logging.Logger) | |
| perf_logger = PerformanceLogger(logger) | |
| # Log successful API call | |
| perf_logger.log_api_call("/api/test", "GET", 200, 0.5) | |
| logger.log.assert_called_once() | |
| # Log failed API call | |
| logger.reset_mock() | |
| perf_logger.log_api_call("/api/test", "POST", 500, 1.0) | |
| logger.log.assert_called_once() | |
| # Check that warning level was used for error status | |
| call_args = logger.log.call_args[0] | |
| assert call_args[0] == logging.WARNING | |
| class TestIntegrationScenarios: | |
| """Test integration scenarios combining error handling and circuit breaker.""" | |
| def setup_method(self): | |
| """Set up test fixtures.""" | |
| self.logger = Mock(spec=logging.Logger) | |
| self.error_handler = ErrorHandler(self.logger) | |
| # Create circuit breaker with fallback | |
| def fallback_response(*args, **kwargs): | |
| return "Fallback response from circuit breaker" | |
| config = CircuitBreakerConfig(failure_threshold=2, recovery_timeout=1) | |
| self.circuit_breaker = CircuitBreaker( | |
| "integration_test", config, fallback_response, self.logger | |
| ) | |
| def test_api_failure_with_circuit_breaker(self): | |
| """Test API failure handling with circuit breaker protection.""" | |
| def failing_api_call(): | |
| raise Exception("API connection failed") | |
| # First failure - circuit still closed | |
| with pytest.raises(Exception): | |
| self.circuit_breaker.call(failing_api_call) | |
| assert self.circuit_breaker.is_closed | |
| # Second failure - circuit opens | |
| with pytest.raises(Exception): | |
| self.circuit_breaker.call(failing_api_call) | |
| assert self.circuit_breaker.is_open | |
| # Third call - should use fallback | |
| result = self.circuit_breaker.call(failing_api_call) | |
| assert result == "Fallback response from circuit breaker" | |
| def test_error_classification_with_circuit_breaker(self): | |
| """Test error classification working with circuit breaker.""" | |
| def api_error_function(): | |
| raise Exception("Groq API rate limit exceeded") | |
| try: | |
| self.circuit_breaker.call(api_error_function) | |
| except Exception as e: | |
| chat_error = self.error_handler.handle_error(e) | |
| assert chat_error.category == ErrorCategory.API_ERROR | |
| def test_performance_monitoring_with_errors(self): | |
| """Test performance monitoring during error conditions.""" | |
| logger = Mock(spec=logging.Logger) | |
| perf_logger = PerformanceLogger(logger) | |
| # Simulate slow operation that fails | |
| start_time = time.time() | |
| try: | |
| time.sleep(0.1) # Simulate work | |
| raise Exception("Operation failed") | |
| except Exception as e: | |
| duration = time.time() - start_time | |
| chat_error = self.error_handler.handle_error(e, {'duration': duration}) | |
| # Log the failed operation | |
| perf_logger.log_operation("failed_operation", duration, {'error': str(e)}) | |
| # Verify logging occurred | |
| logger.info.assert_called_once() | |
| call_args = logger.info.call_args[1]['extra'] | |
| assert 'processing_time' in call_args | |
| assert 'context' in call_args | |
| if __name__ == '__main__': | |
| pytest.main([__file__]) |