File size: 15,663 Bytes
b8277c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import asyncio
import logging
import json
import importlib
from unittest.mock import AsyncMock, Mock, patch, MagicMock

import httpx
import pytest

from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit
from backend.data_sources import schema_search

logger = logging.getLogger(__name__)


# =============================================================================
# PHASE 1 TESTS - Worker Connection Pooling
# =============================================================================


def test_worker_module_has_global_pool_and_minio_client():
    """Verify worker.py declares _redis_pool and _minio_client globals for connection reuse."""
    from backend.data_sources import worker as worker_module

    # Module should have these global attributes defined (initially None)
    assert hasattr(worker_module, '_redis_pool'), "_redis_pool global not found"
    assert hasattr(worker_module, '_minio_client'), "_minio_client global not found"


def test_worker_init_process_connects_to_celery_signal():
    """Verify init_worker_process is decorated with @worker_process_init.connect."""
    from backend.data_sources import worker as worker_module

    assert hasattr(worker_module, 'init_worker_process'), "init_worker_process function not found"
    # Celery signals register receivers in a receivers list - the function should be registered
    from celery.signals import worker_process_init

    # Check function exists and is callable
    assert callable(worker_module.init_worker_process)


def test_worker_init_creates_redis_pool_and_minio():
    """Test init_worker_process creates ConnectionPool and Minio client."""
    from backend.data_sources import worker as worker_module
    import redis

    # Mock config functions to return test values
    with patch.object(worker_module, 'get_redis_url', return_value='redis://localhost:6379/0'), \
         patch.object(worker_module, 'get_minio_config', return_value={
             'endpoint': 'localhost:9000',
             'access_key': 'test_key',
             'secret_key': 'test_secret',
             'secure': False
         }), \
         patch('redis.ConnectionPool.from_url') as mock_pool, \
         patch.object(worker_module, 'Minio') as mock_minio:

        mock_pool.return_value = MagicMock(spec=redis.ConnectionPool)
        mock_minio.return_value = MagicMock()

        # Reset globals
        worker_module._redis_pool = None
        worker_module._minio_client = None

        # Call init
        worker_module.init_worker_process()

        # Assert pool created with correct url
        mock_pool.assert_called_once_with('redis://localhost:6379/0')

        # Assert Minio initialized with config
        mock_minio.assert_called_once()
        call_kwargs = mock_minio.call_args.kwargs
        assert call_kwargs['endpoint'] == 'localhost:9000'
        assert call_kwargs['access_key'] == 'test_key'

        # Globals should be set
        assert worker_module._redis_pool is not None
        assert worker_module._minio_client is not None


def test_worker_task_uses_global_pool_not_new_connection():
    """Verify federated task uses redis.Redis(connection_pool=_redis_pool)."""
    from backend.data_sources import worker as worker_module
    import redis

    # Read the source code to verify the pattern is used
    import inspect
    source = inspect.getsource(worker_module.process_federated_job)

    # The task should use connection_pool=_redis_pool pattern
    assert 'connection_pool=_redis_pool' in source, \
        "process_federated_job should use redis.Redis(connection_pool=_redis_pool)"

    # Also check it references _minio_client
    assert '_minio_client' in source, \
        "process_federated_job should use global _minio_client"


def test_worker_uses_orjson_for_serialization():
    """Verify worker imports orjson aliased as json for fast serialization."""
    from backend.data_sources import worker as worker_module

    # Check the json binding in the module points to orjson
    # Import orjson to compare
    import orjson

    # The module should have imported orjson as json
    # We can verify by checking if json.dumps returns bytes (orjson) vs str (stdlib json)
    # orjson.dumps returns bytes, stdlib json.dumps returns str
    assert hasattr(worker_module, 'json'), "worker module should have json attribute"


# =============================================================================
# PHASE 2 TESTS - API ORJSONResponse and Concurrent Fetch
# =============================================================================


def test_api_heavy_endpoints_use_orjson_response():
    """Verify heavy endpoints like /results/{job_id} use ORJSONResponse."""
    from backend.data_sources import api as api_module
    import inspect

    source = inspect.getsource(api_module)

    # Check ORJSONResponse is imported
    assert 'ORJSONResponse' in source, "ORJSONResponse should be imported in api.py"

    # Check it's used in response_class parameter
    assert 'response_class=ORJSONResponse' in source, \
        "Heavy endpoints should specify response_class=ORJSONResponse"


