sirus / backend /data_sources /tests /test_worker.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
b8277c4
"""Tests for the Celery worker implementation.
This module tests the distributed worker functionality including
network error handling, retry policies, and job lifecycle management.
"""
import json
import pytest
import sys
import os
from unittest.mock import MagicMock, patch
from datetime import datetime
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from backend.data_sources.worker import (
celery_app, process_federated_job,
NetworkError, TransientError, PermanentError,
handle_network_errors, check_job_cancelled,
store_results_with_fallback, send_to_dlq
)
from backend.data_sources.jobs import JobStatus
class TestCeleryWorker:
"""Tests for Celery worker functionality."""
def test_celery_app_configuration(self):
"""Test that Celery app is configured correctly."""
# Check basic configuration
assert celery_app.conf.task_serializer == 'json'
assert celery_app.conf.accept_content == ['json']
assert celery_app.conf.result_serializer == 'json'
assert celery_app.conf.timezone == 'UTC'
assert celery_app.conf.enable_utc is True
# Check worker configuration
assert celery_app.conf.worker_prefetch_multiplier == 1
assert celery_app.conf.task_acks_late is True
assert celery_app.conf.worker_max_tasks_per_child == 100
# Check time limits
assert celery_app.conf.task_soft_time_limit == 300
assert celery_app.conf.task_time_limit == 600
def test_network_error_decorator(self):
"""Test the network error handling decorator."""
call_count = 0
@handle_network_errors
def failing_function():
nonlocal call_count
call_count += 1
if call_count < 3:
from redis.exceptions import ConnectionError
raise ConnectionError("Connection failed")
return "success"
# Should retry and eventually succeed
result = failing_function()
assert result == "success"
assert call_count == 3
def test_network_error_decorator_max_retries(self):
"""Test that decorator raises NetworkError after max retries."""
@handle_network_errors
def always_failing_function():
from redis.exceptions import ConnectionError
raise ConnectionError("Connection always fails")
# Should raise NetworkError after max retries
with pytest.raises(NetworkError):
always_failing_function()
def test_check_job_cancelled_true(self):
"""Test job cancellation check when job is cancelled."""
mock_redis = MagicMock()
mock_redis.get.return_value = b"true"
result = check_job_cancelled("test-job-123", mock_redis)
assert result is True
mock_redis.get.assert_called_once_with("job:test-job-123:cancel")
def test_check_job_cancelled_false(self):
"""Test job cancellation check when job is not cancelled."""
mock_redis = MagicMock()
mock_redis.get.return_value = None
result = check_job_cancelled("test-job-123", mock_redis)
assert result is False
def test_check_job_cancelled_network_error(self):
"""Test job cancellation check handles network errors gracefully."""
mock_redis = MagicMock()
from redis.exceptions import ConnectionError
mock_redis.get.side_effect = ConnectionError("Network error")
# Should return False (assume not cancelled) when network fails
result = check_job_cancelled("test-job-123", mock_redis)
assert result is False
def test_store_results_redis_small(self):
"""Test storing small results in Redis."""
mock_redis = MagicMock()
mock_minio = MagicMock()
small_results = {"test": "data"}
result_location = store_results_with_fallback(
"test-job-123", small_results, mock_redis, mock_minio, "trace-123"
)
assert result_location.startswith("redis://")
mock_redis.setex.assert_called_once()
mock_minio.put_object.assert_not_called()
def test_store_results_minio_large(self):
"""Test storing large results in MinIO."""
mock_redis = MagicMock()
mock_minio = MagicMock()
# Create large results (> 1MB)
large_results = {"data": "x" * (1024 * 1024 + 1)}
result_location = store_results_with_fallback(
"test-job-123", large_results, mock_redis, mock_minio, "trace-123"
)
assert result_location.startswith("minio://")
mock_minio.put_object.assert_called_once()
def test_store_results_fallback_to_minio(self):
"""Test fallback to MinIO when Redis fails."""
mock_redis = MagicMock()
mock_minio = MagicMock()
from redis.exceptions import ConnectionError
mock_redis.setex.side_effect = ConnectionError("Redis down")
small_results = {"test": "data"}
result_location = store_results_with_fallback(
"test-job-123", small_results, mock_redis, mock_minio, "trace-123"
)
assert result_location.startswith("minio://")
mock_minio.put_object.assert_called_once()
def test_send_to_dlq(self):
"""Test sending failed jobs to Dead Letter Queue."""
mock_redis = MagicMock()
job_payload = {"plan": [{"source": "test", "query": {"operation": "table", "name": "test"}}]}
send_to_dlq("test-job-123", job_payload, "Test error", "permanent_error", mock_redis)
# Should add to DLQ list
mock_redis.lpush.assert_called_once()
mock_redis.expire.assert_called_once()
# Check DLQ entry structure
dlq_call = mock_redis.lpush.call_args[0]
assert dlq_call[0] == "dlq:failed_jobs"
dlq_entry = json.loads(dlq_call[1])
assert dlq_entry["job_id"] == "test-job-123"
assert dlq_entry["error"] == "Test error"
assert dlq_entry["failure_type"] == "permanent_error"
@patch('data_sources.worker.get_redis_client_with_retry')
@patch('data_sources.worker.get_minio_client_with_retry')
@patch('data_sources.worker.get_tenant_config_with_retry')
@patch('data_sources.worker.FederationAgent')
def test_process_job_success(self, mock_agent_class, mock_tenant_config, mock_minio, mock_redis):
"""Test successful job processing."""
# Setup mocks
mock_redis_instance = MagicMock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = None # Not cancelled
mock_minio_instance = MagicMock()
mock_minio.return_value = mock_minio_instance
mock_tenant_config.return_value = [{"source_name": "test_db", "source_type": "ibis", "config": {"uri": "duckdb:///:memory:"}}]
mock_agent = MagicMock()
mock_agent.execute_federated_plan.return_value = {"step_1": [{"id": 1}]}
mock_agent_class.return_value = mock_agent
# Create a mock task
mock_task = MagicMock()
mock_task.request.retries = 0
mock_task.max_retries = 3
# Execute the task function directly
job_payload = {"plan": [{"source": "test_db", "query": {"operation": "table", "name": "orders"}}]}
# Note: We can't easily test the actual Celery task without a running broker,
# but we can test the core logic by calling the task function directly
# This is a limitation of the current test setup
# For now, just verify the imports and basic structure work
assert callable(process_federated_job)
assert hasattr(process_federated_job, 'delay') # Celery task method
class TestErrorClasses:
"""Tests for custom error classes."""
def test_network_error(self):
"""Test NetworkError class."""
error = NetworkError("Network failed")
assert str(error) == "Network failed"
assert isinstance(error, Exception)
def test_transient_error(self):
"""Test TransientError class."""
error = TransientError("Temporary failure")
assert str(error) == "Temporary failure"
assert isinstance(error, Exception)
def test_permanent_error(self):
"""Test PermanentError class."""
error = PermanentError("Permanent failure")
assert str(error) == "Permanent failure"
assert isinstance(error, Exception)
if __name__ == "__main__":
pytest.main([__file__, "-v"])