Spaces:
Running
Running
| """ | |
| Unit tests for DataSourcesSQLToolkit. | |
| Tests cover: | |
| - Initialization and configuration | |
| - Schema fetching and formatting | |
| - Source instructions retrieval | |
| - SQL query execution | |
| - Error handling and edge cases | |
| - Input validation | |
| Uses mocking (pytest-mock and unittest.mock) to simulate API responses. | |
| """ | |
| import pytest | |
| import json | |
| import uuid | |
| import httpx | |
| from time import perf_counter | |
| from datetime import datetime | |
| from unittest.mock import Mock, MagicMock, patch, call | |
| from typing import Dict, Any | |
| # Import from backend SQL_Agent package | |
| import sys | |
| import os | |
| # Ensure the project root (parent of `backend`) is on sys.path so | |
| # `import backend...` works when running this test file directly. | |
| project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) | |
| if project_root not in sys.path: | |
| sys.path.insert(0, project_root) | |
| from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit | |
| # ============================================================================ | |
| # Fixtures | |
| # ============================================================================ | |
| def toolkit(): | |
| """Create a DataSourcesSQLToolkit instance for testing.""" | |
| return DataSourcesSQLToolkit(api_base_url="http://test-api:8000") | |
| def mock_httpx_response(): | |
| """Create a mock httpx response.""" | |
| return Mock() | |
| # ============================================================================ | |
| # Test Initialization | |
| # ============================================================================ | |
| class TestToolkitInitialization: | |
| """Tests for toolkit initialization and setup.""" | |
| def test_init_default_values(self): | |
| """Test initialization with default values.""" | |
| toolkit = DataSourcesSQLToolkit() | |
| assert toolkit.api_base_url == "http://127.0.0.1:8000" | |
| assert toolkit.client is not None | |
| assert toolkit.client.base_url == "http://127.0.0.1:8000" | |
| def test_init_custom_values(self): | |
| """Test initialization with custom values.""" | |
| custom_url = "http://custom-api:9999" | |
| custom_timeout = 60.0 | |
| toolkit = DataSourcesSQLToolkit(api_base_url=custom_url, timeout=custom_timeout) | |
| assert toolkit.api_base_url == custom_url | |
| assert toolkit.client.base_url == custom_url | |
| def test_toolkit_has_required_tools(self, toolkit): | |
| """Test that toolkit is initialized with required tools.""" | |
| tool_names = [tool.__name__ for tool in toolkit.tools] | |
| assert "find_relevant_tables" in tool_names # NEW TOOL | |
| assert "search_schema" in tool_names | |
| assert "get_available_sources_and_schema" in tool_names | |
| assert "get_source_instructions" in tool_names | |
| assert "execute_sql_query" in tool_names | |
| # ============================================================================ | |
| # Test Schema Formatting | |
| # ============================================================================ | |
| class TestSchemaFormatting: | |
| """Tests for schema formatting for LLM consumption.""" | |
| def test_format_schema_with_valid_json(self, toolkit): | |
| """Test formatting valid schema JSON.""" | |
| source_name = "production_db" | |
| schema_json = json.dumps([ | |
| { | |
| "schema_name": "public", | |
| "tables": [ | |
| { | |
| "table_name": "users", | |
| "fields": [ | |
| {"name": "id", "type": "INTEGER", "example": "1"}, | |
| {"name": "email", "type": "VARCHAR", "example": "user@example.com"} | |
| ] | |
| } | |
| ] | |
| } | |
| ]) | |
| formatted = toolkit._format_schema_for_llm(source_name, schema_json) | |
| assert "production_db" in formatted | |
| assert "public" in formatted | |
| assert "users" in formatted | |
| assert "id" in formatted | |
| assert "INTEGER" in formatted | |
| def test_format_schema_with_empty_schema(self, toolkit): | |
| """Test formatting with empty schema.""" | |
| source_name = "empty_source" | |
| result = toolkit._format_schema_for_llm(source_name, None) | |
| assert "empty_source" in result | |
| assert "Not Available" in result | |
| def test_format_schema_with_invalid_json(self, toolkit): | |
| """Test formatting with invalid JSON.""" | |
| source_name = "broken_source" | |
| invalid_json = "{ invalid json }" | |
| result = toolkit._format_schema_for_llm(source_name, invalid_json) | |
| assert "broken_source" in result | |
| assert "Invalid Format" in result or "Error" in result | |
| def test_format_schema_with_multiple_tables(self, toolkit): | |
| """Test formatting with multiple tables.""" | |
| source_name = "test_db" | |
| schema_json = json.dumps([ | |
| { | |
| "schema_name": "public", | |
| "tables": [ | |
| { | |
| "table_name": "users", | |
| "fields": [ | |
| {"name": "id", "type": "INTEGER"} | |
| ] | |
| }, | |
| { | |
| "table_name": "orders", | |
| "fields": [ | |
| {"name": "order_id", "type": "INTEGER"}, | |
| {"name": "user_id", "type": "INTEGER"} | |
| ] | |
| } | |
| ] | |
| } | |
| ]) | |
| formatted = toolkit._format_schema_for_llm(source_name, schema_json) | |
| assert "users" in formatted | |
| assert "orders" in formatted | |
| # ============================================================================ | |
| # Test get_available_sources_and_schema | |
| # ============================================================================ | |
| class TestGetAvailableSourcesAndSchema: | |
| """Tests for retrieving available sources and their schemas.""" | |
| def test_missing_tenant_id(self, toolkit): | |
| """Test that missing tenant_id returns error.""" | |
| result = toolkit.get_available_sources_and_schema(session_state=None) | |
| assert "error" in result | |
| assert "Session state not found" in result["error"] | |
| def test_successful_schema_fetch(self, toolkit): | |
| """Test successful schema fetching for valid tenant.""" | |
| # Mock the list_sources call | |
| list_response = Mock() | |
| list_response.status_code = 200 | |
| list_response.json.return_value = { | |
| "available_sources": ["db1", "db2"], | |
| "count": 2 | |
| } | |
| # Mock httpx client responses for schema fetches | |
| schema_response_1 = Mock() | |
| schema_response_1.status_code = 200 | |
| schema_response_1.json.return_value = { | |
| "schema_data": json.dumps([{ | |
| "schema_name": "public", | |
| "tables": [{"table_name": "users", "fields": []}] | |
| }]) | |
| } | |
| schema_response_2 = Mock() | |
| schema_response_2.status_code = 200 | |
| schema_response_2.json.return_value = { | |
| "schema_data": json.dumps([{ | |
| "schema_name": "public", | |
| "tables": [{"table_name": "orders", "fields": []}] | |
| }]) | |
| } | |
| toolkit.client.get = Mock(side_effect=[list_response, schema_response_1, schema_response_2]) | |
| # Provide session_state with tenant_id and JWT | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.get_available_sources_and_schema( | |
| keywords=None, | |
| session_state=session_state | |
| ) | |
| assert "formatted_schema_string" in result | |
| assert "available_sources" in result | |
| assert len(result["available_sources"]) == 2 | |
| # Verify JWT token was used in headers | |
| calls = toolkit.client.get.call_args_list | |
| for call in calls[1:]: # Skip the list_sources call | |
| headers = call[1]["headers"] | |
| assert headers["Authorization"] == "Bearer jwt-token-123" | |
| def test_schema_fetch_with_api_error(self, toolkit): | |
| """Test handling of API errors during schema fetch.""" | |
| # Mock list_sources returning one source | |
| list_response = Mock() | |
| list_response.status_code = 200 | |
| list_response.json.return_value = { | |
| "available_sources": ["db1"], | |
| "count": 1 | |
| } | |
| # Mock error response for schema fetch | |
| schema_response = Mock() | |
| schema_response.status_code = 500 | |
| schema_response.text = "Internal Server Error" | |
| toolkit.client.get = Mock(side_effect=[list_response, schema_response]) | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.get_available_sources_and_schema(session_state=session_state) | |
| assert "formatted_schema_string" in result | |
| assert "available_sources" in result | |
| def test_schema_fetch_import_error(self, toolkit): | |
| """Test handling when API connection fails.""" | |
| # Mock list_sources to raise connection error | |
| toolkit.client.get = Mock(side_effect=httpx.RequestError("Connection failed")) | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.get_available_sources_and_schema(session_state=session_state) | |
| assert "error" in result | |
| assert "API connection error" in result["error"] | |
| # ============================================================================ | |
| # Test get_source_instructions | |
| # ============================================================================ | |
| class TestGetSourceInstructions: | |
| """Tests for retrieving source-specific instructions.""" | |
| def test_missing_tenant_id(self, toolkit): | |
| """Test that missing tenant_id returns error.""" | |
| result = toolkit.get_source_instructions(session_state=None) | |
| assert "error" in result | |
| assert "Session state not found" in result["error"] | |
| def test_missing_source_name(self, toolkit): | |
| """Test that missing source_name returns error.""" | |
| result = toolkit.get_source_instructions(session_state=None) | |
| assert "error" in result | |
| assert "Session state not found" in result["error"] | |
| def test_successful_instructions_fetch(self, toolkit): | |
| """Test successful retrieval of instructions.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "instructions": "Use SELECT ... FROM syntax. Supports CTEs with WITH clause." | |
| } | |
| toolkit.client.get = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "production_db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.get_source_instructions( | |
| source_name="production_db", | |
| session_state=session_state | |
| ) | |
| assert "instructions" in result | |
| assert "SELECT" in result["instructions"] | |
| toolkit.client.get.assert_called_once() | |
| def test_instructions_not_found(self, toolkit): | |
| """Test handling of 404 when instructions not found.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 404 | |
| toolkit.client.get = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "unknown_db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.get_source_instructions( | |
| source_name="unknown_db", | |
| session_state=session_state | |
| ) | |
| assert "error" in result | |
| assert "not found" in result["error"].lower() | |
| def test_instructions_api_error(self, toolkit): | |
| """Test handling of API errors.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 500 | |
| mock_response.text = "Server Error" | |
| toolkit.client.get = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.get_source_instructions( | |
| source_name="db", | |
| session_state=session_state | |
| ) | |
| assert "error" in result | |
| assert "Failed" in result["error"] | |
| def test_instructions_connection_error(self, toolkit): | |
| """Test handling of connection errors.""" | |
| import httpx | |
| toolkit.client.get = Mock(side_effect=httpx.RequestError("Connection refused")) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.get_source_instructions( | |
| source_name="db", | |
| session_state=session_state | |
| ) | |
| assert "error" in result | |
| assert "connection error" in result["error"].lower() | |
| # ============================================================================ | |
| # Test execute_sql_query | |
| # ============================================================================ | |
| class TestExecuteSQLQuery: | |
| """Tests for SQL query execution.""" | |
| def test_missing_tenant_id(self, toolkit): | |
| """Test that missing tenant_id returns error.""" | |
| result = toolkit.execute_sql_query( | |
| sql_query="SELECT * FROM users", | |
| session_state=None | |
| ) | |
| assert "error" in result | |
| assert "Session state not found" in result["error"] | |
| def test_missing_source_name(self, toolkit): | |
| """Test that missing source_name returns error.""" | |
| result = toolkit.execute_sql_query( | |
| sql_query="SELECT * FROM users", | |
| session_state=None | |
| ) | |
| assert "error" in result | |
| assert "Session state not found" in result["error"] | |
| def test_empty_sql_query(self, toolkit): | |
| """Test that empty SQL query returns error.""" | |
| result = toolkit.execute_sql_query( | |
| sql_query=" ", | |
| session_state={"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| ) | |
| assert "error" in result | |
| assert "empty" in result["error"].lower() | |
| def test_successful_query_execution(self, toolkit): | |
| """Test successful SQL query execution.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "results": [ | |
| {"id": 1, "name": "Alice"}, | |
| {"id": 2, "name": "Bob"} | |
| ] | |
| } | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "production_db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.execute_sql_query( | |
| sql_query="SELECT * FROM users", | |
| session_state=session_state | |
| ) | |
| assert result["status"] == "success" | |
| assert len(result["results"]) == 2 | |
| assert result["results"][0]["name"] == "Alice" | |
| def test_query_blocked_for_unsafe_operations(self, toolkit): | |
| """Test that unsafe SQL operations are blocked client-side.""" | |
| unsafe_queries = [ | |
| "DROP TABLE users", | |
| "DELETE FROM users", | |
| "UPDATE users SET name = 'hacked'", | |
| "INSERT INTO users VALUES (1, 'hacked')", | |
| "TRUNCATE TABLE users" | |
| ] | |
| session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| for unsafe_query in unsafe_queries: | |
| result = toolkit.execute_sql_query( | |
| sql_query=unsafe_query, | |
| session_state=session_state | |
| ) | |
| assert "error" in result | |
| assert "blocked" in result["error"].lower() | |
| def test_query_allowed_for_safe_operations(self, toolkit): | |
| """Test that safe SQL operations are allowed.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = {"results": []} | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| safe_queries = [ | |
| "SELECT * FROM users", | |
| "WITH cte AS (SELECT * FROM users) SELECT * FROM cte", | |
| "SHOW TABLES", | |
| "DESCRIBE users", | |
| "EXPLAIN SELECT * FROM users" | |
| ] | |
| for safe_query in safe_queries: | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query=safe_query, | |
| session_state=session_state | |
| ) | |
| # Should not have error (or should have been processed) | |
| # Note: for SHOW, DESCRIBE, EXPLAIN - client-side check passes, | |
| # then POST is attempted which we're mocking to succeed | |
| assert "status" in result or "error" not in result.get("error", "").lower() | |
| def test_query_api_error(self, toolkit): | |
| """Test handling of API errors during execution.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 400 | |
| mock_response.json.return_value = {"detail": "Syntax error in query"} | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query="SELECT * FROM users", | |
| session_state=session_state | |
| ) | |
| assert "error" in result | |
| assert "Syntax error" in result["error"] | |
| def test_query_timeout(self, toolkit): | |
| """Test handling of request timeouts.""" | |
| import httpx | |
| toolkit.client.post = Mock(side_effect=httpx.TimeoutException("Request timed out")) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query="SELECT * FROM users", | |
| session_state=session_state | |
| ) | |
| assert "error" in result | |
| assert "timed out" in result["error"].lower() | |
| def test_query_connection_error(self, toolkit): | |
| """Test handling of connection errors.""" | |
| import httpx | |
| toolkit.client.post = Mock(side_effect=httpx.RequestError("Connection refused")) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query="SELECT * FROM users", | |
| session_state=session_state | |
| ) | |
| assert "error" in result | |
| assert "connection error" in result["error"].lower() | |
| def test_query_malformed_response(self, toolkit): | |
| """Test handling of malformed API responses.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 500 | |
| mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) | |
| mock_response.text = "Internal Server Error" | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query="SELECT * FROM users", | |
| session_state=session_state | |
| ) | |
| assert "error" in result | |
| assert "Internal Server Error" in result["error"] | |
| # ============================================================================ | |
| # Test Multi-tenant Isolation | |
| # ============================================================================ | |
| class TestMultiTenantIsolation: | |
| """Tests for ensuring proper multi-tenant isolation.""" | |
| def test_different_tenants_use_different_params(self, toolkit): | |
| """Test that different tenants get different API parameters.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = {"results": []} | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| # Execute query for tenant 1 | |
| session_state_1 = {"tenant_id": "tenant-1", "supabase_jwt": "jwt-token-1", "source_name": "db"} | |
| toolkit.execute_sql_query( | |
| sql_query="SELECT * FROM users", | |
| session_state=session_state_1 | |
| ) | |
| # Get first call's payload | |
| first_call_payload = toolkit.client.post.call_args[1]["json"] | |
| assert first_call_payload["tenant_id"] == "tenant-1" | |
| # Execute query for tenant 2 | |
| session_state_2 = {"tenant_id": "tenant-2", "supabase_jwt": "jwt-token-2", "source_name": "db"} | |
| toolkit.execute_sql_query( | |
| sql_query="SELECT * FROM orders", | |
| session_state=session_state_2 | |
| ) | |
| # Get second call's payload | |
| second_call_payload = toolkit.client.post.call_args[1]["json"] | |
| assert second_call_payload["tenant_id"] == "tenant-2" | |
| def test_tenant_id_passed_to_all_endpoints(self, toolkit): | |
| """Test that tenant_id is passed in request payload (not query params) for endpoints that need it.""" | |
| # Test execute_sql_query - tenant_id should be in payload | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = {"results": []} | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-x", "source_name": "db", "supabase_jwt": "jwt-token-123"} | |
| toolkit.execute_sql_query( | |
| sql_query="SELECT * FROM users", | |
| session_state=session_state | |
| ) | |
| # Get the call's payload | |
| call_kwargs = toolkit.client.post.call_args[1] | |
| payload = call_kwargs["json"] | |
| headers = call_kwargs["headers"] | |
| # Verify tenant_id is in payload, not query params | |
| assert payload["tenant_id"] == "tenant-x" | |
| # Verify JWT is in Authorization header | |
| assert headers["Authorization"] == "Bearer jwt-token-123" | |
| # ============================================================================ | |
| # Test Edge Cases | |
| # ============================================================================ | |
| class TestEdgeCases: | |
| """Tests for edge cases and boundary conditions.""" | |
| def test_very_long_sql_query(self, toolkit): | |
| """Test handling of very long SQL queries.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = {"results": []} | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| long_query = "SELECT * FROM users WHERE id IN (" + ",".join(str(i) for i in range(1000)) + ")" | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"} | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query=long_query, | |
| session_state=session_state | |
| ) | |
| assert "error" not in result or result.get("status") == "success" | |
| def test_special_characters_in_query(self, toolkit): | |
| """Test handling of special characters in SQL queries.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = {"results": []} | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| query_with_special_chars = "SELECT * FROM users WHERE name = 'O''Reilly' AND email LIKE '%@%.com%'" | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"} | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query=query_with_special_chars, | |
| session_state=session_state | |
| ) | |
| # Should be posted without error | |
| toolkit.client.post.assert_called_once() | |
| def test_empty_results(self, toolkit): | |
| """Test handling of empty result sets.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = {"results": []} | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"} | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query="SELECT * FROM users WHERE id = -1", | |
| session_state=session_state | |
| ) | |
| assert result["status"] == "success" | |
| assert result["results"] == [] | |
| def test_large_result_set(self, toolkit): | |
| """Test handling of large result sets.""" | |
| # Create a large result set | |
| large_results = [{"id": i, "name": f"user_{i}"} for i in range(10000)] | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = {"results": large_results} | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"} | |
| result = toolkit.execute_sql_query( | |
| source_name="db", | |
| sql_query="SELECT * FROM users", | |
| session_state=session_state | |
| ) | |
| assert result["status"] == "success" | |
| assert len(result["results"]) == 10000 | |
| # ============================================================================ | |
| # Test find_relevant_tables (Hybrid Keyword Extraction) | |
| # ============================================================================ | |
| class TestFindRelevantTables: | |
| """Tests for the new find_relevant_tables tool with hybrid keyword extraction.""" | |
| def test_missing_tenant_id(self, toolkit): | |
| """Test that missing tenant_id returns error.""" | |
| result = toolkit.find_relevant_tables( | |
| tenant_id="", | |
| question="What are the total sales?" | |
| ) | |
| assert "error" in result | |
| assert "Tenant ID is required" in result["error"] | |
| def test_missing_question(self, toolkit): | |
| """Test that missing question returns error.""" | |
| result = toolkit.find_relevant_tables( | |
| tenant_id="tenant-123", | |
| question="" | |
| ) | |
| assert "error" in result | |
| assert "Question is required" in result["error"] | |
| def test_find_relevant_tables_merges_concepts_and_semantic_hints(self, mock_extract, toolkit): | |
| """Test that find_relevant_tables properly merges concepts and semantic hints.""" | |
| # Mock hybrid keyword extraction | |
| mock_extract.return_value = { | |
| 'base': ['revenue', 'customers'], | |
| 'semantic': ['sales metrics', 'customer data'], | |
| 'concepts': ['premium', 'quarterly'], | |
| 'combined': ['premium', 'quarterly', 'revenue', 'customers', 'sales metrics', 'customer data'] | |
| } | |
| # Mock the search_schema call | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "formatted_schema_string": "Test schema", | |
| "matches": [{"table_name": "sales", "score": 15.0, "matched_columns": ["revenue"], "source_name": "db1"}], | |
| "available_sources": ["db1"], | |
| "total_matches": 1 | |
| } | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| # Create session state | |
| session_state = {"keyword_extraction_cache": {}} | |
| # Call find_relevant_tables with agent concepts | |
| result = toolkit.find_relevant_tables( | |
| tenant_id="tenant-123", | |
| question="What was revenue from premium customers last quarter?", | |
| concepts=["premium", "quarterly"], | |
| session_state=session_state | |
| ) | |
| # Verify extraction was called | |
| mock_extract.assert_called_once() | |
| call_kwargs = mock_extract.call_args[1] | |
| assert call_kwargs['question'] == "What was revenue from premium customers last quarter?" | |
| assert call_kwargs['llm_concepts'] == ["premium", "quarterly"] | |
| # Verify result structure | |
| assert "formatted_schema_string" in result | |
| assert "keyword_breakdown" in result | |
| assert result["keyword_breakdown"]["base"] == ['revenue', 'customers'] | |
| assert result["keyword_breakdown"]["semantic"] == ['sales metrics', 'customer data'] | |
| assert result["keyword_breakdown"]["concepts"] == ['premium', 'quarterly'] | |
| assert result["original_question"] == "What was revenue from premium customers last quarter?" | |
| # Verify merged keywords were sent to API | |
| api_call_payload = toolkit.client.post.call_args[1]["json"] | |
| assert set(api_call_payload["keywords"]) == {'premium', 'quarterly', 'revenue', 'customers', 'sales metrics', 'customer data'} | |
| def test_find_relevant_tables_handles_gemini_failure(self, mock_extract, toolkit): | |
| """Test that find_relevant_tables handles Gemini failure gracefully.""" | |
| # Mock hybrid extraction with no semantic hints (Gemini failed) | |
| mock_extract.return_value = { | |
| 'base': ['sales', 'customers'], | |
| 'semantic': [], # Empty because Gemini failed | |
| 'concepts': ['premium'], | |
| 'combined': ['premium', 'sales', 'customers'] | |
| } | |
| # Mock API response | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "formatted_schema_string": "Schema", | |
| "matches": [], | |
| "available_sources": ["db1"], | |
| "total_matches": 0 | |
| } | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| # Call should succeed despite Gemini failure | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.find_relevant_tables( | |
| question="Show premium customer sales", | |
| concepts=["premium"], | |
| session_state=session_state | |
| ) | |
| assert "error" not in result | |
| assert result["keyword_breakdown"]["semantic"] == [] | |
| assert "premium" in result["keyword_breakdown"]["combined"] | |
| assert "sales" in result["keyword_breakdown"]["combined"] | |
| def test_find_relevant_tables_uses_cache_for_repeat_question(self, mock_extract, toolkit): | |
| """Test that find_relevant_tables caches keyword extraction per question.""" | |
| # Mock hybrid extraction | |
| mock_extract.return_value = { | |
| 'base': ['revenue'], | |
| 'semantic': ['financial metrics'], | |
| 'concepts': [], | |
| 'combined': ['revenue', 'financial metrics'] | |
| } | |
| # Mock API response | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "formatted_schema_string": "Schema", | |
| "matches": [], | |
| "available_sources": [], | |
| "total_matches": 0 | |
| } | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| # Create session state | |
| session_state = {"keyword_extraction_cache": {}} | |
| # First call | |
| toolkit.find_relevant_tables( | |
| question="What is the total revenue?", | |
| session_state=session_state | |
| ) | |
| # Verify extraction was called once | |
| assert mock_extract.call_count == 1 | |
| assert len(session_state["keyword_extraction_cache"]) == 1 | |
| # Second call with same question | |
| toolkit.find_relevant_tables( | |
| question="What is the total revenue?", | |
| session_state=session_state | |
| ) | |
| # Verify extraction was NOT called again (cache hit) | |
| assert mock_extract.call_count == 1 # Still 1, not 2 | |
| # Third call with different question | |
| toolkit.find_relevant_tables( | |
| question="Show me customer data", | |
| session_state=session_state | |
| ) | |
| # Verify extraction was called for new question | |
| assert mock_extract.call_count == 2 | |
| assert len(session_state["keyword_extraction_cache"]) == 2 | |
| def test_find_relevant_tables_without_hybrid_utils(self, toolkit): | |
| """Test fallback behavior when hybrid_keyword_utils is not available.""" | |
| # Temporarily disable hybrid extraction | |
| original_extract = toolkit.__class__.__module__ | |
| with patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords", None): | |
| # Mock API response | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "formatted_schema_string": "Schema", | |
| "matches": [], | |
| "available_sources": [], | |
| "total_matches": 0 | |
| } | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| # Should fall back to simple keyword extraction | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} | |
| result = toolkit.find_relevant_tables( | |
| question="What is the revenue from premium customers?", | |
| concepts=["premium"], | |
| session_state=session_state | |
| ) | |
| # Verify fallback worked | |
| assert "error" not in result | |
| # Check that API was called with some keywords | |
| api_payload = toolkit.client.post.call_args[1]["json"] | |
| assert len(api_payload["keywords"]) > 0 | |
| def test_find_relevant_tables_sends_metadata_to_api(self, mock_extract, toolkit): | |
| """Test that find_relevant_tables sends metadata to API for analytics.""" | |
| # Mock hybrid extraction | |
| mock_extract.return_value = { | |
| 'base': ['sales'], | |
| 'semantic': ['revenue data'], | |
| 'concepts': ['monthly'], | |
| 'combined': ['monthly', 'sales', 'revenue data'] | |
| } | |
| # Mock API response | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "formatted_schema_string": "Schema", | |
| "matches": [], | |
| "available_sources": [], | |
| "total_matches": 0 | |
| } | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| # Call with question | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} | |
| toolkit.find_relevant_tables( | |
| question="Show monthly sales data", | |
| concepts=["monthly"], | |
| session_state=session_state | |
| ) | |
| # Verify metadata was sent to API | |
| api_payload = toolkit.client.post.call_args[1]["json"] | |
| assert "original_question" in api_payload | |
| assert api_payload["original_question"] == "Show monthly sales data" | |
| assert "keyword_metadata" in api_payload | |
| assert api_payload["keyword_metadata"]["base"] == ['sales'] | |
| assert api_payload["keyword_metadata"]["semantic"] == ['revenue data'] | |
| assert api_payload["keyword_metadata"]["concepts"] == ['monthly'] | |
| def test_find_relevant_tables_with_source_filter(self, mock_extract, toolkit): | |
| """Test that find_relevant_tables passes source_names filter correctly.""" | |
| # Mock hybrid extraction | |
| mock_extract.return_value = { | |
| 'base': ['users'], | |
| 'semantic': [], | |
| 'concepts': [], | |
| 'combined': ['users'] | |
| } | |
| # Mock API response | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "formatted_schema_string": "Schema", | |
| "matches": [], | |
| "available_sources": ["db1"], | |
| "total_matches": 0 | |
| } | |
| toolkit.client.post = Mock(return_value=mock_response) | |
| # Call with source filter | |
| session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"} | |
| toolkit.find_relevant_tables( | |
| question="Show user data", | |
| source_names=["db1", "db2"], | |
| session_state=session_state | |
| ) | |
| # Verify source filter was passed to API | |
| api_payload = toolkit.client.post.call_args[1]["json"] | |
| assert "source_names" in api_payload | |
| assert api_payload["source_names"] == ["db1", "db2"] | |
| # ============================================================================ | |
| # End-to-End Integration Tests with Real Database | |
| # ============================================================================ | |
| class TestEndToEndIntegration: | |
| """ | |
| End-to-end integration tests with real PostgreSQL database using REAL API calls. | |
| Uses requests library to hit the actual running API server. | |
| Tests the complete workflow: keyword search -> schema fetch -> SQL execution. | |
| """ | |
| def live_api_setup(self): | |
| """Provision a real tenant + source against the live data_sources API using requests.""" | |
| import requests | |
| import uuid | |
| if os.environ.get("SKIP_INTEGRATION_TESTS") == "1": | |
| pytest.skip("Integration tests skipped via SKIP_INTEGRATION_TESTS=1") | |
| base_url = os.environ.get("DATA_SOURCES_API_BASE_URL", "http://127.0.0.1:8000") | |
| admin_key = os.environ.get("SIRUS_ADMIN_API_KEY") | |
| if not admin_key: | |
| pytest.skip("SIRUS_ADMIN_API_KEY not configured") | |
| tenant_id = f"toolkit_e2e_{uuid.uuid4().hex[:8]}" | |
| source_name = "scv_sample_db" | |
| description = f"E2E Test {tenant_id}" | |
| session = requests.Session() | |
| session.timeout = 30 | |
| # Quick availability check | |
| try: | |
| health = session.get(f"{base_url}/api/v1/data-sources/list", params={"tenant_id": "__healthcheck__"}) | |
| if health.status_code not in (200, 400, 404, 401): | |
| pytest.skip(f"API health check failed (status {health.status_code})") | |
| except requests.RequestException as exc: | |
| pytest.skip(f"API unreachable: {exc}") | |
| # Create tenant API key via admin route | |
| api_key = None | |
| key_id = None | |
| headers_admin = {"X-Sirus-Admin-Key": admin_key} | |
| create_key_payload = { | |
| "description": description, | |
| "expires_in_days": 14 | |
| } | |
| resp = session.post( | |
| f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys", | |
| json=create_key_payload, | |
| headers=headers_admin | |
| ) | |
| if resp.status_code == 201: | |
| data = resp.json() | |
| api_key = data.get("api_key") | |
| key_info = data.get("key_info", {}) | |
| key_id = key_info.get("key_id") | |
| else: | |
| # Attempt to reuse an existing key if creation failed | |
| list_resp = session.get( | |
| f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys", | |
| headers=headers_admin | |
| ) | |
| if list_resp.status_code == 200: | |
| keys = list_resp.json().get("keys", []) | |
| if keys: | |
| key_id = keys[0].get("key_id") | |
| pytest.skip("Existing tenant API keys found but raw value unavailable") | |
| pytest.skip(f"Failed to create tenant API key (status {resp.status_code}): {resp.text[:200]}") | |
| if not api_key: | |
| pytest.skip("Tenant API key could not be provisioned") | |
| tenant_headers = {"X-Sirus-Api-Key": api_key} | |
| # Upsert tenant source pointing at live Postgres | |
| source_payload = { | |
| "sources": [ | |
| { | |
| "source_name": source_name, | |
| "source_type": "ibis", | |
| "config": { | |
| "uri": "postgresql://neondb_owner:npg_dfWNsn2ZGk7c@ep-cool-poetry-a1puamly-pooler.ap-southeast-1.aws.neon.tech:5432/scv-sample?sslmode=require", | |
| "table_fetch_example_limit": 5 | |
| } | |
| } | |
| ], | |
| "validate_connection": True | |
| } | |
| upsert_resp = session.post( | |
| f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/sources", | |
| json=source_payload, | |
| headers=tenant_headers | |
| ) | |
| if upsert_resp.status_code not in (200, 201): | |
| pytest.skip( | |
| f"Failed to configure tenant source (status {upsert_resp.status_code}): {upsert_resp.text[:200]}" | |
| ) | |
| setup = { | |
| "base_url": base_url, | |
| "tenant_id": tenant_id, | |
| "source_name": source_name, | |
| "api_key": api_key, | |
| "key_id": key_id, | |
| "session": session, | |
| "headers_admin": headers_admin, | |
| "tenant_headers": tenant_headers, | |
| "timeout": 120.0 | |
| } | |
| try: | |
| yield setup | |
| finally: | |
| # Cleanup: delete source and revoke key | |
| try: | |
| session.delete( | |
| f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/sources/{source_name}", | |
| headers=tenant_headers | |
| ) | |
| except Exception: | |
| pass | |
| if key_id: | |
| try: | |
| session.delete( | |
| f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys/{key_id}", | |
| headers=headers_admin | |
| ) | |
| except Exception: | |
| pass | |
| session.close() | |
| def live_toolkit(self, live_api_setup): | |
| """Toolkit connected to real API with proper authentication.""" | |
| api_key = live_api_setup.get("api_key") | |
| base_url = live_api_setup.get("base_url", "http://127.0.0.1:8000") | |
| return DataSourcesSQLToolkit(api_base_url=base_url, timeout=120.0, api_key=api_key) | |
| def test_e2e_keyword_search_with_real_db(self, live_toolkit, live_api_setup, caplog): | |
| """ | |
| Test end-to-end workflow with keyword search on real database using REAL API calls. | |
| Workflow: | |
| 1. Search schema with keywords via real API | |
| 2. Verify matched tables | |
| 3. Execute SQL query on matched table via real API | |
| 4. Measure timing for each step | |
| """ | |
| import time | |
| import logging | |
| # Enable detailed logging | |
| caplog.set_level(logging.INFO) | |
| print("\n" + "="*80) | |
| print("END-TO-END INTEGRATION TEST: Keyword Search -> SQL Execution") | |
| print("="*80) | |
| # Test configuration from provisioned setup | |
| tenant_id = live_api_setup["tenant_id"] | |
| source_name = live_api_setup["source_name"] | |
| # Step 1: Search schema with business keywords | |
| print("\n[STEP 1] Searching schema with keywords...") | |
| search_start = time.time() | |
| keywords = ["customer", "order", "product", "sales"] | |
| search_result = live_toolkit.search_schema( | |
| keywords=keywords, | |
| include_samples=True, | |
| session_state={"tenant_id": tenant_id} | |
| ) | |
| search_duration = time.time() - search_start | |
| print(f"✓ Schema search completed in {search_duration:.3f}s") | |
| # Verify search results | |
| if "error" in search_result: | |
| print(f"⚠ Search returned error (API may not be running): {search_result['error']}") | |
| pytest.skip("API not available for integration test") | |
| assert "matches" in search_result or "formatted_schema_string" in search_result | |
| print(f" - Total matches: {search_result.get('total_matches', 0)}") | |
| print(f" - Available sources: {search_result.get('available_sources', [])}") | |
| print(f" - Cache hit: {search_result.get('cache_hit', False)}") | |
| # Log matched tables | |
| if "matches" in search_result and search_result["matches"]: | |
| print("\n Matched Tables:") | |
| for match in search_result["matches"][:5]: # Show top 5 | |
| print(f" - {match.get('table_name', 'N/A')} (score: {match.get('score', 0)})") | |
| if "matched_columns" in match: | |
| print(f" Columns: {', '.join(match['matched_columns'][:5])}") | |
| # Step 2: Get source instructions | |
| print("\n[STEP 2] Fetching SQL dialect instructions...") | |
| instructions_start = time.time() | |
| session_state = {"tenant_id": tenant_id, "source_name": source_name} | |
| instructions_result = live_toolkit.get_source_instructions( | |
| source_name=source_name, | |
| session_state=session_state | |
| ) | |
| instructions_duration = time.time() - instructions_start | |
| print(f"✓ Instructions fetched in {instructions_duration:.3f}s") | |
| if "error" not in instructions_result: | |
| print(f" - Instructions: {instructions_result.get('instructions', 'N/A')[:100]}...") | |
| # Step 3: Execute SQL query on discovered table | |
| print("\n[STEP 3] Executing SQL query on real database...") | |
| # Use a safe SELECT query that should work on most schemas | |
| test_queries = [ | |
| "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' LIMIT 5", | |
| "SELECT current_database(), current_schema()", | |
| "SELECT version()" | |
| ] | |
| for idx, sql_query in enumerate(test_queries, 1): | |
| print(f"\n Query {idx}: {sql_query}") | |
| query_start = time.time() | |
| query_result = live_toolkit.execute_sql_query( | |
| source_name=source_name, | |
| sql_query=sql_query, | |
| session_state={"tenant_id": tenant_id} | |
| ) | |
| query_duration = time.time() - query_start | |
| print(f" ✓ Query executed in {query_duration:.3f}s") | |
| if "error" in query_result: | |
| print(f" ⚠ Query error: {query_result['error']}") | |
| elif "results" in query_result: | |
| results = query_result["results"] | |
| print(f" - Rows returned: {len(results)}") | |
| if results and len(results) > 0: | |
| print(f" - Sample row: {results[0]}") | |
| # Verify result structure | |
| assert "status" in query_result or "error" in query_result or "results" in query_result | |
| # Step 4: Test multi-tenant isolation | |
| print("\n[STEP 4] Testing multi-tenant isolation...") | |
| tenant_2_id = "test_tenant_e2e_002" | |
| isolation_start = time.time() | |
| # Same query, different tenant | |
| isolation_result = live_toolkit.execute_sql_query( | |
| source_name=source_name, | |
| sql_query="SELECT current_database()", | |
| session_state={"tenant_id": tenant_2_id} | |
| ) | |
| isolation_duration = time.time() - isolation_start | |
| print(f"✓ Tenant isolation verified in {isolation_duration:.3f}s") | |
| print(f" - Tenant 2 query processed independently") | |
| # Step 5: Test caching behavior | |
| print("\n[STEP 5] Testing cache behavior...") | |
| # Repeat search with same keywords | |
| cache_start = time.time() | |
| cached_search_result = live_toolkit.search_schema( | |
| keywords=keywords, | |
| include_samples=False, | |
| session_state={"tenant_id": tenant_id} | |
| ) | |
| cache_duration = time.time() - cache_start | |
| print(f"✓ Cached search completed in {cache_duration:.3f}s") | |
| print(f" - Cache hit: {cached_search_result.get('cache_hit', False)}") | |
| print(f" - Speedup: {(search_duration / cache_duration):.2f}x faster" if cache_duration > 0 else " - Instant cache hit") | |
| # Summary | |
| print("\n" + "="*80) | |
| print("END-TO-END TEST SUMMARY") | |
| print("="*80) | |
| print(f"Total workflow time: {(search_duration + instructions_duration + query_duration):.3f}s") | |
| print(f" - Schema search: {search_duration:.3f}s") | |
| print(f" - Instructions: {instructions_duration:.3f}s") | |
| print(f" - SQL execution: {query_duration:.3f}s") | |
| print(f" - Cache speedup: {(search_duration / cache_duration):.2f}x" if cache_duration > 0 else " - N/A") | |
| print("="*80 + "\n") | |
| def test_e2e_session_state_caching(self, live_toolkit, live_api_setup, caplog): | |
| """ | |
| Test session state caching across multiple toolkit calls. | |
| Simulates agent behavior with session state persistence. | |
| """ | |
| import time | |
| print("\n" + "="*80) | |
| print("END-TO-END TEST: Session State Caching") | |
| print("="*80) | |
| tenant_id = live_api_setup["tenant_id"] | |
| # Simulate session state (like agent would provide) | |
| session_state = { | |
| "schema_search_cache": {}, | |
| "extracted_keywords": [], | |
| "tool_execution_log": [] | |
| } | |
| keywords = ["revenue", "customer", "transaction"] | |
| # First call - should miss cache | |
| print("\n[Call 1] Schema search without cache...") | |
| start_1 = time.time() | |
| result_1 = live_toolkit.search_schema( | |
| keywords=keywords, | |
| session_state=session_state | |
| ) | |
| duration_1 = time.time() - start_1 | |
| if "error" in result_1: | |
| pytest.skip("API not available") | |
| print(f"✓ First call: {duration_1:.3f}s, cache_hit={result_1.get('cache_hit', False)}") | |
| # Second call - should hit session cache | |
| print("\n[Call 2] Schema search with session cache...") | |
| start_2 = time.time() | |
| result_2 = live_toolkit.search_schema( | |
| keywords=keywords, | |
| session_state=session_state | |
| ) | |
| duration_2 = time.time() - start_2 | |
| print(f"✓ Second call: {duration_2:.3f}s, cache_hit={result_2.get('cache_hit', False)}") | |
| print(f" Cache source: {result_2.get('cache_source', 'N/A')}") | |
| # Verify caching worked | |
| assert result_2.get('cache_hit') == True, "Second call should hit cache" | |
| assert duration_2 < duration_1, "Cached call should be faster" | |
| speedup = duration_1 / duration_2 if duration_2 > 0 else float('inf') | |
| print(f"\n✓ Cache speedup: {speedup:.2f}x faster") | |
| print(f" Session cache keys: {len(session_state.get('schema_search_cache', {}))}") | |
| print("="*80 + "\n") | |
| def test_e2e_keyword_extraction_workflow(self, live_toolkit, live_api_setup, caplog): | |
| """ | |
| Test realistic agent workflow: keyword extraction -> search -> SQL execution. | |
| Simulates what agent.py would do with execute_query_with_tracking. | |
| """ | |
| import time | |
| import re | |
| print("\n" + "="*80) | |
| print("END-TO-END TEST: Agent Keyword Extraction Workflow") | |
| print("="*80) | |
| tenant_id = live_api_setup["tenant_id"] | |
| source_name = live_api_setup["source_name"] | |
| # Simulate user questions (like agent would receive) | |
| user_questions = [ | |
| "What are the total sales by customer?", | |
| "Show me all products with low inventory", | |
| "Get customer contact information for premium accounts" | |
| ] | |
| for q_idx, question in enumerate(user_questions, 1): | |
| print(f"\n[Question {q_idx}] {question}") | |
| # Step 1: Extract keywords (simulating agent.py logic) | |
| stopwords = {'what', 'are', 'the', 'show', 'me', 'all', 'with', 'get', 'for', 'by'} | |
| words = re.findall(r'\b[a-z]+\b', question.lower()) | |
| keywords = [w for w in words if w not in stopwords and len(w) > 2] | |
| print(f" Extracted keywords: {keywords}") | |
| # Step 2: Search schema | |
| search_start = time.time() | |
| search_result = live_toolkit.search_schema( | |
| keywords=keywords[:5], # Limit to top 5 | |
| session_state={"tenant_id": tenant_id} | |
| ) | |
| search_duration = time.time() - search_start | |
| if "error" in search_result: | |
| print(f" ⚠ API not available: {search_result['error']}") | |
| pytest.skip("API not running") | |
| print(f" ✓ Search: {search_duration:.3f}s, matches={search_result.get('total_matches', 0)}") | |
| # Step 3: Agent would now generate SQL based on matched tables | |
| # For this test, we'll use a generic query | |
| if search_result.get('total_matches', 0) > 0: | |
| sql_query = "SELECT 1 as test_column" # Safe query | |
| query_start = time.time() | |
| query_result = live_toolkit.execute_sql_query( | |
| source_name=source_name, | |
| sql_query=sql_query, | |
| session_state={"tenant_id": tenant_id} | |
| ) | |
| query_duration = time.time() - query_start | |
| print(f" ✓ Query: {query_duration:.3f}s") | |
| if "results" in query_result: | |
| print(f" Results: {len(query_result['results'])} rows") | |
| print(f" Total workflow: {(search_duration + query_duration if 'query_duration' in locals() else search_duration):.3f}s") | |
| print("\n" + "="*80) | |
| print("✓ All workflow tests completed successfully") | |
| print("="*80 + "\n") | |
| def test_e2e_performance_benchmarks(self, live_toolkit, live_api_setup, caplog): | |
| """ | |
| Performance benchmark tests for the complete toolkit. | |
| Measures timing for various operations under realistic load. | |
| """ | |
| import time | |
| import statistics | |
| print("\n" + "="*80) | |
| print("PERFORMANCE BENCHMARKS") | |
| print("="*80) | |
| tenant_id = live_api_setup["tenant_id"] | |
| # Benchmark 1: Schema search performance | |
| print("\n[Benchmark 1] Schema search (10 iterations)...") | |
| search_times = [] | |
| for i in range(10): | |
| keywords = [f"test_keyword_{i % 3}", "customer", "order"] | |
| start = time.time() | |
| result = live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id}) | |
| duration = time.time() - start | |
| if "error" not in result: | |
| search_times.append(duration) | |
| if search_times: | |
| print(f" Average: {statistics.mean(search_times):.3f}s") | |
| print(f" Median: {statistics.median(search_times):.3f}s") | |
| print(f" Min: {min(search_times):.3f}s") | |
| print(f" Max: {max(search_times):.3f}s") | |
| print(f" Std Dev: {statistics.stdev(search_times):.3f}s" if len(search_times) > 1 else " N/A") | |
| else: | |
| print(" ⚠ No successful searches (API may be down)") | |
| pytest.skip("API not available") | |
| # Benchmark 2: Cache hit performance | |
| print("\n[Benchmark 2] Cache hit performance...") | |
| keywords = ["benchmark", "cache", "test"] | |
| # First call (cache miss) | |
| start = time.time() | |
| live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id}) | |
| miss_time = time.time() - start | |
| # Subsequent calls (cache hits) | |
| hit_times = [] | |
| for _ in range(5): | |
| start = time.time() | |
| live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id}) | |
| hit_times.append(time.time() - start) | |
| avg_hit_time = statistics.mean(hit_times) | |
| print(f" Cache miss: {miss_time:.3f}s") | |
| print(f" Cache hit avg: {avg_hit_time:.3f}s") | |
| print(f" Speedup: {(miss_time / avg_hit_time):.2f}x") | |
| print("\n" + "="*80) | |
| print("✓ Performance benchmarks completed") | |
| print("="*80 + "\n") | |
| def test_full_live_user_flow(self, live_api_setup): | |
| """Validate an end-to-end user flow against the live data_sources API.""" | |
| setup = live_api_setup | |
| if setup is None: | |
| pytest.skip("Live API setup unavailable") | |
| toolkit = DataSourcesSQLToolkit( | |
| api_base_url=setup["base_url"], | |
| api_key=setup["api_key"], | |
| timeout=setup["timeout"] | |
| ) | |
| session_state = { | |
| "tenant_id": setup["tenant_id"], | |
| "source_name": setup["source_name"], | |
| "tool_execution_log": [], | |
| "analysis_metadata": {}, | |
| "user_context": {} | |
| } | |
| metrics: Dict[str, Any] = {} | |
| start = perf_counter() | |
| list_result = toolkit.list_sources(session_state=session_state) | |
| metrics["list_sources_ms"] = round((perf_counter() - start) * 1000, 2) | |
| assert "available_sources" in list_result | |
| assert setup["source_name"] in list_result.get("available_sources", []) | |
| start = perf_counter() | |
| search_result = toolkit.search_schema( | |
| keywords=["booking", "customer"], | |
| session_state=session_state | |
| ) | |
| metrics["search_schema_ms"] = round((perf_counter() - start) * 1000, 2) | |
| assert "available_sources" in search_result | |
| assert session_state.get("analysis_metadata", {}).get("last_schema_search") is not None | |
| start = perf_counter() | |
| find_result = toolkit.find_relevant_tables( | |
| question="What were our total bookings last quarter?", | |
| concepts=["bookings", "quarter"], | |
| session_state=session_state | |
| ) | |
| metrics["find_relevant_tables_ms"] = round((perf_counter() - start) * 1000, 2) | |
| assert "error" not in find_result | |
| start = perf_counter() | |
| instructions = toolkit.get_source_instructions(session_state=session_state) | |
| metrics["get_source_instructions_ms"] = round((perf_counter() - start) * 1000, 2) | |
| assert "instructions" in instructions | |
| start = perf_counter() | |
| good_sql = "SELECT current_database() AS current_db, NOW() AT TIME ZONE 'UTC' AS utc_now" | |
| ok_result = toolkit.execute_sql_query( | |
| sql_query=good_sql, | |
| session_state=session_state | |
| ) | |
| metrics["execute_sql_success_ms"] = round((perf_counter() - start) * 1000, 2) | |
| assert ok_result.get("status") == "success" | |
| assert isinstance(ok_result.get("results"), list) | |
| metadata = session_state.get("analysis_metadata", {}) | |
| last_exec = metadata.get("last_sql_execution", {}) | |
| assert last_exec.get("tenant_id") == setup["tenant_id"] | |
| assert "row_count" in last_exec | |
| start = perf_counter() | |
| bad_sql = "SELECT * FROM table_that_does_not_exist__pytest" | |
| error_result = toolkit.execute_sql_query( | |
| sql_query=bad_sql, | |
| session_state=session_state | |
| ) | |
| metrics["execute_sql_error_ms"] = round((perf_counter() - start) * 1000, 2) | |
| assert "error" in error_result | |
| # Capture async metadata for diagnostics | |
| metrics["total_tool_calls"] = len(session_state.get("tool_execution_log", [])) | |
| # Basic sanity on latencies (should be non-zero but reasonable) | |
| for label, value in metrics.items(): | |
| assert value >= 0 | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "--tb=short"]) | |