def test_api_has_async_fetch_schema_helper():
    """Verify api.py defines _fetch_schema_for async helper for concurrent fetches."""
    from backend.data_sources import api as api_module
    import inspect

    source = inspect.getsource(api_module)

    # Check the concurrent fetch pattern exists
    assert 'async def _fetch_schema_for' in source or '_fetch_schema_for' in source, \
        "_fetch_schema_for async helper should exist for concurrent schema fetching"

    assert 'asyncio.to_thread' in source, \
        "asyncio.to_thread should be used to run sync get_schema in thread pool"

    assert 'asyncio.gather' in source, \
        "asyncio.gather should be used for concurrent execution"


@pytest.mark.asyncio
async def test_asyncio_gather_concurrent_execution_pattern():
    """Test that asyncio.gather runs multiple coroutines concurrently (validates Phase 2 pattern)."""
    import time

    async def slow_fetch(name: str, delay: float):
        """Simulate slow schema fetch."""
        await asyncio.sleep(delay)
        return (name, f"schema_{name}", None)

    # Without concurrency, 3x 0.1s delays = 0.3s
    # With concurrency, should be ~0.1s
    start = time.perf_counter()

    tasks = [
        slow_fetch("source1", 0.1),
        slow_fetch("source2", 0.1),
        slow_fetch("source3", 0.1),
    ]
    results = await asyncio.gather(*tasks)
    
    elapsed = time.perf_counter() - start

    # Should complete in roughly 0.1s (with some margin), not 0.3s
    assert elapsed < 0.25, f"asyncio.gather should run concurrently, took {elapsed:.3f}s"
    assert len(results) == 3
    assert results[0] == ("source1", "schema_source1", None)


@pytest.mark.asyncio  
async def test_asyncio_to_thread_offloads_blocking_calls():
    """Test asyncio.to_thread correctly offloads sync code (validates Phase 2 pattern)."""
    import time

    def blocking_get_schema():
        """Simulate blocking I/O."""
        time.sleep(0.05)
        return "schema_data"

    # Run blocking call via to_thread
    result = await asyncio.to_thread(blocking_get_schema)

    assert result == "schema_data"


# =============================================================================
# PHASE 3 TESTS - Toolkit Async Client
# =============================================================================


@pytest.mark.asyncio
async def test_list_sources_returns_coroutine_inside_event_loop_and_awaits():
    """When called inside an event loop, list_sources returns a coroutine which can be awaited.

    This test ensures the toolkit prefers the AsyncClient and that the async path is exercised.
    """
    toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)

    # Prepare a fake async response
    response_data = {"available_sources": ["s1", "s2"], "count": 2, "tenant_id": "tenant-x"}
    resp = httpx.Response(200, json=response_data)

    # Replace async client with an AsyncMock having an awaitable .get
    async_client = AsyncMock()
    async_client.get = AsyncMock(return_value=resp)
    toolkit._async_client = async_client

    session_state = {"tenant_id": "tenant-x", "supabase_jwt": "jwt-token"}

    # In event loop, list_sources should return a coroutine
    coro = toolkit.list_sources(session_state=session_state)
    assert asyncio.iscoroutine(coro), "Expected a coroutine when called inside an event loop"

    result = await coro

    # Async client should have been awaited
    async_client.get.assert_awaited()
    assert result["available_sources"] == ["s1", "s2"]
    assert result["count"] == 2


def test_list_sources_sync_runs_asyncclient_and_returns_result():
    """When called from sync context, list_sources runs the async impl via asyncio.run and returns result."""
    toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)

    response_data = {"available_sources": ["alpha"], "count": 1, "tenant_id": "t1"}
    resp = httpx.Response(200, json=response_data)

    # Provide AsyncMock for async client
    async_client = AsyncMock()
    async_client.get = AsyncMock(return_value=resp)
    toolkit._async_client = async_client

    session_state = {"tenant_id": "t1", "supabase_jwt": "token"}

    result = toolkit.list_sources(session_state=session_state)

    # The async client's get must have been awaited via asyncio.run
    async_client.get.assert_awaited()
    assert result["available_sources"] == ["alpha"]
    assert result["count"] == 1


def test_search_schema_prefers_async_and_caches():
    """Ensure search_schema uses async client when available and caches results in toolkit._search_cache."""
    toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)

    payload = {
        "formatted_schema_string": "schema-content",
        "matches": [{"table": "users"}],
        "available_sources": ["s1"],
        "total_matches": 1,
    }
    resp = httpx.Response(200, json=payload)

    async_client = AsyncMock()
    async_client.post = AsyncMock(return_value=resp)
    toolkit._async_client = async_client

    # Must provide session_state with tenant_id for _resolve_tenant
    session_state = {"tenant_id": "tX", "supabase_jwt": "test-jwt"}
    result = toolkit.search_schema(keywords=["user"], session_state=session_state)

    # If running in a sync context, _run_coro_sync invoked asyncio.run, so AsyncMock should be awaited
    async_client.post.assert_awaited()

    # The result should be the payload mapped to our normalized keys
    assert result["formatted_schema_string"] == "schema-content"
    # Toolkit-level cache should contain at least one entry
    assert len(toolkit._search_cache) >= 1


