Spaces:
Running
Running
| import asyncio | |
| import logging | |
| import json | |
| import importlib | |
| from unittest.mock import AsyncMock, Mock, patch, MagicMock | |
| import httpx | |
| import pytest | |
| from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit | |
| from backend.data_sources import schema_search | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # PHASE 1 TESTS - Worker Connection Pooling | |
| # ============================================================================= | |
| def test_worker_module_has_global_pool_and_minio_client(): | |
| """Verify worker.py declares _redis_pool and _minio_client globals for connection reuse.""" | |
| from backend.data_sources import worker as worker_module | |
| # Module should have these global attributes defined (initially None) | |
| assert hasattr(worker_module, '_redis_pool'), "_redis_pool global not found" | |
| assert hasattr(worker_module, '_minio_client'), "_minio_client global not found" | |
| def test_worker_init_process_connects_to_celery_signal(): | |
| """Verify init_worker_process is decorated with @worker_process_init.connect.""" | |
| from backend.data_sources import worker as worker_module | |
| assert hasattr(worker_module, 'init_worker_process'), "init_worker_process function not found" | |
| # Celery signals register receivers in a receivers list - the function should be registered | |
| from celery.signals import worker_process_init | |
| # Check function exists and is callable | |
| assert callable(worker_module.init_worker_process) | |
| def test_worker_init_creates_redis_pool_and_minio(): | |
| """Test init_worker_process creates ConnectionPool and Minio client.""" | |
| from backend.data_sources import worker as worker_module | |
| import redis | |
| # Mock config functions to return test values | |
| with patch.object(worker_module, 'get_redis_url', return_value='redis://localhost:6379/0'), \ | |
| patch.object(worker_module, 'get_minio_config', return_value={ | |
| 'endpoint': 'localhost:9000', | |
| 'access_key': 'test_key', | |
| 'secret_key': 'test_secret', | |
| 'secure': False | |
| }), \ | |
| patch('redis.ConnectionPool.from_url') as mock_pool, \ | |
| patch.object(worker_module, 'Minio') as mock_minio: | |
| mock_pool.return_value = MagicMock(spec=redis.ConnectionPool) | |
| mock_minio.return_value = MagicMock() | |
| # Reset globals | |
| worker_module._redis_pool = None | |
| worker_module._minio_client = None | |
| # Call init | |
| worker_module.init_worker_process() | |
| # Assert pool created with correct url | |
| mock_pool.assert_called_once_with('redis://localhost:6379/0') | |
| # Assert Minio initialized with config | |
| mock_minio.assert_called_once() | |
| call_kwargs = mock_minio.call_args.kwargs | |
| assert call_kwargs['endpoint'] == 'localhost:9000' | |
| assert call_kwargs['access_key'] == 'test_key' | |
| # Globals should be set | |
| assert worker_module._redis_pool is not None | |
| assert worker_module._minio_client is not None | |
| def test_worker_task_uses_global_pool_not_new_connection(): | |
| """Verify federated task uses redis.Redis(connection_pool=_redis_pool).""" | |
| from backend.data_sources import worker as worker_module | |
| import redis | |
| # Read the source code to verify the pattern is used | |
| import inspect | |
| source = inspect.getsource(worker_module.process_federated_job) | |
| # The task should use connection_pool=_redis_pool pattern | |
| assert 'connection_pool=_redis_pool' in source, \ | |
| "process_federated_job should use redis.Redis(connection_pool=_redis_pool)" | |
| # Also check it references _minio_client | |
| assert '_minio_client' in source, \ | |
| "process_federated_job should use global _minio_client" | |
| def test_worker_uses_orjson_for_serialization(): | |
| """Verify worker imports orjson aliased as json for fast serialization.""" | |
| from backend.data_sources import worker as worker_module | |
| # Check the json binding in the module points to orjson | |
| # Import orjson to compare | |
| import orjson | |
| # The module should have imported orjson as json | |
| # We can verify by checking if json.dumps returns bytes (orjson) vs str (stdlib json) | |
| # orjson.dumps returns bytes, stdlib json.dumps returns str | |
| assert hasattr(worker_module, 'json'), "worker module should have json attribute" | |
| # ============================================================================= | |
| # PHASE 2 TESTS - API ORJSONResponse and Concurrent Fetch | |
| # ============================================================================= | |
| def test_api_heavy_endpoints_use_orjson_response(): | |
| """Verify heavy endpoints like /results/{job_id} use ORJSONResponse.""" | |
| from backend.data_sources import api as api_module | |
| import inspect | |
| source = inspect.getsource(api_module) | |
| # Check ORJSONResponse is imported | |
| assert 'ORJSONResponse' in source, "ORJSONResponse should be imported in api.py" | |
| # Check it's used in response_class parameter | |
| assert 'response_class=ORJSONResponse' in source, \ | |
| "Heavy endpoints should specify response_class=ORJSONResponse" | |
| def test_api_has_async_fetch_schema_helper(): | |
| """Verify api.py defines _fetch_schema_for async helper for concurrent fetches.""" | |
| from backend.data_sources import api as api_module | |
| import inspect | |
| source = inspect.getsource(api_module) | |
| # Check the concurrent fetch pattern exists | |
| assert 'async def _fetch_schema_for' in source or '_fetch_schema_for' in source, \ | |
| "_fetch_schema_for async helper should exist for concurrent schema fetching" | |
| assert 'asyncio.to_thread' in source, \ | |
| "asyncio.to_thread should be used to run sync get_schema in thread pool" | |
| assert 'asyncio.gather' in source, \ | |
| "asyncio.gather should be used for concurrent execution" | |
| async def test_asyncio_gather_concurrent_execution_pattern(): | |
| """Test that asyncio.gather runs multiple coroutines concurrently (validates Phase 2 pattern).""" | |
| import time | |
| async def slow_fetch(name: str, delay: float): | |
| """Simulate slow schema fetch.""" | |
| await asyncio.sleep(delay) | |
| return (name, f"schema_{name}", None) | |
| # Without concurrency, 3x 0.1s delays = 0.3s | |
| # With concurrency, should be ~0.1s | |
| start = time.perf_counter() | |
| tasks = [ | |
| slow_fetch("source1", 0.1), | |
| slow_fetch("source2", 0.1), | |
| slow_fetch("source3", 0.1), | |
| ] | |
| results = await asyncio.gather(*tasks) | |
| elapsed = time.perf_counter() - start | |
| # Should complete in roughly 0.1s (with some margin), not 0.3s | |
| assert elapsed < 0.25, f"asyncio.gather should run concurrently, took {elapsed:.3f}s" | |
| assert len(results) == 3 | |
| assert results[0] == ("source1", "schema_source1", None) | |
| async def test_asyncio_to_thread_offloads_blocking_calls(): | |
| """Test asyncio.to_thread correctly offloads sync code (validates Phase 2 pattern).""" | |
| import time | |
| def blocking_get_schema(): | |
| """Simulate blocking I/O.""" | |
| time.sleep(0.05) | |
| return "schema_data" | |
| # Run blocking call via to_thread | |
| result = await asyncio.to_thread(blocking_get_schema) | |
| assert result == "schema_data" | |
| # ============================================================================= | |
| # PHASE 3 TESTS - Toolkit Async Client | |
| # ============================================================================= | |
| async def test_list_sources_returns_coroutine_inside_event_loop_and_awaits(): | |
| """When called inside an event loop, list_sources returns a coroutine which can be awaited. | |
| This test ensures the toolkit prefers the AsyncClient and that the async path is exercised. | |
| """ | |
| toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) | |
| # Prepare a fake async response | |
| response_data = {"available_sources": ["s1", "s2"], "count": 2, "tenant_id": "tenant-x"} | |
| resp = httpx.Response(200, json=response_data) | |
| # Replace async client with an AsyncMock having an awaitable .get | |
| async_client = AsyncMock() | |
| async_client.get = AsyncMock(return_value=resp) | |
| toolkit._async_client = async_client | |
| session_state = {"tenant_id": "tenant-x", "supabase_jwt": "jwt-token"} | |
| # In event loop, list_sources should return a coroutine | |
| coro = toolkit.list_sources(session_state=session_state) | |
| assert asyncio.iscoroutine(coro), "Expected a coroutine when called inside an event loop" | |
| result = await coro | |
| # Async client should have been awaited | |
| async_client.get.assert_awaited() | |
| assert result["available_sources"] == ["s1", "s2"] | |
| assert result["count"] == 2 | |
| def test_list_sources_sync_runs_asyncclient_and_returns_result(): | |
| """When called from sync context, list_sources runs the async impl via asyncio.run and returns result.""" | |
| toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) | |
| response_data = {"available_sources": ["alpha"], "count": 1, "tenant_id": "t1"} | |
| resp = httpx.Response(200, json=response_data) | |
| # Provide AsyncMock for async client | |
| async_client = AsyncMock() | |
| async_client.get = AsyncMock(return_value=resp) | |
| toolkit._async_client = async_client | |
| session_state = {"tenant_id": "t1", "supabase_jwt": "token"} | |
| result = toolkit.list_sources(session_state=session_state) | |
| # The async client's get must have been awaited via asyncio.run | |
| async_client.get.assert_awaited() | |
| assert result["available_sources"] == ["alpha"] | |
| assert result["count"] == 1 | |
| def test_search_schema_prefers_async_and_caches(): | |
| """Ensure search_schema uses async client when available and caches results in toolkit._search_cache.""" | |
| toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) | |
| payload = { | |
| "formatted_schema_string": "schema-content", | |
| "matches": [{"table": "users"}], | |
| "available_sources": ["s1"], | |
| "total_matches": 1, | |
| } | |
| resp = httpx.Response(200, json=payload) | |
| async_client = AsyncMock() | |
| async_client.post = AsyncMock(return_value=resp) | |
| toolkit._async_client = async_client | |
| # Must provide session_state with tenant_id for _resolve_tenant | |
| session_state = {"tenant_id": "tX", "supabase_jwt": "test-jwt"} | |
| result = toolkit.search_schema(keywords=["user"], session_state=session_state) | |
| # If running in a sync context, _run_coro_sync invoked asyncio.run, so AsyncMock should be awaited | |
| async_client.post.assert_awaited() | |
| # The result should be the payload mapped to our normalized keys | |
| assert result["formatted_schema_string"] == "schema-content" | |
| # Toolkit-level cache should contain at least one entry | |
| assert len(toolkit._search_cache) >= 1 | |
| def test_execute_sql_query_sync_path_and_error_handling(): | |
| """Test execute_sql_query uses the sync client and returns structured results on success.""" | |
| toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) | |
| success_payload = {"status": "success", "results": [{"id": 1}], "rows_returned": 1} | |
| resp = httpx.Response(200, json=success_payload) | |
| # Replace sync client with a simple Mock holding a post method | |
| mock_client = Mock() | |
| mock_client.post = Mock(return_value=resp) | |
| toolkit.client = mock_client | |
| session_state = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"} | |
| # Create a mock session state | |
| mock_session = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"} | |
| # Pass session_state to the function | |
| result = toolkit.execute_sql_query( | |
| tenant_id='t1', | |
| source_name='s1', | |
| sql_query='SELECT * FROM big_table', | |
| session_state=mock_session # <--- CRITICAL FIX | |
| ) | |
| assert result["status"] == "success" | |
| assert isinstance(result["results"], list) and result["rows_returned"] == 1 | |
| # ============================================================================= | |
| # PHASE 4 TESTS - Schema Search Cache Increase | |
| # ============================================================================= | |
| def test_split_identifier_cache_increased_and_tokens(): | |
| """Verify the _split_identifier cache maxsize was increased and tokenization behaves as expected.""" | |
| func = schema_search._split_identifier | |
| # Clear cache to get accurate hit counts for this test | |
| func.cache_clear() | |
| info = func.cache_info() | |
| assert info.maxsize == 4096, f"Expected cache maxsize 4096, got {info.maxsize}" | |
| # Call several times to generate cache hits | |
| tokens1 = func("OrderID") | |
| tokens2 = func("order_id") | |
| tokens3 = func("OrderID") # This should be a cache hit | |
| assert tokens1 == ["order", "id"], f"Expected ['order', 'id'], got {tokens1}" | |
| assert tokens2 == ["order", "id"], f"Expected ['order', 'id'], got {tokens2}" | |
| # After calling again with same input, hits should be > 0 | |
| assert func.cache_info().hits >= 1, f"Expected at least 1 cache hit, got {func.cache_info().hits}" | |
| # ============================================================================= | |
| # PHASE 5 TESTS - Large Result Handling (MinIO + Pandas) | |
| # ============================================================================= | |
| def test_execute_sql_query_minio_result_triggers_summary(): | |
| """Test that execute_sql_query detects minio:// result, fetches file, and returns summary + head.""" | |
| import pandas as pd | |
| import io | |
| from unittest.mock import patch, MagicMock | |
| # Prepare a fake DataFrame and its JSON bytes | |
| df = pd.DataFrame({ | |
| 'a': range(1, 301), | |
| 'b': [x * 2 for x in range(1, 301)] | |
| }) | |
| json_bytes = df.to_json().encode() | |
| # Mock MinIO response | |
| fake_minio_response = MagicMock() | |
| fake_minio_response.read.return_value = json_bytes | |
| fake_minio_client = MagicMock() | |
| fake_minio_client.get_object.return_value = fake_minio_response | |
| fake_minio_config = {'endpoint': 'localhost:9000', 'access_key': 'x', 'secret_key': 'y', 'secure': False} | |
| # Mock HTTP response from API | |
| fake_http_response = MagicMock() | |
| fake_http_response.status_code = 200 | |
| fake_http_response.json.return_value = { | |
| 'status': 'success', | |
| 'results': 'minio://testbucket/job-results/123.json' | |
| } | |
| with patch('backend.core.minio.config.get_minio_config', return_value=fake_minio_config), \ | |
| patch('minio.Minio', return_value=fake_minio_client): | |
| from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit | |
| toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5) | |
| toolkit.client = MagicMock() | |
| toolkit.client.post.return_value = fake_http_response | |
| # Create a mock session state | |
| mock_session = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"} | |
| # Pass session_state to the function | |
| result = toolkit.execute_sql_query( | |
| tenant_id='t1', | |
| source_name='s1', | |
| sql_query='SELECT * FROM big_table', | |
| session_state=mock_session # <--- CRITICAL FIX | |
| ) | |
| # Should return summary and head rows | |
| assert result.get('status') == 'success', f"Expected status='success', got {result}" | |
| assert 'summary' in result, f"Expected 'summary' in result, got keys: {result.keys()}" | |
| assert 'results' in result | |
| assert isinstance(result['results'], list) | |
| assert len(result['results']) == 200, f"Expected 200 head rows, got {len(result['results'])}" | |
| assert 'Result too large' in result['message'] | |
| assert result['row_count'] == 300 | |
| assert result['rows_limited'] is True | |