Spaces:
Running
Running
| """Tests for the Celery worker implementation. | |
| This module tests the distributed worker functionality including | |
| network error handling, retry policies, and job lifecycle management. | |
| """ | |
| import json | |
| import pytest | |
| import sys | |
| import os | |
| from unittest.mock import MagicMock, patch | |
| from datetime import datetime | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) | |
| from backend.data_sources.worker import ( | |
| celery_app, process_federated_job, | |
| NetworkError, TransientError, PermanentError, | |
| handle_network_errors, check_job_cancelled, | |
| store_results_with_fallback, send_to_dlq | |
| ) | |
| from backend.data_sources.jobs import JobStatus | |
| class TestCeleryWorker: | |
| """Tests for Celery worker functionality.""" | |
| def test_celery_app_configuration(self): | |
| """Test that Celery app is configured correctly.""" | |
| # Check basic configuration | |
| assert celery_app.conf.task_serializer == 'json' | |
| assert celery_app.conf.accept_content == ['json'] | |
| assert celery_app.conf.result_serializer == 'json' | |
| assert celery_app.conf.timezone == 'UTC' | |
| assert celery_app.conf.enable_utc is True | |
| # Check worker configuration | |
| assert celery_app.conf.worker_prefetch_multiplier == 1 | |
| assert celery_app.conf.task_acks_late is True | |
| assert celery_app.conf.worker_max_tasks_per_child == 100 | |
| # Check time limits | |
| assert celery_app.conf.task_soft_time_limit == 300 | |
| assert celery_app.conf.task_time_limit == 600 | |
| def test_network_error_decorator(self): | |
| """Test the network error handling decorator.""" | |
| call_count = 0 | |
| def failing_function(): | |
| nonlocal call_count | |
| call_count += 1 | |
| if call_count < 3: | |
| from redis.exceptions import ConnectionError | |
| raise ConnectionError("Connection failed") | |
| return "success" | |
| # Should retry and eventually succeed | |
| result = failing_function() | |
| assert result == "success" | |
| assert call_count == 3 | |
| def test_network_error_decorator_max_retries(self): | |
| """Test that decorator raises NetworkError after max retries.""" | |
| def always_failing_function(): | |
| from redis.exceptions import ConnectionError | |
| raise ConnectionError("Connection always fails") | |
| # Should raise NetworkError after max retries | |
| with pytest.raises(NetworkError): | |
| always_failing_function() | |
| def test_check_job_cancelled_true(self): | |
| """Test job cancellation check when job is cancelled.""" | |
| mock_redis = MagicMock() | |
| mock_redis.get.return_value = b"true" | |
| result = check_job_cancelled("test-job-123", mock_redis) | |
| assert result is True | |
| mock_redis.get.assert_called_once_with("job:test-job-123:cancel") | |
| def test_check_job_cancelled_false(self): | |
| """Test job cancellation check when job is not cancelled.""" | |
| mock_redis = MagicMock() | |
| mock_redis.get.return_value = None | |
| result = check_job_cancelled("test-job-123", mock_redis) | |
| assert result is False | |
| def test_check_job_cancelled_network_error(self): | |
| """Test job cancellation check handles network errors gracefully.""" | |
| mock_redis = MagicMock() | |
| from redis.exceptions import ConnectionError | |
| mock_redis.get.side_effect = ConnectionError("Network error") | |
| # Should return False (assume not cancelled) when network fails | |
| result = check_job_cancelled("test-job-123", mock_redis) | |
| assert result is False | |
| def test_store_results_redis_small(self): | |
| """Test storing small results in Redis.""" | |
| mock_redis = MagicMock() | |
| mock_minio = MagicMock() | |
| small_results = {"test": "data"} | |
| result_location = store_results_with_fallback( | |
| "test-job-123", small_results, mock_redis, mock_minio, "trace-123" | |
| ) | |
| assert result_location.startswith("redis://") | |
| mock_redis.setex.assert_called_once() | |
| mock_minio.put_object.assert_not_called() | |
| def test_store_results_minio_large(self): | |
| """Test storing large results in MinIO.""" | |
| mock_redis = MagicMock() | |
| mock_minio = MagicMock() | |
| # Create large results (> 1MB) | |
| large_results = {"data": "x" * (1024 * 1024 + 1)} | |
| result_location = store_results_with_fallback( | |
| "test-job-123", large_results, mock_redis, mock_minio, "trace-123" | |
| ) | |
| assert result_location.startswith("minio://") | |
| mock_minio.put_object.assert_called_once() | |
| def test_store_results_fallback_to_minio(self): | |
| """Test fallback to MinIO when Redis fails.""" | |
| mock_redis = MagicMock() | |
| mock_minio = MagicMock() | |
| from redis.exceptions import ConnectionError | |
| mock_redis.setex.side_effect = ConnectionError("Redis down") | |
| small_results = {"test": "data"} | |
| result_location = store_results_with_fallback( | |
| "test-job-123", small_results, mock_redis, mock_minio, "trace-123" | |
| ) | |
| assert result_location.startswith("minio://") | |
| mock_minio.put_object.assert_called_once() | |
| def test_send_to_dlq(self): | |
| """Test sending failed jobs to Dead Letter Queue.""" | |
| mock_redis = MagicMock() | |
| job_payload = {"plan": [{"source": "test", "query": {"operation": "table", "name": "test"}}]} | |
| send_to_dlq("test-job-123", job_payload, "Test error", "permanent_error", mock_redis) | |
| # Should add to DLQ list | |
| mock_redis.lpush.assert_called_once() | |
| mock_redis.expire.assert_called_once() | |
| # Check DLQ entry structure | |
| dlq_call = mock_redis.lpush.call_args[0] | |
| assert dlq_call[0] == "dlq:failed_jobs" | |
| dlq_entry = json.loads(dlq_call[1]) | |
| assert dlq_entry["job_id"] == "test-job-123" | |
| assert dlq_entry["error"] == "Test error" | |
| assert dlq_entry["failure_type"] == "permanent_error" | |
| def test_process_job_success(self, mock_agent_class, mock_tenant_config, mock_minio, mock_redis): | |
| """Test successful job processing.""" | |
| # Setup mocks | |
| mock_redis_instance = MagicMock() | |
| mock_redis.return_value = mock_redis_instance | |
| mock_redis_instance.get.return_value = None # Not cancelled | |
| mock_minio_instance = MagicMock() | |
| mock_minio.return_value = mock_minio_instance | |
| mock_tenant_config.return_value = [{"source_name": "test_db", "source_type": "ibis", "config": {"uri": "duckdb:///:memory:"}}] | |
| mock_agent = MagicMock() | |
| mock_agent.execute_federated_plan.return_value = {"step_1": [{"id": 1}]} | |
| mock_agent_class.return_value = mock_agent | |
| # Create a mock task | |
| mock_task = MagicMock() | |
| mock_task.request.retries = 0 | |
| mock_task.max_retries = 3 | |
| # Execute the task function directly | |
| job_payload = {"plan": [{"source": "test_db", "query": {"operation": "table", "name": "orders"}}]} | |
| # Note: We can't easily test the actual Celery task without a running broker, | |
| # but we can test the core logic by calling the task function directly | |
| # This is a limitation of the current test setup | |
| # For now, just verify the imports and basic structure work | |
| assert callable(process_federated_job) | |
| assert hasattr(process_federated_job, 'delay') # Celery task method | |
| class TestErrorClasses: | |
| """Tests for custom error classes.""" | |
| def test_network_error(self): | |
| """Test NetworkError class.""" | |
| error = NetworkError("Network failed") | |
| assert str(error) == "Network failed" | |
| assert isinstance(error, Exception) | |
| def test_transient_error(self): | |
| """Test TransientError class.""" | |
| error = TransientError("Temporary failure") | |
| assert str(error) == "Temporary failure" | |
| assert isinstance(error, Exception) | |
| def test_permanent_error(self): | |
| """Test PermanentError class.""" | |
| error = PermanentError("Permanent failure") | |
| assert str(error) == "Permanent failure" | |
| assert isinstance(error, Exception) | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |