"""Tests for the Celery worker implementation. This module tests the distributed worker functionality including network error handling, retry policies, and job lifecycle management. """ import json import pytest import sys import os from unittest.mock import MagicMock, patch from datetime import datetime sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) from backend.data_sources.worker import ( celery_app, process_federated_job, NetworkError, TransientError, PermanentError, handle_network_errors, check_job_cancelled, store_results_with_fallback, send_to_dlq ) from backend.data_sources.jobs import JobStatus class TestCeleryWorker: """Tests for Celery worker functionality.""" def test_celery_app_configuration(self): """Test that Celery app is configured correctly.""" # Check basic configuration assert celery_app.conf.task_serializer == 'json' assert celery_app.conf.accept_content == ['json'] assert celery_app.conf.result_serializer == 'json' assert celery_app.conf.timezone == 'UTC' assert celery_app.conf.enable_utc is True # Check worker configuration assert celery_app.conf.worker_prefetch_multiplier == 1 assert celery_app.conf.task_acks_late is True assert celery_app.conf.worker_max_tasks_per_child == 100 # Check time limits assert celery_app.conf.task_soft_time_limit == 300 assert celery_app.conf.task_time_limit == 600 def test_network_error_decorator(self): """Test the network error handling decorator.""" call_count = 0 @handle_network_errors def failing_function(): nonlocal call_count call_count += 1 if call_count < 3: from redis.exceptions import ConnectionError raise ConnectionError("Connection failed") return "success" # Should retry and eventually succeed result = failing_function() assert result == "success" assert call_count == 3 def test_network_error_decorator_max_retries(self): """Test that decorator raises NetworkError after max retries.""" @handle_network_errors def always_failing_function(): from redis.exceptions import ConnectionError raise ConnectionError("Connection always fails") # Should raise NetworkError after max retries with pytest.raises(NetworkError): always_failing_function() def test_check_job_cancelled_true(self): """Test job cancellation check when job is cancelled.""" mock_redis = MagicMock() mock_redis.get.return_value = b"true" result = check_job_cancelled("test-job-123", mock_redis) assert result is True mock_redis.get.assert_called_once_with("job:test-job-123:cancel") def test_check_job_cancelled_false(self): """Test job cancellation check when job is not cancelled.""" mock_redis = MagicMock() mock_redis.get.return_value = None result = check_job_cancelled("test-job-123", mock_redis) assert result is False def test_check_job_cancelled_network_error(self): """Test job cancellation check handles network errors gracefully.""" mock_redis = MagicMock() from redis.exceptions import ConnectionError mock_redis.get.side_effect = ConnectionError("Network error") # Should return False (assume not cancelled) when network fails result = check_job_cancelled("test-job-123", mock_redis) assert result is False def test_store_results_redis_small(self): """Test storing small results in Redis.""" mock_redis = MagicMock() mock_minio = MagicMock() small_results = {"test": "data"} result_location = store_results_with_fallback( "test-job-123", small_results, mock_redis, mock_minio, "trace-123" ) assert result_location.startswith("redis://") mock_redis.setex.assert_called_once() mock_minio.put_object.assert_not_called() def test_store_results_minio_large(self): """Test storing large results in MinIO.""" mock_redis = MagicMock() mock_minio = MagicMock() # Create large results (> 1MB) large_results = {"data": "x" * (1024 * 1024 + 1)} result_location = store_results_with_fallback( "test-job-123", large_results, mock_redis, mock_minio, "trace-123" ) assert result_location.startswith("minio://") mock_minio.put_object.assert_called_once() def test_store_results_fallback_to_minio(self): """Test fallback to MinIO when Redis fails.""" mock_redis = MagicMock() mock_minio = MagicMock() from redis.exceptions import ConnectionError mock_redis.setex.side_effect = ConnectionError("Redis down") small_results = {"test": "data"} result_location = store_results_with_fallback( "test-job-123", small_results, mock_redis, mock_minio, "trace-123" ) assert result_location.startswith("minio://") mock_minio.put_object.assert_called_once() def test_send_to_dlq(self): """Test sending failed jobs to Dead Letter Queue.""" mock_redis = MagicMock() job_payload = {"plan": [{"source": "test", "query": {"operation": "table", "name": "test"}}]} send_to_dlq("test-job-123", job_payload, "Test error", "permanent_error", mock_redis) # Should add to DLQ list mock_redis.lpush.assert_called_once() mock_redis.expire.assert_called_once() # Check DLQ entry structure dlq_call = mock_redis.lpush.call_args[0] assert dlq_call[0] == "dlq:failed_jobs" dlq_entry = json.loads(dlq_call[1]) assert dlq_entry["job_id"] == "test-job-123" assert dlq_entry["error"] == "Test error" assert dlq_entry["failure_type"] == "permanent_error" @patch('data_sources.worker.get_redis_client_with_retry') @patch('data_sources.worker.get_minio_client_with_retry') @patch('data_sources.worker.get_tenant_config_with_retry') @patch('data_sources.worker.FederationAgent') def test_process_job_success(self, mock_agent_class, mock_tenant_config, mock_minio, mock_redis): """Test successful job processing.""" # Setup mocks mock_redis_instance = MagicMock() mock_redis.return_value = mock_redis_instance mock_redis_instance.get.return_value = None # Not cancelled mock_minio_instance = MagicMock() mock_minio.return_value = mock_minio_instance mock_tenant_config.return_value = [{"source_name": "test_db", "source_type": "ibis", "config": {"uri": "duckdb:///:memory:"}}] mock_agent = MagicMock() mock_agent.execute_federated_plan.return_value = {"step_1": [{"id": 1}]} mock_agent_class.return_value = mock_agent # Create a mock task mock_task = MagicMock() mock_task.request.retries = 0 mock_task.max_retries = 3 # Execute the task function directly job_payload = {"plan": [{"source": "test_db", "query": {"operation": "table", "name": "orders"}}]} # Note: We can't easily test the actual Celery task without a running broker, # but we can test the core logic by calling the task function directly # This is a limitation of the current test setup # For now, just verify the imports and basic structure work assert callable(process_federated_job) assert hasattr(process_federated_job, 'delay') # Celery task method class TestErrorClasses: """Tests for custom error classes.""" def test_network_error(self): """Test NetworkError class.""" error = NetworkError("Network failed") assert str(error) == "Network failed" assert isinstance(error, Exception) def test_transient_error(self): """Test TransientError class.""" error = TransientError("Temporary failure") assert str(error) == "Temporary failure" assert isinstance(error, Exception) def test_permanent_error(self): """Test PermanentError class.""" error = PermanentError("Permanent failure") assert str(error) == "Permanent failure" assert isinstance(error, Exception) if __name__ == "__main__": pytest.main([__file__, "-v"])