def test_execute_sql_query_sync_path_and_error_handling():
    """Test execute_sql_query uses the sync client and returns structured results on success."""
    toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)

    success_payload = {"status": "success", "results": [{"id": 1}], "rows_returned": 1}
    resp = httpx.Response(200, json=success_payload)

    # Replace sync client with a simple Mock holding a post method
    mock_client = Mock()
    mock_client.post = Mock(return_value=resp)
    toolkit.client = mock_client

    session_state = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"}

    # Create a mock session state
    mock_session = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"}
    
    # Pass session_state to the function
    result = toolkit.execute_sql_query(
        tenant_id='t1', 
        source_name='s1', 
        sql_query='SELECT * FROM big_table',
        session_state=mock_session  # <--- CRITICAL FIX
    )

    assert result["status"] == "success"
    assert isinstance(result["results"], list) and result["rows_returned"] == 1


# =============================================================================
# PHASE 4 TESTS - Schema Search Cache Increase
# =============================================================================


def test_split_identifier_cache_increased_and_tokens():
    """Verify the _split_identifier cache maxsize was increased and tokenization behaves as expected."""
    func = schema_search._split_identifier
    
    # Clear cache to get accurate hit counts for this test
    func.cache_clear()
    
    info = func.cache_info()
    assert info.maxsize == 4096, f"Expected cache maxsize 4096, got {info.maxsize}"

    # Call several times to generate cache hits
    tokens1 = func("OrderID")
    tokens2 = func("order_id")
    tokens3 = func("OrderID")  # This should be a cache hit

    assert tokens1 == ["order", "id"], f"Expected ['order', 'id'], got {tokens1}"
    assert tokens2 == ["order", "id"], f"Expected ['order', 'id'], got {tokens2}"
    # After calling again with same input, hits should be > 0
    assert func.cache_info().hits >= 1, f"Expected at least 1 cache hit, got {func.cache_info().hits}"


# =============================================================================
# PHASE 5 TESTS - Large Result Handling (MinIO + Pandas)
# =============================================================================

def test_execute_sql_query_minio_result_triggers_summary():
    """Test that execute_sql_query detects minio:// result, fetches file, and returns summary + head."""
    import pandas as pd
    import io
    from unittest.mock import patch, MagicMock

    # Prepare a fake DataFrame and its JSON bytes
    df = pd.DataFrame({
        'a': range(1, 301),
        'b': [x * 2 for x in range(1, 301)]
    })
    json_bytes = df.to_json().encode()

    # Mock MinIO response
    fake_minio_response = MagicMock()
    fake_minio_response.read.return_value = json_bytes

    fake_minio_client = MagicMock()
    fake_minio_client.get_object.return_value = fake_minio_response

    fake_minio_config = {'endpoint': 'localhost:9000', 'access_key': 'x', 'secret_key': 'y', 'secure': False}

    # Mock HTTP response from API
    fake_http_response = MagicMock()
    fake_http_response.status_code = 200
    fake_http_response.json.return_value = {
        'status': 'success',
        'results': 'minio://testbucket/job-results/123.json'
    }

    with patch('backend.core.minio.config.get_minio_config', return_value=fake_minio_config), \
         patch('minio.Minio', return_value=fake_minio_client):

        from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit
        toolkit = DataSourcesSQLToolkit(api_base_url="http://testserver", timeout=5)
        toolkit.client = MagicMock()
        toolkit.client.post.return_value = fake_http_response

        # Create a mock session state
        mock_session = {"tenant_id": "t1", "source_name": "s1", "supabase_jwt": "fake-token"}
        
        # Pass session_state to the function
        result = toolkit.execute_sql_query(
            tenant_id='t1', 
            source_name='s1', 
            sql_query='SELECT * FROM big_table',
            session_state=mock_session  # <--- CRITICAL FIX
        )

    # Should return summary and head rows
    assert result.get('status') == 'success', f"Expected status='success', got {result}"
    assert 'summary' in result, f"Expected 'summary' in result, got keys: {result.keys()}"
    assert 'results' in result
    assert isinstance(result['results'], list)
    assert len(result['results']) == 200, f"Expected 200 head rows, got {len(result['results'])}"
    assert 'Result too large' in result['message']
    assert result['row_count'] == 300
    assert result['rows_limited'] is True