sirus / backend /SQL_Agent /tests /test_toolkit_end_to_end.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
b8277c4
import asyncio
import logging
import json
import importlib
from unittest.mock import AsyncMock, Mock, patch, MagicMock
import httpx
import pytest
from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit
from backend.data_sources import schema_search
logger = logging.getLogger(__name__)
# =============================================================================
# PHASE 1 TESTS - Worker Connection Pooling
# =============================================================================
def test_worker_module_has_global_pool_and_minio_client():
"""Verify worker.py declares _redis_pool and _minio_client globals for connection reuse."""
from backend.data_sources import worker as worker_module
# Module should have these global attributes defined (initially None)
assert hasattr(worker_module, '_redis_pool'), "_redis_pool global not found"
assert hasattr(worker_module, '_minio_client'), "_minio_client global not found"
def test_worker_init_process_connects_to_celery_signal():
"""Verify init_worker_process is decorated with @worker_process_init.connect."""
from backend.data_sources import worker as worker_module
assert hasattr(worker_module, 'init_worker_process'), "init_worker_process function not found"
# Celery signals register receivers in a receivers list - the function should be registered
from celery.signals import worker_process_init
# Check function exists and is callable
assert callable(worker_module.init_worker_process)
def test_worker_init_creates_redis_pool_and_minio():
"""Test init_worker_process creates ConnectionPool and Minio client."""
from backend.data_sources import worker as worker_module
import redis
# Mock config functions to return test values
with patch.object(worker_module, 'get_redis_url', return_value='redis://localhost:6379/0'), \
patch.object(worker_module, 'get_minio_config', return_value={
'endpoint': 'localhost:9000',
'access_key': 'test_key',
'secret_key': 'test_secret',
'secure': False
}), \
patch('redis.ConnectionPool.from_url') as mock_pool, \
patch.object(worker_module, 'Minio') as mock_minio:
mock_pool.return_value = MagicMock(spec=redis.ConnectionPool)
mock_minio.return_value = MagicMock()
# Reset globals
worker_module._redis_pool = None
worker_module._minio_client = None
# Call init
worker_module.init_worker_process()
# Assert pool created with correct url
mock_pool.assert_called_once_with('redis://localhost:6379/0')
# Assert Minio initialized with config
mock_minio.assert_called_once()
call_kwargs = mock_minio.call_args.kwargs
assert call_kwargs['endpoint'] == 'localhost:9000'
assert call_kwargs['access_key'] == 'test_key'
# Globals should be set
assert worker_module._redis_pool is not None
assert worker_module._minio_client is not None
def test_worker_task_uses_global_pool_not_new_connection():
"""Verify federated task uses redis.Redis(connection_pool=_redis_pool)."""
from backend.data_sources import worker as worker_module
import redis
# Read the source code to verify the pattern is used
import inspect
source = inspect.getsource(worker_module.process_federated_job)
# The task should use connection_pool=_redis_pool pattern
assert 'connection_pool=_redis_pool' in source, \
"process_federated_job should use redis.Redis(connection_pool=_redis_pool)"
# Also check it references _minio_client
assert '_minio_client' in source, \
"process_federated_job should use global _minio_client"
def test_worker_uses_orjson_for_serialization():
"""Verify worker imports orjson aliased as json for fast serialization."""
from backend.data_sources import worker as worker_module
# Check the json binding in the module points to orjson
# Import orjson to compare
import orjson
# The module should have imported orjson as json
# We can verify by checking if json.dumps returns bytes (orjson) vs str (stdlib json)
# orjson.dumps returns bytes, stdlib json.dumps returns str
assert hasattr(worker_module, 'json'), "worker module should have json attribute"
# =============================================================================
# PHASE 2 TESTS - API ORJSONResponse and Concurrent Fetch
# =============================================================================
def test_api_heavy_endpoints_use_orjson_response():
"""Verify heavy endpoints like /results/{job_id} use ORJSONResponse."""
from backend.data_sources import api as api_module
import inspect
source = inspect.getsource(api_module)
# Check ORJSONResponse is imported
assert 'ORJSONResponse' in source, "ORJSONResponse should be imported in api.py"
# Check it's used in response_class parameter
assert 'response_class=ORJSONResponse' in source, \
"Heavy endpoints should specify response_class=ORJSONResponse"
def test_api_has_async_fetch_schema_helper():
"""Verify api.py defines _fetch_schema_for async helper for concurrent fetches."""
from backend.data_sources import api as api_module
import inspect
source = inspect.getsource(api_module)
# Check the concurrent fetch pattern exists
assert 'async def _fetch_schema_for' in source or '_fetch_schema_for' in source, \
"_fetch_schema_for async helper should exist for concurrent schema fetching"
assert 'asyncio.to_thread' in source, \
"asyncio.to_thread should be used to run sync get_schema in thread pool"
assert 'asyncio.gather' in source, \
"asyncio.gather should be used for concurrent execution"
@pytest.mark.asyncio
async def test_asyncio_gather_concurrent_execution_pattern():
"""Test that asyncio.gather runs multiple coroutines concurrently (validates Phase 2 pattern)."""
import time
async def slow_fetch(name: str, delay: float):
"""Simulate slow schema fetch."""
await asyncio.sleep(delay)
return (name, f"schema_{name}", None)
# Without concurrency, 3x 0.1s delays = 0.3s
# With concurrency, should be ~0.1s
start = time.perf_counter()
tasks = [
slow_fetch("source1", 0.1),
slow_fetch("source2", 0.1),
slow_fetch("source3", 0.1),
]
results = await asyncio.gather(*tasks)
elapsed = time.perf_counter() - start
# Should complete in roughly 0.1s (with some margin), not 0.3s
assert elapsed < 0.25, f"asyncio.gather should run concurrently, took {elapsed:.3f}s"
assert len(results) == 3
assert results[0] == ("source1", "schema_source1", None)
@pytest.mark.asyncio
async def test_asyncio_to_thread_offloads_blocking_calls():
"""Test asyncio.to_thread correctly offloads sync code (validates Phase 2 pattern)."""
import time
def blocking_get_schema():
"""Simulate blocking I/O."""
time.sleep(0.05)
return "schema_data"
# Run blocking call via to_thread
result = await asyncio.to_thread(blocking_get_schema)
assert result == "schema_data"
# =============================================================================
# PHASE 3 TESTS - Toolkit Async Client
# =============================================================================
@pytest.mark.asyncio
async def test_list_sources_returns_coroutine_inside_event_loop_and_awaits():
"""When called inside an event loop, list_sources returns a coroutine which can be awaited.
This test ensures the toolkit prefers the AsyncClient and that the async path is exercised.
"""
toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)
# Prepare a fake async response
response_data = {"available_sources": ["s1", "s2"], "count": 2, "tenant_id": "tenant-x"}
resp = httpx.Response(200, json=response_data)
# Replace async client with an AsyncMock having an awaitable .get
async_client = AsyncMock()
async_client.get = AsyncMock(return_value=resp)
toolkit._async_client = async_client
session_state = {"tenant_id": "tenant-x", "supabase_jwt": "jwt-token"}
# In event loop, list_sources should return a coroutine
coro = toolkit.list_sources(session_state=session_state)
assert asyncio.iscoroutine(coro), "Expected a coroutine when called inside an event loop"
result = await coro
# Async client should have been awaited
async_client.get.assert_awaited()
assert result["available_sources"] == ["s1", "s2"]
assert result["count"] == 2
def test_list_sources_sync_runs_asyncclient_and_returns_result():
"""When called from sync context, list_sources runs the async impl via asyncio.run and returns result."""
toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)
response_data = {"available_sources": ["alpha"], "count": 1, "tenant_id": "t1"}
resp = httpx.Response(200, json=response_data)
# Provide AsyncMock for async client
async_client = AsyncMock()
async_client.get = AsyncMock(return_value=resp)
toolkit._async_client = async_client
session_state = {"tenant_id": "t1", "supabase_jwt": "token"}
result = toolkit.list_sources(session_state=session_state)
# The async client's get must have been awaited via asyncio.run
async_client.get.assert_awaited()
assert result["available_sources"] == ["alpha"]
assert result["count"] == 1
def test_search_schema_prefers_async_and_caches():
"""Ensure search_schema uses async client when available and caches results in toolkit._search_cache."""
toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)
payload = {
"formatted_schema_string": "schema-content",
"matches": [{"table": "users"}],
"available_sources": ["s1"],
"total_matches": 1,
}
resp = httpx.Response(200, json=payload)
async_client = AsyncMock()
async_client.post = AsyncMock(return_value=resp)
toolkit._async_client = async_client
# Must provide session_state with tenant_id for _resolve_tenant
session_state = {"tenant_id": "tX", "supabase_jwt": "test-jwt"}
result = toolkit.search_schema(keywords=["user"], session_state=session_state)
# If running in a sync context, _run_coro_sync invoked asyncio.run, so AsyncMock should be awaited
async_client.post.assert_awaited()
# The result should be the payload mapped to our normalized keys
assert result["formatted_schema_string"] == "schema-content"
# Toolkit-level cache should contain at least one entry
assert len(toolkit._search_cache) >= 1
def test_execute_sql_query_sync_path_and_error_handling():
"""Test execute_sql_query uses the sync client and returns structured results on success."""
toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)
success_payload = {"status": "success", "results": [{"id": 1}], "rows_returned": 1}
resp = httpx.Response(200, json=success_payload)
# Replace sync client with a simple Mock holding a post method
mock_client = Mock()
mock_client.post = Mock(return_value=resp)
toolkit.client = mock_client
session_state = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"}
# Create a mock session state
mock_session = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"}
# Pass session_state to the function
result = toolkit.execute_sql_query(
tenant_id='t1',
source_name='s1',
sql_query='SELECT * FROM big_table',
session_state=mock_session # <--- CRITICAL FIX
)
assert result["status"] == "success"
assert isinstance(result["results"], list) and result["rows_returned"] == 1
# =============================================================================
# PHASE 4 TESTS - Schema Search Cache Increase
# =============================================================================
def test_split_identifier_cache_increased_and_tokens():
"""Verify the _split_identifier cache maxsize was increased and tokenization behaves as expected."""
func = schema_search._split_identifier
# Clear cache to get accurate hit counts for this test
func.cache_clear()
info = func.cache_info()
assert info.maxsize == 4096, f"Expected cache maxsize 4096, got {info.maxsize}"
# Call several times to generate cache hits
tokens1 = func("OrderID")
tokens2 = func("order_id")
tokens3 = func("OrderID") # This should be a cache hit
assert tokens1 == ["order", "id"], f"Expected ['order', 'id'], got {tokens1}"
assert tokens2 == ["order", "id"], f"Expected ['order', 'id'], got {tokens2}"
# After calling again with same input, hits should be > 0
assert func.cache_info().hits >= 1, f"Expected at least 1 cache hit, got {func.cache_info().hits}"
# =============================================================================
# PHASE 5 TESTS - Large Result Handling (MinIO + Pandas)
# =============================================================================
def test_execute_sql_query_minio_result_triggers_summary():
"""Test that execute_sql_query detects minio:// result, fetches file, and returns summary + head."""
import pandas as pd
import io
from unittest.mock import patch, MagicMock
# Prepare a fake DataFrame and its JSON bytes
df = pd.DataFrame({
'a': range(1, 301),
'b': [x * 2 for x in range(1, 301)]
})
json_bytes = df.to_json().encode()
# Mock MinIO response
fake_minio_response = MagicMock()
fake_minio_response.read.return_value = json_bytes
fake_minio_client = MagicMock()
fake_minio_client.get_object.return_value = fake_minio_response
fake_minio_config = {'endpoint': 'localhost:9000', 'access_key': 'x', 'secret_key': 'y', 'secure': False}
# Mock HTTP response from API
fake_http_response = MagicMock()
fake_http_response.status_code = 200
fake_http_response.json.return_value = {
'status': 'success',
'results': 'minio://testbucket/job-results/123.json'
}
with patch('backend.core.minio.config.get_minio_config', return_value=fake_minio_config), \
patch('minio.Minio', return_value=fake_minio_client):
from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit
toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)
toolkit.client = MagicMock()
toolkit.client.post.return_value = fake_http_response
# Create a mock session state
mock_session = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"}
# Pass session_state to the function
result = toolkit.execute_sql_query(
tenant_id='t1',
source_name='s1',
sql_query='SELECT * FROM big_table',
session_state=mock_session # <--- CRITICAL FIX
)
# Should return summary and head rows
assert result.get('status') == 'success', f"Expected status='success', got {result}"
assert 'summary' in result, f"Expected 'summary' in result, got keys: {result.keys()}"
assert 'results' in result
assert isinstance(result['results'], list)
assert len(result['results']) == 200, f"Expected 200 head rows, got {len(result['results'])}"
assert 'Result too large' in result['message']
assert result['row_count'] == 300
assert result['rows_limited'] is True