""" Unit tests for DataSourcesSQLToolkit. Tests cover: - Initialization and configuration - Schema fetching and formatting - Source instructions retrieval - SQL query execution - Error handling and edge cases - Input validation Uses mocking (pytest-mock and unittest.mock) to simulate API responses. """ import pytest import json import uuid import httpx from time import perf_counter from datetime import datetime from unittest.mock import Mock, MagicMock, patch, call from typing import Dict, Any # Import from backend SQL_Agent package import sys import os # Ensure the project root (parent of `backend`) is on sys.path so # `import backend...` works when running this test file directly. project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) if project_root not in sys.path: sys.path.insert(0, project_root) from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture def toolkit(): """Create a DataSourcesSQLToolkit instance for testing.""" return DataSourcesSQLToolkit(api_base_url="http://test-api:8000") @pytest.fixture def mock_httpx_response(): """Create a mock httpx response.""" return Mock() # ============================================================================ # Test Initialization # ============================================================================ class TestToolkitInitialization: """Tests for toolkit initialization and setup.""" def test_init_default_values(self): """Test initialization with default values.""" toolkit = DataSourcesSQLToolkit() assert toolkit.api_base_url == "http://127.0.0.1:8000" assert toolkit.client is not None assert toolkit.client.base_url == "http://127.0.0.1:8000" def test_init_custom_values(self): """Test initialization with custom values.""" custom_url = "http://custom-api:9999" custom_timeout = 60.0 toolkit = DataSourcesSQLToolkit(api_base_url=custom_url, timeout=custom_timeout) assert toolkit.api_base_url == custom_url assert toolkit.client.base_url == custom_url def test_toolkit_has_required_tools(self, toolkit): """Test that toolkit is initialized with required tools.""" tool_names = [tool.__name__ for tool in toolkit.tools] assert "find_relevant_tables" in tool_names # NEW TOOL assert "search_schema" in tool_names assert "get_available_sources_and_schema" in tool_names assert "get_source_instructions" in tool_names assert "execute_sql_query" in tool_names # ============================================================================ # Test Schema Formatting # ============================================================================ class TestSchemaFormatting: """Tests for schema formatting for LLM consumption.""" def test_format_schema_with_valid_json(self, toolkit): """Test formatting valid schema JSON.""" source_name = "production_db" schema_json = json.dumps([ { "schema_name": "public", "tables": [ { "table_name": "users", "fields": [ {"name": "id", "type": "INTEGER", "example": "1"}, {"name": "email", "type": "VARCHAR", "example": "user@example.com"} ] } ] } ]) formatted = toolkit._format_schema_for_llm(source_name, schema_json) assert "production_db" in formatted assert "public" in formatted assert "users" in formatted assert "id" in formatted assert "INTEGER" in formatted def test_format_schema_with_empty_schema(self, toolkit): """Test formatting with empty schema.""" source_name = "empty_source" result = toolkit._format_schema_for_llm(source_name, None) assert "empty_source" in result assert "Not Available" in result def test_format_schema_with_invalid_json(self, toolkit): """Test formatting with invalid JSON.""" source_name = "broken_source" invalid_json = "{ invalid json }" result = toolkit._format_schema_for_llm(source_name, invalid_json) assert "broken_source" in result assert "Invalid Format" in result or "Error" in result def test_format_schema_with_multiple_tables(self, toolkit): """Test formatting with multiple tables.""" source_name = "test_db" schema_json = json.dumps([ { "schema_name": "public", "tables": [ { "table_name": "users", "fields": [ {"name": "id", "type": "INTEGER"} ] }, { "table_name": "orders", "fields": [ {"name": "order_id", "type": "INTEGER"}, {"name": "user_id", "type": "INTEGER"} ] } ] } ]) formatted = toolkit._format_schema_for_llm(source_name, schema_json) assert "users" in formatted assert "orders" in formatted # ============================================================================ # Test get_available_sources_and_schema # ============================================================================ class TestGetAvailableSourcesAndSchema: """Tests for retrieving available sources and their schemas.""" def test_missing_tenant_id(self, toolkit): """Test that missing tenant_id returns error.""" result = toolkit.get_available_sources_and_schema(session_state=None) assert "error" in result assert "Session state not found" in result["error"] def test_successful_schema_fetch(self, toolkit): """Test successful schema fetching for valid tenant.""" # Mock the list_sources call list_response = Mock() list_response.status_code = 200 list_response.json.return_value = { "available_sources": ["db1", "db2"], "count": 2 } # Mock httpx client responses for schema fetches schema_response_1 = Mock() schema_response_1.status_code = 200 schema_response_1.json.return_value = { "schema_data": json.dumps([{ "schema_name": "public", "tables": [{"table_name": "users", "fields": []}] }]) } schema_response_2 = Mock() schema_response_2.status_code = 200 schema_response_2.json.return_value = { "schema_data": json.dumps([{ "schema_name": "public", "tables": [{"table_name": "orders", "fields": []}] }]) } toolkit.client.get = Mock(side_effect=[list_response, schema_response_1, schema_response_2]) # Provide session_state with tenant_id and JWT session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} result = toolkit.get_available_sources_and_schema( keywords=None, session_state=session_state ) assert "formatted_schema_string" in result assert "available_sources" in result assert len(result["available_sources"]) == 2 # Verify JWT token was used in headers calls = toolkit.client.get.call_args_list for call in calls[1:]: # Skip the list_sources call headers = call[1]["headers"] assert headers["Authorization"] == "Bearer jwt-token-123" def test_schema_fetch_with_api_error(self, toolkit): """Test handling of API errors during schema fetch.""" # Mock list_sources returning one source list_response = Mock() list_response.status_code = 200 list_response.json.return_value = { "available_sources": ["db1"], "count": 1 } # Mock error response for schema fetch schema_response = Mock() schema_response.status_code = 500 schema_response.text = "Internal Server Error" toolkit.client.get = Mock(side_effect=[list_response, schema_response]) session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} result = toolkit.get_available_sources_and_schema(session_state=session_state) assert "formatted_schema_string" in result assert "available_sources" in result def test_schema_fetch_import_error(self, toolkit): """Test handling when API connection fails.""" # Mock list_sources to raise connection error toolkit.client.get = Mock(side_effect=httpx.RequestError("Connection failed")) session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} result = toolkit.get_available_sources_and_schema(session_state=session_state) assert "error" in result assert "API connection error" in result["error"] # ============================================================================ # Test get_source_instructions # ============================================================================ class TestGetSourceInstructions: """Tests for retrieving source-specific instructions.""" def test_missing_tenant_id(self, toolkit): """Test that missing tenant_id returns error.""" result = toolkit.get_source_instructions(session_state=None) assert "error" in result assert "Session state not found" in result["error"] def test_missing_source_name(self, toolkit): """Test that missing source_name returns error.""" result = toolkit.get_source_instructions(session_state=None) assert "error" in result assert "Session state not found" in result["error"] def test_successful_instructions_fetch(self, toolkit): """Test successful retrieval of instructions.""" mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { "instructions": "Use SELECT ... FROM syntax. Supports CTEs with WITH clause." } toolkit.client.get = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "source_name": "production_db", "supabase_jwt": "jwt-token-123"} result = toolkit.get_source_instructions( source_name="production_db", session_state=session_state ) assert "instructions" in result assert "SELECT" in result["instructions"] toolkit.client.get.assert_called_once() def test_instructions_not_found(self, toolkit): """Test handling of 404 when instructions not found.""" mock_response = Mock() mock_response.status_code = 404 toolkit.client.get = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "source_name": "unknown_db", "supabase_jwt": "jwt-token-123"} result = toolkit.get_source_instructions( source_name="unknown_db", session_state=session_state ) assert "error" in result assert "not found" in result["error"].lower() def test_instructions_api_error(self, toolkit): """Test handling of API errors.""" mock_response = Mock() mock_response.status_code = 500 mock_response.text = "Server Error" toolkit.client.get = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} result = toolkit.get_source_instructions( source_name="db", session_state=session_state ) assert "error" in result assert "Failed" in result["error"] def test_instructions_connection_error(self, toolkit): """Test handling of connection errors.""" import httpx toolkit.client.get = Mock(side_effect=httpx.RequestError("Connection refused")) session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} result = toolkit.get_source_instructions( source_name="db", session_state=session_state ) assert "error" in result assert "connection error" in result["error"].lower() # ============================================================================ # Test execute_sql_query # ============================================================================ class TestExecuteSQLQuery: """Tests for SQL query execution.""" def test_missing_tenant_id(self, toolkit): """Test that missing tenant_id returns error.""" result = toolkit.execute_sql_query( sql_query="SELECT * FROM users", session_state=None ) assert "error" in result assert "Session state not found" in result["error"] def test_missing_source_name(self, toolkit): """Test that missing source_name returns error.""" result = toolkit.execute_sql_query( sql_query="SELECT * FROM users", session_state=None ) assert "error" in result assert "Session state not found" in result["error"] def test_empty_sql_query(self, toolkit): """Test that empty SQL query returns error.""" result = toolkit.execute_sql_query( sql_query=" ", session_state={"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} ) assert "error" in result assert "empty" in result["error"].lower() def test_successful_query_execution(self, toolkit): """Test successful SQL query execution.""" mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { "results": [ {"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"} ] } toolkit.client.post = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "source_name": "production_db", "supabase_jwt": "jwt-token-123"} result = toolkit.execute_sql_query( sql_query="SELECT * FROM users", session_state=session_state ) assert result["status"] == "success" assert len(result["results"]) == 2 assert result["results"][0]["name"] == "Alice" def test_query_blocked_for_unsafe_operations(self, toolkit): """Test that unsafe SQL operations are blocked client-side.""" unsafe_queries = [ "DROP TABLE users", "DELETE FROM users", "UPDATE users SET name = 'hacked'", "INSERT INTO users VALUES (1, 'hacked')", "TRUNCATE TABLE users" ] session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} for unsafe_query in unsafe_queries: result = toolkit.execute_sql_query( sql_query=unsafe_query, session_state=session_state ) assert "error" in result assert "blocked" in result["error"].lower() def test_query_allowed_for_safe_operations(self, toolkit): """Test that safe SQL operations are allowed.""" mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"results": []} toolkit.client.post = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} safe_queries = [ "SELECT * FROM users", "WITH cte AS (SELECT * FROM users) SELECT * FROM cte", "SHOW TABLES", "DESCRIBE users", "EXPLAIN SELECT * FROM users" ] for safe_query in safe_queries: result = toolkit.execute_sql_query( source_name="db", sql_query=safe_query, session_state=session_state ) # Should not have error (or should have been processed) # Note: for SHOW, DESCRIBE, EXPLAIN - client-side check passes, # then POST is attempted which we're mocking to succeed assert "status" in result or "error" not in result.get("error", "").lower() def test_query_api_error(self, toolkit): """Test handling of API errors during execution.""" mock_response = Mock() mock_response.status_code = 400 mock_response.json.return_value = {"detail": "Syntax error in query"} toolkit.client.post = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} result = toolkit.execute_sql_query( source_name="db", sql_query="SELECT * FROM users", session_state=session_state ) assert "error" in result assert "Syntax error" in result["error"] def test_query_timeout(self, toolkit): """Test handling of request timeouts.""" import httpx toolkit.client.post = Mock(side_effect=httpx.TimeoutException("Request timed out")) session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} result = toolkit.execute_sql_query( source_name="db", sql_query="SELECT * FROM users", session_state=session_state ) assert "error" in result assert "timed out" in result["error"].lower() def test_query_connection_error(self, toolkit): """Test handling of connection errors.""" import httpx toolkit.client.post = Mock(side_effect=httpx.RequestError("Connection refused")) session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} result = toolkit.execute_sql_query( source_name="db", sql_query="SELECT * FROM users", session_state=session_state ) assert "error" in result assert "connection error" in result["error"].lower() def test_query_malformed_response(self, toolkit): """Test handling of malformed API responses.""" mock_response = Mock() mock_response.status_code = 500 mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) mock_response.text = "Internal Server Error" toolkit.client.post = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} result = toolkit.execute_sql_query( source_name="db", sql_query="SELECT * FROM users", session_state=session_state ) assert "error" in result assert "Internal Server Error" in result["error"] # ============================================================================ # Test Multi-tenant Isolation # ============================================================================ class TestMultiTenantIsolation: """Tests for ensuring proper multi-tenant isolation.""" def test_different_tenants_use_different_params(self, toolkit): """Test that different tenants get different API parameters.""" mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"results": []} toolkit.client.post = Mock(return_value=mock_response) # Execute query for tenant 1 session_state_1 = {"tenant_id": "tenant-1", "supabase_jwt": "jwt-token-1", "source_name": "db"} toolkit.execute_sql_query( sql_query="SELECT * FROM users", session_state=session_state_1 ) # Get first call's payload first_call_payload = toolkit.client.post.call_args[1]["json"] assert first_call_payload["tenant_id"] == "tenant-1" # Execute query for tenant 2 session_state_2 = {"tenant_id": "tenant-2", "supabase_jwt": "jwt-token-2", "source_name": "db"} toolkit.execute_sql_query( sql_query="SELECT * FROM orders", session_state=session_state_2 ) # Get second call's payload second_call_payload = toolkit.client.post.call_args[1]["json"] assert second_call_payload["tenant_id"] == "tenant-2" def test_tenant_id_passed_to_all_endpoints(self, toolkit): """Test that tenant_id is passed in request payload (not query params) for endpoints that need it.""" # Test execute_sql_query - tenant_id should be in payload mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"results": []} toolkit.client.post = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-x", "source_name": "db", "supabase_jwt": "jwt-token-123"} toolkit.execute_sql_query( sql_query="SELECT * FROM users", session_state=session_state ) # Get the call's payload call_kwargs = toolkit.client.post.call_args[1] payload = call_kwargs["json"] headers = call_kwargs["headers"] # Verify tenant_id is in payload, not query params assert payload["tenant_id"] == "tenant-x" # Verify JWT is in Authorization header assert headers["Authorization"] == "Bearer jwt-token-123" # ============================================================================ # Test Edge Cases # ============================================================================ class TestEdgeCases: """Tests for edge cases and boundary conditions.""" def test_very_long_sql_query(self, toolkit): """Test handling of very long SQL queries.""" mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"results": []} toolkit.client.post = Mock(return_value=mock_response) long_query = "SELECT * FROM users WHERE id IN (" + ",".join(str(i) for i in range(1000)) + ")" session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"} result = toolkit.execute_sql_query( source_name="db", sql_query=long_query, session_state=session_state ) assert "error" not in result or result.get("status") == "success" def test_special_characters_in_query(self, toolkit): """Test handling of special characters in SQL queries.""" mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"results": []} toolkit.client.post = Mock(return_value=mock_response) query_with_special_chars = "SELECT * FROM users WHERE name = 'O''Reilly' AND email LIKE '%@%.com%'" session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"} result = toolkit.execute_sql_query( source_name="db", sql_query=query_with_special_chars, session_state=session_state ) # Should be posted without error toolkit.client.post.assert_called_once() def test_empty_results(self, toolkit): """Test handling of empty result sets.""" mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"results": []} toolkit.client.post = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"} result = toolkit.execute_sql_query( source_name="db", sql_query="SELECT * FROM users WHERE id = -1", session_state=session_state ) assert result["status"] == "success" assert result["results"] == [] def test_large_result_set(self, toolkit): """Test handling of large result sets.""" # Create a large result set large_results = [{"id": i, "name": f"user_{i}"} for i in range(10000)] mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"results": large_results} toolkit.client.post = Mock(return_value=mock_response) session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"} result = toolkit.execute_sql_query( source_name="db", sql_query="SELECT * FROM users", session_state=session_state ) assert result["status"] == "success" assert len(result["results"]) == 10000 # ============================================================================ # Test find_relevant_tables (Hybrid Keyword Extraction) # ============================================================================ class TestFindRelevantTables: """Tests for the new find_relevant_tables tool with hybrid keyword extraction.""" def test_missing_tenant_id(self, toolkit): """Test that missing tenant_id returns error.""" result = toolkit.find_relevant_tables( tenant_id="", question="What are the total sales?" ) assert "error" in result assert "Tenant ID is required" in result["error"] def test_missing_question(self, toolkit): """Test that missing question returns error.""" result = toolkit.find_relevant_tables( tenant_id="tenant-123", question="" ) assert "error" in result assert "Question is required" in result["error"] @patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords") def test_find_relevant_tables_merges_concepts_and_semantic_hints(self, mock_extract, toolkit): """Test that find_relevant_tables properly merges concepts and semantic hints.""" # Mock hybrid keyword extraction mock_extract.return_value = { 'base': ['revenue', 'customers'], 'semantic': ['sales metrics', 'customer data'], 'concepts': ['premium', 'quarterly'], 'combined': ['premium', 'quarterly', 'revenue', 'customers', 'sales metrics', 'customer data'] } # Mock the search_schema call mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { "formatted_schema_string": "Test schema", "matches": [{"table_name": "sales", "score": 15.0, "matched_columns": ["revenue"], "source_name": "db1"}], "available_sources": ["db1"], "total_matches": 1 } toolkit.client.post = Mock(return_value=mock_response) # Create session state session_state = {"keyword_extraction_cache": {}} # Call find_relevant_tables with agent concepts result = toolkit.find_relevant_tables( tenant_id="tenant-123", question="What was revenue from premium customers last quarter?", concepts=["premium", "quarterly"], session_state=session_state ) # Verify extraction was called mock_extract.assert_called_once() call_kwargs = mock_extract.call_args[1] assert call_kwargs['question'] == "What was revenue from premium customers last quarter?" assert call_kwargs['llm_concepts'] == ["premium", "quarterly"] # Verify result structure assert "formatted_schema_string" in result assert "keyword_breakdown" in result assert result["keyword_breakdown"]["base"] == ['revenue', 'customers'] assert result["keyword_breakdown"]["semantic"] == ['sales metrics', 'customer data'] assert result["keyword_breakdown"]["concepts"] == ['premium', 'quarterly'] assert result["original_question"] == "What was revenue from premium customers last quarter?" # Verify merged keywords were sent to API api_call_payload = toolkit.client.post.call_args[1]["json"] assert set(api_call_payload["keywords"]) == {'premium', 'quarterly', 'revenue', 'customers', 'sales metrics', 'customer data'} @patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords") def test_find_relevant_tables_handles_gemini_failure(self, mock_extract, toolkit): """Test that find_relevant_tables handles Gemini failure gracefully.""" # Mock hybrid extraction with no semantic hints (Gemini failed) mock_extract.return_value = { 'base': ['sales', 'customers'], 'semantic': [], # Empty because Gemini failed 'concepts': ['premium'], 'combined': ['premium', 'sales', 'customers'] } # Mock API response mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { "formatted_schema_string": "Schema", "matches": [], "available_sources": ["db1"], "total_matches": 0 } toolkit.client.post = Mock(return_value=mock_response) # Call should succeed despite Gemini failure session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} result = toolkit.find_relevant_tables( question="Show premium customer sales", concepts=["premium"], session_state=session_state ) assert "error" not in result assert result["keyword_breakdown"]["semantic"] == [] assert "premium" in result["keyword_breakdown"]["combined"] assert "sales" in result["keyword_breakdown"]["combined"] @patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords") def test_find_relevant_tables_uses_cache_for_repeat_question(self, mock_extract, toolkit): """Test that find_relevant_tables caches keyword extraction per question.""" # Mock hybrid extraction mock_extract.return_value = { 'base': ['revenue'], 'semantic': ['financial metrics'], 'concepts': [], 'combined': ['revenue', 'financial metrics'] } # Mock API response mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { "formatted_schema_string": "Schema", "matches": [], "available_sources": [], "total_matches": 0 } toolkit.client.post = Mock(return_value=mock_response) # Create session state session_state = {"keyword_extraction_cache": {}} # First call toolkit.find_relevant_tables( question="What is the total revenue?", session_state=session_state ) # Verify extraction was called once assert mock_extract.call_count == 1 assert len(session_state["keyword_extraction_cache"]) == 1 # Second call with same question toolkit.find_relevant_tables( question="What is the total revenue?", session_state=session_state ) # Verify extraction was NOT called again (cache hit) assert mock_extract.call_count == 1 # Still 1, not 2 # Third call with different question toolkit.find_relevant_tables( question="Show me customer data", session_state=session_state ) # Verify extraction was called for new question assert mock_extract.call_count == 2 assert len(session_state["keyword_extraction_cache"]) == 2 def test_find_relevant_tables_without_hybrid_utils(self, toolkit): """Test fallback behavior when hybrid_keyword_utils is not available.""" # Temporarily disable hybrid extraction original_extract = toolkit.__class__.__module__ with patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords", None): # Mock API response mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { "formatted_schema_string": "Schema", "matches": [], "available_sources": [], "total_matches": 0 } toolkit.client.post = Mock(return_value=mock_response) # Should fall back to simple keyword extraction session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} result = toolkit.find_relevant_tables( question="What is the revenue from premium customers?", concepts=["premium"], session_state=session_state ) # Verify fallback worked assert "error" not in result # Check that API was called with some keywords api_payload = toolkit.client.post.call_args[1]["json"] assert len(api_payload["keywords"]) > 0 @patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords") def test_find_relevant_tables_sends_metadata_to_api(self, mock_extract, toolkit): """Test that find_relevant_tables sends metadata to API for analytics.""" # Mock hybrid extraction mock_extract.return_value = { 'base': ['sales'], 'semantic': ['revenue data'], 'concepts': ['monthly'], 'combined': ['monthly', 'sales', 'revenue data'] } # Mock API response mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { "formatted_schema_string": "Schema", "matches": [], "available_sources": [], "total_matches": 0 } toolkit.client.post = Mock(return_value=mock_response) # Call with question session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} toolkit.find_relevant_tables( question="Show monthly sales data", concepts=["monthly"], session_state=session_state ) # Verify metadata was sent to API api_payload = toolkit.client.post.call_args[1]["json"] assert "original_question" in api_payload assert api_payload["original_question"] == "Show monthly sales data" assert "keyword_metadata" in api_payload assert api_payload["keyword_metadata"]["base"] == ['sales'] assert api_payload["keyword_metadata"]["semantic"] == ['revenue data'] assert api_payload["keyword_metadata"]["concepts"] == ['monthly'] @patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords") def test_find_relevant_tables_with_source_filter(self, mock_extract, toolkit): """Test that find_relevant_tables passes source_names filter correctly.""" # Mock hybrid extraction mock_extract.return_value = { 'base': ['users'], 'semantic': [], 'concepts': [], 'combined': ['users'] } # Mock API response mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { "formatted_schema_string": "Schema", "matches": [], "available_sources": ["db1"], "total_matches": 0 } toolkit.client.post = Mock(return_value=mock_response) # Call with source filter session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} toolkit.find_relevant_tables( question="Show user data", source_names=["db1", "db2"], session_state=session_state ) # Verify source filter was passed to API api_payload = toolkit.client.post.call_args[1]["json"] assert "source_names" in api_payload assert api_payload["source_names"] == ["db1", "db2"] # ============================================================================ # End-to-End Integration Tests with Real Database # ============================================================================ class TestEndToEndIntegration: """ End-to-end integration tests with real PostgreSQL database using REAL API calls. Uses requests library to hit the actual running API server. Tests the complete workflow: keyword search -> schema fetch -> SQL execution. """ @pytest.fixture(scope="class") def live_api_setup(self): """Provision a real tenant + source against the live data_sources API using requests.""" import requests import uuid if os.environ.get("SKIP_INTEGRATION_TESTS") == "1": pytest.skip("Integration tests skipped via SKIP_INTEGRATION_TESTS=1") base_url = os.environ.get("DATA_SOURCES_API_BASE_URL", "http://127.0.0.1:8000") admin_key = os.environ.get("SIRUS_ADMIN_API_KEY") if not admin_key: pytest.skip("SIRUS_ADMIN_API_KEY not configured") tenant_id = f"toolkit_e2e_{uuid.uuid4().hex[:8]}" source_name = "scv_sample_db" description = f"E2E Test {tenant_id}" session = requests.Session() session.timeout = 30 # Quick availability check try: health = session.get(f"{base_url}/api/v1/data-sources/list", params={"tenant_id": "__healthcheck__"}) if health.status_code not in (200, 400, 404, 401): pytest.skip(f"API health check failed (status {health.status_code})") except requests.RequestException as exc: pytest.skip(f"API unreachable: {exc}") # Create tenant API key via admin route api_key = None key_id = None headers_admin = {"X-Sirus-Admin-Key": admin_key} create_key_payload = { "description": description, "expires_in_days": 14 } resp = session.post( f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys", json=create_key_payload, headers=headers_admin ) if resp.status_code == 201: data = resp.json() api_key = data.get("api_key") key_info = data.get("key_info", {}) key_id = key_info.get("key_id") else: # Attempt to reuse an existing key if creation failed list_resp = session.get( f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys", headers=headers_admin ) if list_resp.status_code == 200: keys = list_resp.json().get("keys", []) if keys: key_id = keys[0].get("key_id") pytest.skip("Existing tenant API keys found but raw value unavailable") pytest.skip(f"Failed to create tenant API key (status {resp.status_code}): {resp.text[:200]}") if not api_key: pytest.skip("Tenant API key could not be provisioned") tenant_headers = {"X-Sirus-Api-Key": api_key} # Upsert tenant source pointing at live Postgres source_payload = { "sources": [ { "source_name": source_name, "source_type": "ibis", "config": { "uri": "postgresql://neondb_owner:npg_dfWNsn2ZGk7c@ep-cool-poetry-a1puamly-pooler.ap-southeast-1.aws.neon.tech:5432/scv-sample?sslmode=require", "table_fetch_example_limit": 5 } } ], "validate_connection": True } upsert_resp = session.post( f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/sources", json=source_payload, headers=tenant_headers ) if upsert_resp.status_code not in (200, 201): pytest.skip( f"Failed to configure tenant source (status {upsert_resp.status_code}): {upsert_resp.text[:200]}" ) setup = { "base_url": base_url, "tenant_id": tenant_id, "source_name": source_name, "api_key": api_key, "key_id": key_id, "session": session, "headers_admin": headers_admin, "tenant_headers": tenant_headers, "timeout": 120.0 } try: yield setup finally: # Cleanup: delete source and revoke key try: session.delete( f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/sources/{source_name}", headers=tenant_headers ) except Exception: pass if key_id: try: session.delete( f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys/{key_id}", headers=headers_admin ) except Exception: pass session.close() @pytest.fixture def live_toolkit(self, live_api_setup): """Toolkit connected to real API with proper authentication.""" api_key = live_api_setup.get("api_key") base_url = live_api_setup.get("base_url", "http://127.0.0.1:8000") return DataSourcesSQLToolkit(api_base_url=base_url, timeout=120.0, api_key=api_key) @pytest.mark.integration @pytest.mark.skipif( os.environ.get("SKIP_INTEGRATION_TESTS") == "1", reason="Integration tests skipped (set SKIP_INTEGRATION_TESTS=0 to run)" ) def test_e2e_keyword_search_with_real_db(self, live_toolkit, live_api_setup, caplog): """ Test end-to-end workflow with keyword search on real database using REAL API calls. Workflow: 1. Search schema with keywords via real API 2. Verify matched tables 3. Execute SQL query on matched table via real API 4. Measure timing for each step """ import time import logging # Enable detailed logging caplog.set_level(logging.INFO) print("\n" + "="*80) print("END-TO-END INTEGRATION TEST: Keyword Search -> SQL Execution") print("="*80) # Test configuration from provisioned setup tenant_id = live_api_setup["tenant_id"] source_name = live_api_setup["source_name"] # Step 1: Search schema with business keywords print("\n[STEP 1] Searching schema with keywords...") search_start = time.time() keywords = ["customer", "order", "product", "sales"] search_result = live_toolkit.search_schema( keywords=keywords, include_samples=True, session_state={"tenant_id": tenant_id} ) search_duration = time.time() - search_start print(f"✓ Schema search completed in {search_duration:.3f}s") # Verify search results if "error" in search_result: print(f"⚠ Search returned error (API may not be running): {search_result['error']}") pytest.skip("API not available for integration test") assert "matches" in search_result or "formatted_schema_string" in search_result print(f" - Total matches: {search_result.get('total_matches', 0)}") print(f" - Available sources: {search_result.get('available_sources', [])}") print(f" - Cache hit: {search_result.get('cache_hit', False)}") # Log matched tables if "matches" in search_result and search_result["matches"]: print("\n Matched Tables:") for match in search_result["matches"][:5]: # Show top 5 print(f" - {match.get('table_name', 'N/A')} (score: {match.get('score', 0)})") if "matched_columns" in match: print(f" Columns: {', '.join(match['matched_columns'][:5])}") # Step 2: Get source instructions print("\n[STEP 2] Fetching SQL dialect instructions...") instructions_start = time.time() session_state = {"tenant_id": tenant_id, "source_name": source_name} instructions_result = live_toolkit.get_source_instructions( source_name=source_name, session_state=session_state ) instructions_duration = time.time() - instructions_start print(f"✓ Instructions fetched in {instructions_duration:.3f}s") if "error" not in instructions_result: print(f" - Instructions: {instructions_result.get('instructions', 'N/A')[:100]}...") # Step 3: Execute SQL query on discovered table print("\n[STEP 3] Executing SQL query on real database...") # Use a safe SELECT query that should work on most schemas test_queries = [ "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' LIMIT 5", "SELECT current_database(), current_schema()", "SELECT version()" ] for idx, sql_query in enumerate(test_queries, 1): print(f"\n Query {idx}: {sql_query}") query_start = time.time() query_result = live_toolkit.execute_sql_query( source_name=source_name, sql_query=sql_query, session_state={"tenant_id": tenant_id} ) query_duration = time.time() - query_start print(f" ✓ Query executed in {query_duration:.3f}s") if "error" in query_result: print(f" ⚠ Query error: {query_result['error']}") elif "results" in query_result: results = query_result["results"] print(f" - Rows returned: {len(results)}") if results and len(results) > 0: print(f" - Sample row: {results[0]}") # Verify result structure assert "status" in query_result or "error" in query_result or "results" in query_result # Step 4: Test multi-tenant isolation print("\n[STEP 4] Testing multi-tenant isolation...") tenant_2_id = "test_tenant_e2e_002" isolation_start = time.time() # Same query, different tenant isolation_result = live_toolkit.execute_sql_query( source_name=source_name, sql_query="SELECT current_database()", session_state={"tenant_id": tenant_2_id} ) isolation_duration = time.time() - isolation_start print(f"✓ Tenant isolation verified in {isolation_duration:.3f}s") print(f" - Tenant 2 query processed independently") # Step 5: Test caching behavior print("\n[STEP 5] Testing cache behavior...") # Repeat search with same keywords cache_start = time.time() cached_search_result = live_toolkit.search_schema( keywords=keywords, include_samples=False, session_state={"tenant_id": tenant_id} ) cache_duration = time.time() - cache_start print(f"✓ Cached search completed in {cache_duration:.3f}s") print(f" - Cache hit: {cached_search_result.get('cache_hit', False)}") print(f" - Speedup: {(search_duration / cache_duration):.2f}x faster" if cache_duration > 0 else " - Instant cache hit") # Summary print("\n" + "="*80) print("END-TO-END TEST SUMMARY") print("="*80) print(f"Total workflow time: {(search_duration + instructions_duration + query_duration):.3f}s") print(f" - Schema search: {search_duration:.3f}s") print(f" - Instructions: {instructions_duration:.3f}s") print(f" - SQL execution: {query_duration:.3f}s") print(f" - Cache speedup: {(search_duration / cache_duration):.2f}x" if cache_duration > 0 else " - N/A") print("="*80 + "\n") @pytest.mark.integration @pytest.mark.skipif( os.environ.get("SKIP_INTEGRATION_TESTS") == "1", reason="Integration tests skipped" ) def test_e2e_session_state_caching(self, live_toolkit, live_api_setup, caplog): """ Test session state caching across multiple toolkit calls. Simulates agent behavior with session state persistence. """ import time print("\n" + "="*80) print("END-TO-END TEST: Session State Caching") print("="*80) tenant_id = live_api_setup["tenant_id"] # Simulate session state (like agent would provide) session_state = { "schema_search_cache": {}, "extracted_keywords": [], "tool_execution_log": [] } keywords = ["revenue", "customer", "transaction"] # First call - should miss cache print("\n[Call 1] Schema search without cache...") start_1 = time.time() result_1 = live_toolkit.search_schema( keywords=keywords, session_state=session_state ) duration_1 = time.time() - start_1 if "error" in result_1: pytest.skip("API not available") print(f"✓ First call: {duration_1:.3f}s, cache_hit={result_1.get('cache_hit', False)}") # Second call - should hit session cache print("\n[Call 2] Schema search with session cache...") start_2 = time.time() result_2 = live_toolkit.search_schema( keywords=keywords, session_state=session_state ) duration_2 = time.time() - start_2 print(f"✓ Second call: {duration_2:.3f}s, cache_hit={result_2.get('cache_hit', False)}") print(f" Cache source: {result_2.get('cache_source', 'N/A')}") # Verify caching worked assert result_2.get('cache_hit') == True, "Second call should hit cache" assert duration_2 < duration_1, "Cached call should be faster" speedup = duration_1 / duration_2 if duration_2 > 0 else float('inf') print(f"\n✓ Cache speedup: {speedup:.2f}x faster") print(f" Session cache keys: {len(session_state.get('schema_search_cache', {}))}") print("="*80 + "\n") @pytest.mark.integration @pytest.mark.skipif( os.environ.get("SKIP_INTEGRATION_TESTS") == "1", reason="Integration tests skipped" ) def test_e2e_keyword_extraction_workflow(self, live_toolkit, live_api_setup, caplog): """ Test realistic agent workflow: keyword extraction -> search -> SQL execution. Simulates what agent.py would do with execute_query_with_tracking. """ import time import re print("\n" + "="*80) print("END-TO-END TEST: Agent Keyword Extraction Workflow") print("="*80) tenant_id = live_api_setup["tenant_id"] source_name = live_api_setup["source_name"] # Simulate user questions (like agent would receive) user_questions = [ "What are the total sales by customer?", "Show me all products with low inventory", "Get customer contact information for premium accounts" ] for q_idx, question in enumerate(user_questions, 1): print(f"\n[Question {q_idx}] {question}") # Step 1: Extract keywords (simulating agent.py logic) stopwords = {'what', 'are', 'the', 'show', 'me', 'all', 'with', 'get', 'for', 'by'} words = re.findall(r'\b[a-z]+\b', question.lower()) keywords = [w for w in words if w not in stopwords and len(w) > 2] print(f" Extracted keywords: {keywords}") # Step 2: Search schema search_start = time.time() search_result = live_toolkit.search_schema( keywords=keywords[:5], # Limit to top 5 session_state={"tenant_id": tenant_id} ) search_duration = time.time() - search_start if "error" in search_result: print(f" ⚠ API not available: {search_result['error']}") pytest.skip("API not running") print(f" ✓ Search: {search_duration:.3f}s, matches={search_result.get('total_matches', 0)}") # Step 3: Agent would now generate SQL based on matched tables # For this test, we'll use a generic query if search_result.get('total_matches', 0) > 0: sql_query = "SELECT 1 as test_column" # Safe query query_start = time.time() query_result = live_toolkit.execute_sql_query( source_name=source_name, sql_query=sql_query, session_state={"tenant_id": tenant_id} ) query_duration = time.time() - query_start print(f" ✓ Query: {query_duration:.3f}s") if "results" in query_result: print(f" Results: {len(query_result['results'])} rows") print(f" Total workflow: {(search_duration + query_duration if 'query_duration' in locals() else search_duration):.3f}s") print("\n" + "="*80) print("✓ All workflow tests completed successfully") print("="*80 + "\n") @pytest.mark.integration def test_e2e_performance_benchmarks(self, live_toolkit, live_api_setup, caplog): """ Performance benchmark tests for the complete toolkit. Measures timing for various operations under realistic load. """ import time import statistics print("\n" + "="*80) print("PERFORMANCE BENCHMARKS") print("="*80) tenant_id = live_api_setup["tenant_id"] # Benchmark 1: Schema search performance print("\n[Benchmark 1] Schema search (10 iterations)...") search_times = [] for i in range(10): keywords = [f"test_keyword_{i % 3}", "customer", "order"] start = time.time() result = live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id}) duration = time.time() - start if "error" not in result: search_times.append(duration) if search_times: print(f" Average: {statistics.mean(search_times):.3f}s") print(f" Median: {statistics.median(search_times):.3f}s") print(f" Min: {min(search_times):.3f}s") print(f" Max: {max(search_times):.3f}s") print(f" Std Dev: {statistics.stdev(search_times):.3f}s" if len(search_times) > 1 else " N/A") else: print(" ⚠ No successful searches (API may be down)") pytest.skip("API not available") # Benchmark 2: Cache hit performance print("\n[Benchmark 2] Cache hit performance...") keywords = ["benchmark", "cache", "test"] # First call (cache miss) start = time.time() live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id}) miss_time = time.time() - start # Subsequent calls (cache hits) hit_times = [] for _ in range(5): start = time.time() live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id}) hit_times.append(time.time() - start) avg_hit_time = statistics.mean(hit_times) print(f" Cache miss: {miss_time:.3f}s") print(f" Cache hit avg: {avg_hit_time:.3f}s") print(f" Speedup: {(miss_time / avg_hit_time):.2f}x") print("\n" + "="*80) print("✓ Performance benchmarks completed") print("="*80 + "\n") @pytest.mark.integration def test_full_live_user_flow(self, live_api_setup): """Validate an end-to-end user flow against the live data_sources API.""" setup = live_api_setup if setup is None: pytest.skip("Live API setup unavailable") toolkit = DataSourcesSQLToolkit( api_base_url=setup["base_url"], api_key=setup["api_key"], timeout=setup["timeout"] ) session_state = { "tenant_id": setup["tenant_id"], "source_name": setup["source_name"], "tool_execution_log": [], "analysis_metadata": {}, "user_context": {} } metrics: Dict[str, Any] = {} start = perf_counter() list_result = toolkit.list_sources(session_state=session_state) metrics["list_sources_ms"] = round((perf_counter() - start) * 1000, 2) assert "available_sources" in list_result assert setup["source_name"] in list_result.get("available_sources", []) start = perf_counter() search_result = toolkit.search_schema( keywords=["booking", "customer"], session_state=session_state ) metrics["search_schema_ms"] = round((perf_counter() - start) * 1000, 2) assert "available_sources" in search_result assert session_state.get("analysis_metadata", {}).get("last_schema_search") is not None start = perf_counter() find_result = toolkit.find_relevant_tables( question="What were our total bookings last quarter?", concepts=["bookings", "quarter"], session_state=session_state ) metrics["find_relevant_tables_ms"] = round((perf_counter() - start) * 1000, 2) assert "error" not in find_result start = perf_counter() instructions = toolkit.get_source_instructions(session_state=session_state) metrics["get_source_instructions_ms"] = round((perf_counter() - start) * 1000, 2) assert "instructions" in instructions start = perf_counter() good_sql = "SELECT current_database() AS current_db, NOW() AT TIME ZONE 'UTC' AS utc_now" ok_result = toolkit.execute_sql_query( sql_query=good_sql, session_state=session_state ) metrics["execute_sql_success_ms"] = round((perf_counter() - start) * 1000, 2) assert ok_result.get("status") == "success" assert isinstance(ok_result.get("results"), list) metadata = session_state.get("analysis_metadata", {}) last_exec = metadata.get("last_sql_execution", {}) assert last_exec.get("tenant_id") == setup["tenant_id"] assert "row_count" in last_exec start = perf_counter() bad_sql = "SELECT * FROM table_that_does_not_exist__pytest" error_result = toolkit.execute_sql_query( sql_query=bad_sql, session_state=session_state ) metrics["execute_sql_error_ms"] = round((perf_counter() - start) * 1000, 2) assert "error" in error_result # Capture async metadata for diagnostics metrics["total_tool_calls"] = len(session_state.get("tool_execution_log", [])) # Basic sanity on latencies (should be non-zero but reasonable) for label, value in metrics.items(): assert value >= 0 if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])