import asyncio import logging import json import importlib from unittest.mock import AsyncMock, Mock, patch, MagicMock import httpx import pytest from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit from backend.data_sources import schema_search logger = logging.getLogger(__name__) # ============================================================================= # PHASE 1 TESTS - Worker Connection Pooling # ============================================================================= def test_worker_module_has_global_pool_and_minio_client(): """Verify worker.py declares _redis_pool and _minio_client globals for connection reuse.""" from backend.data_sources import worker as worker_module # Module should have these global attributes defined (initially None) assert hasattr(worker_module, '_redis_pool'), "_redis_pool global not found" assert hasattr(worker_module, '_minio_client'), "_minio_client global not found" def test_worker_init_process_connects_to_celery_signal(): """Verify init_worker_process is decorated with @worker_process_init.connect.""" from backend.data_sources import worker as worker_module assert hasattr(worker_module, 'init_worker_process'), "init_worker_process function not found" # Celery signals register receivers in a receivers list - the function should be registered from celery.signals import worker_process_init # Check function exists and is callable assert callable(worker_module.init_worker_process) def test_worker_init_creates_redis_pool_and_minio(): """Test init_worker_process creates ConnectionPool and Minio client.""" from backend.data_sources import worker as worker_module import redis # Mock config functions to return test values with patch.object(worker_module, 'get_redis_url', return_value='redis://localhost:6379/0'), \ patch.object(worker_module, 'get_minio_config', return_value={ 'endpoint': 'localhost:9000', 'access_key': 'test_key', 'secret_key': 'test_secret', 'secure': False }), \ patch('redis.ConnectionPool.from_url') as mock_pool, \ patch.object(worker_module, 'Minio') as mock_minio: mock_pool.return_value = MagicMock(spec=redis.ConnectionPool) mock_minio.return_value = MagicMock() # Reset globals worker_module._redis_pool = None worker_module._minio_client = None # Call init worker_module.init_worker_process() # Assert pool created with correct url mock_pool.assert_called_once_with('redis://localhost:6379/0') # Assert Minio initialized with config mock_minio.assert_called_once() call_kwargs = mock_minio.call_args.kwargs assert call_kwargs['endpoint'] == 'localhost:9000' assert call_kwargs['access_key'] == 'test_key' # Globals should be set assert worker_module._redis_pool is not None assert worker_module._minio_client is not None def test_worker_task_uses_global_pool_not_new_connection(): """Verify federated task uses redis.Redis(connection_pool=_redis_pool).""" from backend.data_sources import worker as worker_module import redis # Read the source code to verify the pattern is used import inspect source = inspect.getsource(worker_module.process_federated_job) # The task should use connection_pool=_redis_pool pattern assert 'connection_pool=_redis_pool' in source, \ "process_federated_job should use redis.Redis(connection_pool=_redis_pool)" # Also check it references _minio_client assert '_minio_client' in source, \ "process_federated_job should use global _minio_client" def test_worker_uses_orjson_for_serialization(): """Verify worker imports orjson aliased as json for fast serialization.""" from backend.data_sources import worker as worker_module # Check the json binding in the module points to orjson # Import orjson to compare import orjson # The module should have imported orjson as json # We can verify by checking if json.dumps returns bytes (orjson) vs str (stdlib json) # orjson.dumps returns bytes, stdlib json.dumps returns str assert hasattr(worker_module, 'json'), "worker module should have json attribute" # ============================================================================= # PHASE 2 TESTS - API ORJSONResponse and Concurrent Fetch # ============================================================================= def test_api_heavy_endpoints_use_orjson_response(): """Verify heavy endpoints like /results/{job_id} use ORJSONResponse.""" from backend.data_sources import api as api_module import inspect source = inspect.getsource(api_module) # Check ORJSONResponse is imported assert 'ORJSONResponse' in source, "ORJSONResponse should be imported in api.py" # Check it's used in response_class parameter assert 'response_class=ORJSONResponse' in source, \ "Heavy endpoints should specify response_class=ORJSONResponse" def test_api_has_async_fetch_schema_helper(): """Verify api.py defines _fetch_schema_for async helper for concurrent fetches.""" from backend.data_sources import api as api_module import inspect source = inspect.getsource(api_module) # Check the concurrent fetch pattern exists assert 'async def _fetch_schema_for' in source or '_fetch_schema_for' in source, \ "_fetch_schema_for async helper should exist for concurrent schema fetching" assert 'asyncio.to_thread' in source, \ "asyncio.to_thread should be used to run sync get_schema in thread pool" assert 'asyncio.gather' in source, \ "asyncio.gather should be used for concurrent execution" @pytest.mark.asyncio async def test_asyncio_gather_concurrent_execution_pattern(): """Test that asyncio.gather runs multiple coroutines concurrently (validates Phase 2 pattern).""" import time async def slow_fetch(name: str, delay: float): """Simulate slow schema fetch.""" await asyncio.sleep(delay) return (name, f"schema_{name}", None) # Without concurrency, 3x 0.1s delays = 0.3s # With concurrency, should be ~0.1s start = time.perf_counter() tasks = [ slow_fetch("source1", 0.1), slow_fetch("source2", 0.1), slow_fetch("source3", 0.1), ] results = await asyncio.gather(*tasks) elapsed = time.perf_counter() - start # Should complete in roughly 0.1s (with some margin), not 0.3s assert elapsed < 0.25, f"asyncio.gather should run concurrently, took {elapsed:.3f}s" assert len(results) == 3 assert results[0] == ("source1", "schema_source1", None) @pytest.mark.asyncio async def test_asyncio_to_thread_offloads_blocking_calls(): """Test asyncio.to_thread correctly offloads sync code (validates Phase 2 pattern).""" import time def blocking_get_schema(): """Simulate blocking I/O.""" time.sleep(0.05) return "schema_data" # Run blocking call via to_thread result = await asyncio.to_thread(blocking_get_schema) assert result == "schema_data" # ============================================================================= # PHASE 3 TESTS - Toolkit Async Client # ============================================================================= @pytest.mark.asyncio async def test_list_sources_returns_coroutine_inside_event_loop_and_awaits(): """When called inside an event loop, list_sources returns a coroutine which can be awaited. This test ensures the toolkit prefers the AsyncClient and that the async path is exercised. """ toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) # Prepare a fake async response response_data = {"available_sources": ["s1", "s2"], "count": 2, "tenant_id": "tenant-x"} resp = httpx.Response(200, json=response_data) # Replace async client with an AsyncMock having an awaitable .get async_client = AsyncMock() async_client.get = AsyncMock(return_value=resp) toolkit._async_client = async_client session_state = {"tenant_id": "tenant-x", "supabase_jwt": "jwt-token"} # In event loop, list_sources should return a coroutine coro = toolkit.list_sources(session_state=session_state) assert asyncio.iscoroutine(coro), "Expected a coroutine when called inside an event loop" result = await coro # Async client should have been awaited async_client.get.assert_awaited() assert result["available_sources"] == ["s1", "s2"] assert result["count"] == 2 def test_list_sources_sync_runs_asyncclient_and_returns_result(): """When called from sync context, list_sources runs the async impl via asyncio.run and returns result.""" toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) response_data = {"available_sources": ["alpha"], "count": 1, "tenant_id": "t1"} resp = httpx.Response(200, json=response_data) # Provide AsyncMock for async client async_client = AsyncMock() async_client.get = AsyncMock(return_value=resp) toolkit._async_client = async_client session_state = {"tenant_id": "t1", "supabase_jwt": "token"} result = toolkit.list_sources(session_state=session_state) # The async client's get must have been awaited via asyncio.run async_client.get.assert_awaited() assert result["available_sources"] == ["alpha"] assert result["count"] == 1 def test_search_schema_prefers_async_and_caches(): """Ensure search_schema uses async client when available and caches results in toolkit._search_cache.""" toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) payload = { "formatted_schema_string": "schema-content", "matches": [{"table": "users"}], "available_sources": ["s1"], "total_matches": 1, } resp = httpx.Response(200, json=payload) async_client = AsyncMock() async_client.post = AsyncMock(return_value=resp) toolkit._async_client = async_client # Must provide session_state with tenant_id for _resolve_tenant session_state = {"tenant_id": "tX", "supabase_jwt": "test-jwt"} result = toolkit.search_schema(keywords=["user"], session_state=session_state) # If running in a sync context, _run_coro_sync invoked asyncio.run, so AsyncMock should be awaited async_client.post.assert_awaited() # The result should be the payload mapped to our normalized keys assert result["formatted_schema_string"] == "schema-content" # Toolkit-level cache should contain at least one entry assert len(toolkit._search_cache) >= 1 def test_execute_sql_query_sync_path_and_error_handling(): """Test execute_sql_query uses the sync client and returns structured results on success.""" toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) success_payload = {"status": "success", "results": [{"id": 1}], "rows_returned": 1} resp = httpx.Response(200, json=success_payload) # Replace sync client with a simple Mock holding a post method mock_client = Mock() mock_client.post = Mock(return_value=resp) toolkit.client = mock_client session_state = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"} # Create a mock session state mock_session = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"} # Pass session_state to the function result = toolkit.execute_sql_query( tenant_id='t1', source_name='s1', sql_query='SELECT * FROM big_table', session_state=mock_session # <--- CRITICAL FIX ) assert result["status"] == "success" assert isinstance(result["results"], list) and result["rows_returned"] == 1 # ============================================================================= # PHASE 4 TESTS - Schema Search Cache Increase # ============================================================================= def test_split_identifier_cache_increased_and_tokens(): """Verify the _split_identifier cache maxsize was increased and tokenization behaves as expected.""" func = schema_search._split_identifier # Clear cache to get accurate hit counts for this test func.cache_clear() info = func.cache_info() assert info.maxsize == 4096, f"Expected cache maxsize 4096, got {info.maxsize}" # Call several times to generate cache hits tokens1 = func("OrderID") tokens2 = func("order_id") tokens3 = func("OrderID") # This should be a cache hit assert tokens1 == ["order", "id"], f"Expected ['order', 'id'], got {tokens1}" assert tokens2 == ["order", "id"], f"Expected ['order', 'id'], got {tokens2}" # After calling again with same input, hits should be > 0 assert func.cache_info().hits >= 1, f"Expected at least 1 cache hit, got {func.cache_info().hits}" # ============================================================================= # PHASE 5 TESTS - Large Result Handling (MinIO + Pandas) # ============================================================================= def test_execute_sql_query_minio_result_triggers_summary(): """Test that execute_sql_query detects minio:// result, fetches file, and returns summary + head.""" import pandas as pd import io from unittest.mock import patch, MagicMock # Prepare a fake DataFrame and its JSON bytes df = pd.DataFrame({ 'a': range(1, 301), 'b': [x * 2 for x in range(1, 301)] }) json_bytes = df.to_json().encode() # Mock MinIO response fake_minio_response = MagicMock() fake_minio_response.read.return_value = json_bytes fake_minio_client = MagicMock() fake_minio_client.get_object.return_value = fake_minio_response fake_minio_config = {'endpoint': 'localhost:9000', 'access_key': 'x', 'secret_key': 'y', 'secure': False} # Mock HTTP response from API fake_http_response = MagicMock() fake_http_response.status_code = 200 fake_http_response.json.return_value = { 'status': 'success', 'results': 'minio://testbucket/job-results/123.json' } with patch('backend.core.minio.config.get_minio_config', return_value=fake_minio_config), \ patch('minio.Minio', return_value=fake_minio_client): from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) toolkit.client = MagicMock() toolkit.client.post.return_value = fake_http_response # Create a mock session state mock_session = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"} # Pass session_state to the function result = toolkit.execute_sql_query( tenant_id='t1', source_name='s1', sql_query='SELECT * FROM big_table', session_state=mock_session # <--- CRITICAL FIX ) # Should return summary and head rows assert result.get('status') == 'success', f"Expected status='success', got {result}" assert 'summary' in result, f"Expected 'summary' in result, got keys: {result.keys()}" assert 'results' in result assert isinstance(result['results'], list) assert len(result['results']) == 200, f"Expected 200 head rows, got {len(result['results'])}" assert 'Result too large' in result['message'] assert result['row_count'] == 300 assert result['rows_limited'] is True