sirus / backend /data_sources /tests /test_data_sources_sql_toolkit.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
b8277c4
"""
Unit tests for DataSourcesSQLToolkit.
Tests cover:
- Initialization and configuration
- Schema fetching and formatting
- Source instructions retrieval
- SQL query execution
- Error handling and edge cases
- Input validation
Uses mocking (pytest-mock and unittest.mock) to simulate API responses.
"""
import pytest
import json
import uuid
import httpx
from time import perf_counter
from datetime import datetime
from unittest.mock import Mock, MagicMock, patch, call
from typing import Dict, Any
# Import from backend SQL_Agent package
import sys
import os
# Ensure the project root (parent of `backend`) is on sys.path so
# `import backend...` works when running this test file directly.
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def toolkit():
"""Create a DataSourcesSQLToolkit instance for testing."""
return DataSourcesSQLToolkit(api_base_url="http://test-api:8000")
@pytest.fixture
def mock_httpx_response():
"""Create a mock httpx response."""
return Mock()
# ============================================================================
# Test Initialization
# ============================================================================
class TestToolkitInitialization:
"""Tests for toolkit initialization and setup."""
def test_init_default_values(self):
"""Test initialization with default values."""
toolkit = DataSourcesSQLToolkit()
assert toolkit.api_base_url == "http://127.0.0.1:8000"
assert toolkit.client is not None
assert toolkit.client.base_url == "http://127.0.0.1:8000"
def test_init_custom_values(self):
"""Test initialization with custom values."""
custom_url = "http://custom-api:9999"
custom_timeout = 60.0
toolkit = DataSourcesSQLToolkit(api_base_url=custom_url, timeout=custom_timeout)
assert toolkit.api_base_url == custom_url
assert toolkit.client.base_url == custom_url
def test_toolkit_has_required_tools(self, toolkit):
"""Test that toolkit is initialized with required tools."""
tool_names = [tool.__name__ for tool in toolkit.tools]
assert "find_relevant_tables" in tool_names # NEW TOOL
assert "search_schema" in tool_names
assert "get_available_sources_and_schema" in tool_names
assert "get_source_instructions" in tool_names
assert "execute_sql_query" in tool_names
# ============================================================================
# Test Schema Formatting
# ============================================================================
class TestSchemaFormatting:
"""Tests for schema formatting for LLM consumption."""
def test_format_schema_with_valid_json(self, toolkit):
"""Test formatting valid schema JSON."""
source_name = "production_db"
schema_json = json.dumps([
{
"schema_name": "public",
"tables": [
{
"table_name": "users",
"fields": [
{"name": "id", "type": "INTEGER", "example": "1"},
{"name": "email", "type": "VARCHAR", "example": "user@example.com"}
]
}
]
}
])
formatted = toolkit._format_schema_for_llm(source_name, schema_json)
assert "production_db" in formatted
assert "public" in formatted
assert "users" in formatted
assert "id" in formatted
assert "INTEGER" in formatted
def test_format_schema_with_empty_schema(self, toolkit):
"""Test formatting with empty schema."""
source_name = "empty_source"
result = toolkit._format_schema_for_llm(source_name, None)
assert "empty_source" in result
assert "Not Available" in result
def test_format_schema_with_invalid_json(self, toolkit):
"""Test formatting with invalid JSON."""
source_name = "broken_source"
invalid_json = "{ invalid json }"
result = toolkit._format_schema_for_llm(source_name, invalid_json)
assert "broken_source" in result
assert "Invalid Format" in result or "Error" in result
def test_format_schema_with_multiple_tables(self, toolkit):
"""Test formatting with multiple tables."""
source_name = "test_db"
schema_json = json.dumps([
{
"schema_name": "public",
"tables": [
{
"table_name": "users",
"fields": [
{"name": "id", "type": "INTEGER"}
]
},
{
"table_name": "orders",
"fields": [
{"name": "order_id", "type": "INTEGER"},
{"name": "user_id", "type": "INTEGER"}
]
}
]
}
])
formatted = toolkit._format_schema_for_llm(source_name, schema_json)
assert "users" in formatted
assert "orders" in formatted
# ============================================================================
# Test get_available_sources_and_schema
# ============================================================================
class TestGetAvailableSourcesAndSchema:
"""Tests for retrieving available sources and their schemas."""
def test_missing_tenant_id(self, toolkit):
"""Test that missing tenant_id returns error."""
result = toolkit.get_available_sources_and_schema(session_state=None)
assert "error" in result
assert "Session state not found" in result["error"]
def test_successful_schema_fetch(self, toolkit):
"""Test successful schema fetching for valid tenant."""
# Mock the list_sources call
list_response = Mock()
list_response.status_code = 200
list_response.json.return_value = {
"available_sources": ["db1", "db2"],
"count": 2
}
# Mock httpx client responses for schema fetches
schema_response_1 = Mock()
schema_response_1.status_code = 200
schema_response_1.json.return_value = {
"schema_data": json.dumps([{
"schema_name": "public",
"tables": [{"table_name": "users", "fields": []}]
}])
}
schema_response_2 = Mock()
schema_response_2.status_code = 200
schema_response_2.json.return_value = {
"schema_data": json.dumps([{
"schema_name": "public",
"tables": [{"table_name": "orders", "fields": []}]
}])
}
toolkit.client.get = Mock(side_effect=[list_response, schema_response_1, schema_response_2])
# Provide session_state with tenant_id and JWT
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"}
result = toolkit.get_available_sources_and_schema(
keywords=None,
session_state=session_state
)
assert "formatted_schema_string" in result
assert "available_sources" in result
assert len(result["available_sources"]) == 2
# Verify JWT token was used in headers
calls = toolkit.client.get.call_args_list
for call in calls[1:]: # Skip the list_sources call
headers = call[1]["headers"]
assert headers["Authorization"] == "Bearer jwt-token-123"
def test_schema_fetch_with_api_error(self, toolkit):
"""Test handling of API errors during schema fetch."""
# Mock list_sources returning one source
list_response = Mock()
list_response.status_code = 200
list_response.json.return_value = {
"available_sources": ["db1"],
"count": 1
}
# Mock error response for schema fetch
schema_response = Mock()
schema_response.status_code = 500
schema_response.text = "Internal Server Error"
toolkit.client.get = Mock(side_effect=[list_response, schema_response])
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"}
result = toolkit.get_available_sources_and_schema(session_state=session_state)
assert "formatted_schema_string" in result
assert "available_sources" in result
def test_schema_fetch_import_error(self, toolkit):
"""Test handling when API connection fails."""
# Mock list_sources to raise connection error
toolkit.client.get = Mock(side_effect=httpx.RequestError("Connection failed"))
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"}
result = toolkit.get_available_sources_and_schema(session_state=session_state)
assert "error" in result
assert "API connection error" in result["error"]
# ============================================================================
# Test get_source_instructions
# ============================================================================
class TestGetSourceInstructions:
"""Tests for retrieving source-specific instructions."""
def test_missing_tenant_id(self, toolkit):
"""Test that missing tenant_id returns error."""
result = toolkit.get_source_instructions(session_state=None)
assert "error" in result
assert "Session state not found" in result["error"]
def test_missing_source_name(self, toolkit):
"""Test that missing source_name returns error."""
result = toolkit.get_source_instructions(session_state=None)
assert "error" in result
assert "Session state not found" in result["error"]
def test_successful_instructions_fetch(self, toolkit):
"""Test successful retrieval of instructions."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"instructions": "Use SELECT ... FROM syntax. Supports CTEs with WITH clause."
}
toolkit.client.get = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "source_name": "production_db", "supabase_jwt": "jwt-token-123"}
result = toolkit.get_source_instructions(
source_name="production_db",
session_state=session_state
)
assert "instructions" in result
assert "SELECT" in result["instructions"]
toolkit.client.get.assert_called_once()
def test_instructions_not_found(self, toolkit):
"""Test handling of 404 when instructions not found."""
mock_response = Mock()
mock_response.status_code = 404
toolkit.client.get = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "source_name": "unknown_db", "supabase_jwt": "jwt-token-123"}
result = toolkit.get_source_instructions(
source_name="unknown_db",
session_state=session_state
)
assert "error" in result
assert "not found" in result["error"].lower()
def test_instructions_api_error(self, toolkit):
"""Test handling of API errors."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.text = "Server Error"
toolkit.client.get = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
result = toolkit.get_source_instructions(
source_name="db",
session_state=session_state
)
assert "error" in result
assert "Failed" in result["error"]
def test_instructions_connection_error(self, toolkit):
"""Test handling of connection errors."""
import httpx
toolkit.client.get = Mock(side_effect=httpx.RequestError("Connection refused"))
session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
result = toolkit.get_source_instructions(
source_name="db",
session_state=session_state
)
assert "error" in result
assert "connection error" in result["error"].lower()
# ============================================================================
# Test execute_sql_query
# ============================================================================
class TestExecuteSQLQuery:
"""Tests for SQL query execution."""
def test_missing_tenant_id(self, toolkit):
"""Test that missing tenant_id returns error."""
result = toolkit.execute_sql_query(
sql_query="SELECT * FROM users",
session_state=None
)
assert "error" in result
assert "Session state not found" in result["error"]
def test_missing_source_name(self, toolkit):
"""Test that missing source_name returns error."""
result = toolkit.execute_sql_query(
sql_query="SELECT * FROM users",
session_state=None
)
assert "error" in result
assert "Session state not found" in result["error"]
def test_empty_sql_query(self, toolkit):
"""Test that empty SQL query returns error."""
result = toolkit.execute_sql_query(
sql_query=" ",
session_state={"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
)
assert "error" in result
assert "empty" in result["error"].lower()
def test_successful_query_execution(self, toolkit):
"""Test successful SQL query execution."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"results": [
{"id": 1, "name": "Alice"},
{"id": 2, "name": "Bob"}
]
}
toolkit.client.post = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "source_name": "production_db", "supabase_jwt": "jwt-token-123"}
result = toolkit.execute_sql_query(
sql_query="SELECT * FROM users",
session_state=session_state
)
assert result["status"] == "success"
assert len(result["results"]) == 2
assert result["results"][0]["name"] == "Alice"
def test_query_blocked_for_unsafe_operations(self, toolkit):
"""Test that unsafe SQL operations are blocked client-side."""
unsafe_queries = [
"DROP TABLE users",
"DELETE FROM users",
"UPDATE users SET name = 'hacked'",
"INSERT INTO users VALUES (1, 'hacked')",
"TRUNCATE TABLE users"
]
session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
for unsafe_query in unsafe_queries:
result = toolkit.execute_sql_query(
sql_query=unsafe_query,
session_state=session_state
)
assert "error" in result
assert "blocked" in result["error"].lower()
def test_query_allowed_for_safe_operations(self, toolkit):
"""Test that safe SQL operations are allowed."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": []}
toolkit.client.post = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
safe_queries = [
"SELECT * FROM users",
"WITH cte AS (SELECT * FROM users) SELECT * FROM cte",
"SHOW TABLES",
"DESCRIBE users",
"EXPLAIN SELECT * FROM users"
]
for safe_query in safe_queries:
result = toolkit.execute_sql_query(
source_name="db",
sql_query=safe_query,
session_state=session_state
)
# Should not have error (or should have been processed)
# Note: for SHOW, DESCRIBE, EXPLAIN - client-side check passes,
# then POST is attempted which we're mocking to succeed
assert "status" in result or "error" not in result.get("error", "").lower()
def test_query_api_error(self, toolkit):
"""Test handling of API errors during execution."""
mock_response = Mock()
mock_response.status_code = 400
mock_response.json.return_value = {"detail": "Syntax error in query"}
toolkit.client.post = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
result = toolkit.execute_sql_query(
source_name="db",
sql_query="SELECT * FROM users",
session_state=session_state
)
assert "error" in result
assert "Syntax error" in result["error"]
def test_query_timeout(self, toolkit):
"""Test handling of request timeouts."""
import httpx
toolkit.client.post = Mock(side_effect=httpx.TimeoutException("Request timed out"))
session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
result = toolkit.execute_sql_query(
source_name="db",
sql_query="SELECT * FROM users",
session_state=session_state
)
assert "error" in result
assert "timed out" in result["error"].lower()
def test_query_connection_error(self, toolkit):
"""Test handling of connection errors."""
import httpx
toolkit.client.post = Mock(side_effect=httpx.RequestError("Connection refused"))
session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
result = toolkit.execute_sql_query(
source_name="db",
sql_query="SELECT * FROM users",
session_state=session_state
)
assert "error" in result
assert "connection error" in result["error"].lower()
def test_query_malformed_response(self, toolkit):
"""Test handling of malformed API responses."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
mock_response.text = "Internal Server Error"
toolkit.client.post = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "source_name": "db", "supabase_jwt": "jwt-token-123"}
result = toolkit.execute_sql_query(
source_name="db",
sql_query="SELECT * FROM users",
session_state=session_state
)
assert "error" in result
assert "Internal Server Error" in result["error"]
# ============================================================================
# Test Multi-tenant Isolation
# ============================================================================
class TestMultiTenantIsolation:
"""Tests for ensuring proper multi-tenant isolation."""
def test_different_tenants_use_different_params(self, toolkit):
"""Test that different tenants get different API parameters."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": []}
toolkit.client.post = Mock(return_value=mock_response)
# Execute query for tenant 1
session_state_1 = {"tenant_id": "tenant-1", "supabase_jwt": "jwt-token-1", "source_name": "db"}
toolkit.execute_sql_query(
sql_query="SELECT * FROM users",
session_state=session_state_1
)
# Get first call's payload
first_call_payload = toolkit.client.post.call_args[1]["json"]
assert first_call_payload["tenant_id"] == "tenant-1"
# Execute query for tenant 2
session_state_2 = {"tenant_id": "tenant-2", "supabase_jwt": "jwt-token-2", "source_name": "db"}
toolkit.execute_sql_query(
sql_query="SELECT * FROM orders",
session_state=session_state_2
)
# Get second call's payload
second_call_payload = toolkit.client.post.call_args[1]["json"]
assert second_call_payload["tenant_id"] == "tenant-2"
def test_tenant_id_passed_to_all_endpoints(self, toolkit):
"""Test that tenant_id is passed in request payload (not query params) for endpoints that need it."""
# Test execute_sql_query - tenant_id should be in payload
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": []}
toolkit.client.post = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-x", "source_name": "db", "supabase_jwt": "jwt-token-123"}
toolkit.execute_sql_query(
sql_query="SELECT * FROM users",
session_state=session_state
)
# Get the call's payload
call_kwargs = toolkit.client.post.call_args[1]
payload = call_kwargs["json"]
headers = call_kwargs["headers"]
# Verify tenant_id is in payload, not query params
assert payload["tenant_id"] == "tenant-x"
# Verify JWT is in Authorization header
assert headers["Authorization"] == "Bearer jwt-token-123"
# ============================================================================
# Test Edge Cases
# ============================================================================
class TestEdgeCases:
"""Tests for edge cases and boundary conditions."""
def test_very_long_sql_query(self, toolkit):
"""Test handling of very long SQL queries."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": []}
toolkit.client.post = Mock(return_value=mock_response)
long_query = "SELECT * FROM users WHERE id IN (" + ",".join(str(i) for i in range(1000)) + ")"
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"}
result = toolkit.execute_sql_query(
source_name="db",
sql_query=long_query,
session_state=session_state
)
assert "error" not in result or result.get("status") == "success"
def test_special_characters_in_query(self, toolkit):
"""Test handling of special characters in SQL queries."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": []}
toolkit.client.post = Mock(return_value=mock_response)
query_with_special_chars = "SELECT * FROM users WHERE name = 'O''Reilly' AND email LIKE '%@%.com%'"
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"}
result = toolkit.execute_sql_query(
source_name="db",
sql_query=query_with_special_chars,
session_state=session_state
)
# Should be posted without error
toolkit.client.post.assert_called_once()
def test_empty_results(self, toolkit):
"""Test handling of empty result sets."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": []}
toolkit.client.post = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"}
result = toolkit.execute_sql_query(
source_name="db",
sql_query="SELECT * FROM users WHERE id = -1",
session_state=session_state
)
assert result["status"] == "success"
assert result["results"] == []
def test_large_result_set(self, toolkit):
"""Test handling of large result sets."""
# Create a large result set
large_results = [{"id": i, "name": f"user_{i}"} for i in range(10000)]
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": large_results}
toolkit.client.post = Mock(return_value=mock_response)
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123", "source_name": "db"}
result = toolkit.execute_sql_query(
source_name="db",
sql_query="SELECT * FROM users",
session_state=session_state
)
assert result["status"] == "success"
assert len(result["results"]) == 10000
# ============================================================================
# Test find_relevant_tables (Hybrid Keyword Extraction)
# ============================================================================
class TestFindRelevantTables:
"""Tests for the new find_relevant_tables tool with hybrid keyword extraction."""
def test_missing_tenant_id(self, toolkit):
"""Test that missing tenant_id returns error."""
result = toolkit.find_relevant_tables(
tenant_id="",
question="What are the total sales?"
)
assert "error" in result
assert "Tenant ID is required" in result["error"]
def test_missing_question(self, toolkit):
"""Test that missing question returns error."""
result = toolkit.find_relevant_tables(
tenant_id="tenant-123",
question=""
)
assert "error" in result
assert "Question is required" in result["error"]
@patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords")
def test_find_relevant_tables_merges_concepts_and_semantic_hints(self, mock_extract, toolkit):
"""Test that find_relevant_tables properly merges concepts and semantic hints."""
# Mock hybrid keyword extraction
mock_extract.return_value = {
'base': ['revenue', 'customers'],
'semantic': ['sales metrics', 'customer data'],
'concepts': ['premium', 'quarterly'],
'combined': ['premium', 'quarterly', 'revenue', 'customers', 'sales metrics', 'customer data']
}
# Mock the search_schema call
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"formatted_schema_string": "Test schema",
"matches": [{"table_name": "sales", "score": 15.0, "matched_columns": ["revenue"], "source_name": "db1"}],
"available_sources": ["db1"],
"total_matches": 1
}
toolkit.client.post = Mock(return_value=mock_response)
# Create session state
session_state = {"keyword_extraction_cache": {}}
# Call find_relevant_tables with agent concepts
result = toolkit.find_relevant_tables(
tenant_id="tenant-123",
question="What was revenue from premium customers last quarter?",
concepts=["premium", "quarterly"],
session_state=session_state
)
# Verify extraction was called
mock_extract.assert_called_once()
call_kwargs = mock_extract.call_args[1]
assert call_kwargs['question'] == "What was revenue from premium customers last quarter?"
assert call_kwargs['llm_concepts'] == ["premium", "quarterly"]
# Verify result structure
assert "formatted_schema_string" in result
assert "keyword_breakdown" in result
assert result["keyword_breakdown"]["base"] == ['revenue', 'customers']
assert result["keyword_breakdown"]["semantic"] == ['sales metrics', 'customer data']
assert result["keyword_breakdown"]["concepts"] == ['premium', 'quarterly']
assert result["original_question"] == "What was revenue from premium customers last quarter?"
# Verify merged keywords were sent to API
api_call_payload = toolkit.client.post.call_args[1]["json"]
assert set(api_call_payload["keywords"]) == {'premium', 'quarterly', 'revenue', 'customers', 'sales metrics', 'customer data'}
@patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords")
def test_find_relevant_tables_handles_gemini_failure(self, mock_extract, toolkit):
"""Test that find_relevant_tables handles Gemini failure gracefully."""
# Mock hybrid extraction with no semantic hints (Gemini failed)
mock_extract.return_value = {
'base': ['sales', 'customers'],
'semantic': [], # Empty because Gemini failed
'concepts': ['premium'],
'combined': ['premium', 'sales', 'customers']
}
# Mock API response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"formatted_schema_string": "Schema",
"matches": [],
"available_sources": ["db1"],
"total_matches": 0
}
toolkit.client.post = Mock(return_value=mock_response)
# Call should succeed despite Gemini failure
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"}
result = toolkit.find_relevant_tables(
question="Show premium customer sales",
concepts=["premium"],
session_state=session_state
)
assert "error" not in result
assert result["keyword_breakdown"]["semantic"] == []
assert "premium" in result["keyword_breakdown"]["combined"]
assert "sales" in result["keyword_breakdown"]["combined"]
@patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords")
def test_find_relevant_tables_uses_cache_for_repeat_question(self, mock_extract, toolkit):
"""Test that find_relevant_tables caches keyword extraction per question."""
# Mock hybrid extraction
mock_extract.return_value = {
'base': ['revenue'],
'semantic': ['financial metrics'],
'concepts': [],
'combined': ['revenue', 'financial metrics']
}
# Mock API response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"formatted_schema_string": "Schema",
"matches": [],
"available_sources": [],
"total_matches": 0
}
toolkit.client.post = Mock(return_value=mock_response)
# Create session state
session_state = {"keyword_extraction_cache": {}}
# First call
toolkit.find_relevant_tables(
question="What is the total revenue?",
session_state=session_state
)
# Verify extraction was called once
assert mock_extract.call_count == 1
assert len(session_state["keyword_extraction_cache"]) == 1
# Second call with same question
toolkit.find_relevant_tables(
question="What is the total revenue?",
session_state=session_state
)
# Verify extraction was NOT called again (cache hit)
assert mock_extract.call_count == 1 # Still 1, not 2
# Third call with different question
toolkit.find_relevant_tables(
question="Show me customer data",
session_state=session_state
)
# Verify extraction was called for new question
assert mock_extract.call_count == 2
assert len(session_state["keyword_extraction_cache"]) == 2
def test_find_relevant_tables_without_hybrid_utils(self, toolkit):
"""Test fallback behavior when hybrid_keyword_utils is not available."""
# Temporarily disable hybrid extraction
original_extract = toolkit.__class__.__module__
with patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords", None):
# Mock API response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"formatted_schema_string": "Schema",
"matches": [],
"available_sources": [],
"total_matches": 0
}
toolkit.client.post = Mock(return_value=mock_response)
# Should fall back to simple keyword extraction
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"}
result = toolkit.find_relevant_tables(
question="What is the revenue from premium customers?",
concepts=["premium"],
session_state=session_state
)
# Verify fallback worked
assert "error" not in result
# Check that API was called with some keywords
api_payload = toolkit.client.post.call_args[1]["json"]
assert len(api_payload["keywords"]) > 0
@patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords")
def test_find_relevant_tables_sends_metadata_to_api(self, mock_extract, toolkit):
"""Test that find_relevant_tables sends metadata to API for analytics."""
# Mock hybrid extraction
mock_extract.return_value = {
'base': ['sales'],
'semantic': ['revenue data'],
'concepts': ['monthly'],
'combined': ['monthly', 'sales', 'revenue data']
}
# Mock API response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"formatted_schema_string": "Schema",
"matches": [],
"available_sources": [],
"total_matches": 0
}
toolkit.client.post = Mock(return_value=mock_response)
# Call with question
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"}
toolkit.find_relevant_tables(
question="Show monthly sales data",
concepts=["monthly"],
session_state=session_state
)
# Verify metadata was sent to API
api_payload = toolkit.client.post.call_args[1]["json"]
assert "original_question" in api_payload
assert api_payload["original_question"] == "Show monthly sales data"
assert "keyword_metadata" in api_payload
assert api_payload["keyword_metadata"]["base"] == ['sales']
assert api_payload["keyword_metadata"]["semantic"] == ['revenue data']
assert api_payload["keyword_metadata"]["concepts"] == ['monthly']
@patch("backend.SQL_Agent.data_sources_sql_toolkit.extract_hybrid_keywords")
def test_find_relevant_tables_with_source_filter(self, mock_extract, toolkit):
"""Test that find_relevant_tables passes source_names filter correctly."""
# Mock hybrid extraction
mock_extract.return_value = {
'base': ['users'],
'semantic': [],
'concepts': [],
'combined': ['users']
}
# Mock API response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"formatted_schema_string": "Schema",
"matches": [],
"available_sources": ["db1"],
"total_matches": 0
}
toolkit.client.post = Mock(return_value=mock_response)
# Call with source filter
session_state = {"tenant_id": "tenant-123", "supabase_jwt": "jwt-token-123"}
toolkit.find_relevant_tables(
question="Show user data",
source_names=["db1", "db2"],
session_state=session_state
)
# Verify source filter was passed to API
api_payload = toolkit.client.post.call_args[1]["json"]
assert "source_names" in api_payload
assert api_payload["source_names"] == ["db1", "db2"]
# ============================================================================
# End-to-End Integration Tests with Real Database
# ============================================================================
class TestEndToEndIntegration:
"""
End-to-end integration tests with real PostgreSQL database using REAL API calls.
Uses requests library to hit the actual running API server.
Tests the complete workflow: keyword search -> schema fetch -> SQL execution.
"""
@pytest.fixture(scope="class")
def live_api_setup(self):
"""Provision a real tenant + source against the live data_sources API using requests."""
import requests
import uuid
if os.environ.get("SKIP_INTEGRATION_TESTS") == "1":
pytest.skip("Integration tests skipped via SKIP_INTEGRATION_TESTS=1")
base_url = os.environ.get("DATA_SOURCES_API_BASE_URL", "http://127.0.0.1:8000")
admin_key = os.environ.get("SIRUS_ADMIN_API_KEY")
if not admin_key:
pytest.skip("SIRUS_ADMIN_API_KEY not configured")
tenant_id = f"toolkit_e2e_{uuid.uuid4().hex[:8]}"
source_name = "scv_sample_db"
description = f"E2E Test {tenant_id}"
session = requests.Session()
session.timeout = 30
# Quick availability check
try:
health = session.get(f"{base_url}/api/v1/data-sources/list", params={"tenant_id": "__healthcheck__"})
if health.status_code not in (200, 400, 404, 401):
pytest.skip(f"API health check failed (status {health.status_code})")
except requests.RequestException as exc:
pytest.skip(f"API unreachable: {exc}")
# Create tenant API key via admin route
api_key = None
key_id = None
headers_admin = {"X-Sirus-Admin-Key": admin_key}
create_key_payload = {
"description": description,
"expires_in_days": 14
}
resp = session.post(
f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys",
json=create_key_payload,
headers=headers_admin
)
if resp.status_code == 201:
data = resp.json()
api_key = data.get("api_key")
key_info = data.get("key_info", {})
key_id = key_info.get("key_id")
else:
# Attempt to reuse an existing key if creation failed
list_resp = session.get(
f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys",
headers=headers_admin
)
if list_resp.status_code == 200:
keys = list_resp.json().get("keys", [])
if keys:
key_id = keys[0].get("key_id")
pytest.skip("Existing tenant API keys found but raw value unavailable")
pytest.skip(f"Failed to create tenant API key (status {resp.status_code}): {resp.text[:200]}")
if not api_key:
pytest.skip("Tenant API key could not be provisioned")
tenant_headers = {"X-Sirus-Api-Key": api_key}
# Upsert tenant source pointing at live Postgres
source_payload = {
"sources": [
{
"source_name": source_name,
"source_type": "ibis",
"config": {
"uri": "postgresql://neondb_owner:npg_dfWNsn2ZGk7c@ep-cool-poetry-a1puamly-pooler.ap-southeast-1.aws.neon.tech:5432/scv-sample?sslmode=require",
"table_fetch_example_limit": 5
}
}
],
"validate_connection": True
}
upsert_resp = session.post(
f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/sources",
json=source_payload,
headers=tenant_headers
)
if upsert_resp.status_code not in (200, 201):
pytest.skip(
f"Failed to configure tenant source (status {upsert_resp.status_code}): {upsert_resp.text[:200]}"
)
setup = {
"base_url": base_url,
"tenant_id": tenant_id,
"source_name": source_name,
"api_key": api_key,
"key_id": key_id,
"session": session,
"headers_admin": headers_admin,
"tenant_headers": tenant_headers,
"timeout": 120.0
}
try:
yield setup
finally:
# Cleanup: delete source and revoke key
try:
session.delete(
f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/sources/{source_name}",
headers=tenant_headers
)
except Exception:
pass
if key_id:
try:
session.delete(
f"{base_url}/api/v1/data-sources/tenants/{tenant_id}/api-keys/{key_id}",
headers=headers_admin
)
except Exception:
pass
session.close()
@pytest.fixture
def live_toolkit(self, live_api_setup):
"""Toolkit connected to real API with proper authentication."""
api_key = live_api_setup.get("api_key")
base_url = live_api_setup.get("base_url", "http://127.0.0.1:8000")
return DataSourcesSQLToolkit(api_base_url=base_url, timeout=120.0, api_key=api_key)
@pytest.mark.integration
@pytest.mark.skipif(
os.environ.get("SKIP_INTEGRATION_TESTS") == "1",
reason="Integration tests skipped (set SKIP_INTEGRATION_TESTS=0 to run)"
)
def test_e2e_keyword_search_with_real_db(self, live_toolkit, live_api_setup, caplog):
"""
Test end-to-end workflow with keyword search on real database using REAL API calls.
Workflow:
1. Search schema with keywords via real API
2. Verify matched tables
3. Execute SQL query on matched table via real API
4. Measure timing for each step
"""
import time
import logging
# Enable detailed logging
caplog.set_level(logging.INFO)
print("\n" + "="*80)
print("END-TO-END INTEGRATION TEST: Keyword Search -> SQL Execution")
print("="*80)
# Test configuration from provisioned setup
tenant_id = live_api_setup["tenant_id"]
source_name = live_api_setup["source_name"]
# Step 1: Search schema with business keywords
print("\n[STEP 1] Searching schema with keywords...")
search_start = time.time()
keywords = ["customer", "order", "product", "sales"]
search_result = live_toolkit.search_schema(
keywords=keywords,
include_samples=True,
session_state={"tenant_id": tenant_id}
)
search_duration = time.time() - search_start
print(f"✓ Schema search completed in {search_duration:.3f}s")
# Verify search results
if "error" in search_result:
print(f"⚠ Search returned error (API may not be running): {search_result['error']}")
pytest.skip("API not available for integration test")
assert "matches" in search_result or "formatted_schema_string" in search_result
print(f" - Total matches: {search_result.get('total_matches', 0)}")
print(f" - Available sources: {search_result.get('available_sources', [])}")
print(f" - Cache hit: {search_result.get('cache_hit', False)}")
# Log matched tables
if "matches" in search_result and search_result["matches"]:
print("\n Matched Tables:")
for match in search_result["matches"][:5]: # Show top 5
print(f" - {match.get('table_name', 'N/A')} (score: {match.get('score', 0)})")
if "matched_columns" in match:
print(f" Columns: {', '.join(match['matched_columns'][:5])}")
# Step 2: Get source instructions
print("\n[STEP 2] Fetching SQL dialect instructions...")
instructions_start = time.time()
session_state = {"tenant_id": tenant_id, "source_name": source_name}
instructions_result = live_toolkit.get_source_instructions(
source_name=source_name,
session_state=session_state
)
instructions_duration = time.time() - instructions_start
print(f"✓ Instructions fetched in {instructions_duration:.3f}s")
if "error" not in instructions_result:
print(f" - Instructions: {instructions_result.get('instructions', 'N/A')[:100]}...")
# Step 3: Execute SQL query on discovered table
print("\n[STEP 3] Executing SQL query on real database...")
# Use a safe SELECT query that should work on most schemas
test_queries = [
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' LIMIT 5",
"SELECT current_database(), current_schema()",
"SELECT version()"
]
for idx, sql_query in enumerate(test_queries, 1):
print(f"\n Query {idx}: {sql_query}")
query_start = time.time()
query_result = live_toolkit.execute_sql_query(
source_name=source_name,
sql_query=sql_query,
session_state={"tenant_id": tenant_id}
)
query_duration = time.time() - query_start
print(f" ✓ Query executed in {query_duration:.3f}s")
if "error" in query_result:
print(f" ⚠ Query error: {query_result['error']}")
elif "results" in query_result:
results = query_result["results"]
print(f" - Rows returned: {len(results)}")
if results and len(results) > 0:
print(f" - Sample row: {results[0]}")
# Verify result structure
assert "status" in query_result or "error" in query_result or "results" in query_result
# Step 4: Test multi-tenant isolation
print("\n[STEP 4] Testing multi-tenant isolation...")
tenant_2_id = "test_tenant_e2e_002"
isolation_start = time.time()
# Same query, different tenant
isolation_result = live_toolkit.execute_sql_query(
source_name=source_name,
sql_query="SELECT current_database()",
session_state={"tenant_id": tenant_2_id}
)
isolation_duration = time.time() - isolation_start
print(f"✓ Tenant isolation verified in {isolation_duration:.3f}s")
print(f" - Tenant 2 query processed independently")
# Step 5: Test caching behavior
print("\n[STEP 5] Testing cache behavior...")
# Repeat search with same keywords
cache_start = time.time()
cached_search_result = live_toolkit.search_schema(
keywords=keywords,
include_samples=False,
session_state={"tenant_id": tenant_id}
)
cache_duration = time.time() - cache_start
print(f"✓ Cached search completed in {cache_duration:.3f}s")
print(f" - Cache hit: {cached_search_result.get('cache_hit', False)}")
print(f" - Speedup: {(search_duration / cache_duration):.2f}x faster" if cache_duration > 0 else " - Instant cache hit")
# Summary
print("\n" + "="*80)
print("END-TO-END TEST SUMMARY")
print("="*80)
print(f"Total workflow time: {(search_duration + instructions_duration + query_duration):.3f}s")
print(f" - Schema search: {search_duration:.3f}s")
print(f" - Instructions: {instructions_duration:.3f}s")
print(f" - SQL execution: {query_duration:.3f}s")
print(f" - Cache speedup: {(search_duration / cache_duration):.2f}x" if cache_duration > 0 else " - N/A")
print("="*80 + "\n")
@pytest.mark.integration
@pytest.mark.skipif(
os.environ.get("SKIP_INTEGRATION_TESTS") == "1",
reason="Integration tests skipped"
)
def test_e2e_session_state_caching(self, live_toolkit, live_api_setup, caplog):
"""
Test session state caching across multiple toolkit calls.
Simulates agent behavior with session state persistence.
"""
import time
print("\n" + "="*80)
print("END-TO-END TEST: Session State Caching")
print("="*80)
tenant_id = live_api_setup["tenant_id"]
# Simulate session state (like agent would provide)
session_state = {
"schema_search_cache": {},
"extracted_keywords": [],
"tool_execution_log": []
}
keywords = ["revenue", "customer", "transaction"]
# First call - should miss cache
print("\n[Call 1] Schema search without cache...")
start_1 = time.time()
result_1 = live_toolkit.search_schema(
keywords=keywords,
session_state=session_state
)
duration_1 = time.time() - start_1
if "error" in result_1:
pytest.skip("API not available")
print(f"✓ First call: {duration_1:.3f}s, cache_hit={result_1.get('cache_hit', False)}")
# Second call - should hit session cache
print("\n[Call 2] Schema search with session cache...")
start_2 = time.time()
result_2 = live_toolkit.search_schema(
keywords=keywords,
session_state=session_state
)
duration_2 = time.time() - start_2
print(f"✓ Second call: {duration_2:.3f}s, cache_hit={result_2.get('cache_hit', False)}")
print(f" Cache source: {result_2.get('cache_source', 'N/A')}")
# Verify caching worked
assert result_2.get('cache_hit') == True, "Second call should hit cache"
assert duration_2 < duration_1, "Cached call should be faster"
speedup = duration_1 / duration_2 if duration_2 > 0 else float('inf')
print(f"\n✓ Cache speedup: {speedup:.2f}x faster")
print(f" Session cache keys: {len(session_state.get('schema_search_cache', {}))}")
print("="*80 + "\n")
@pytest.mark.integration
@pytest.mark.skipif(
os.environ.get("SKIP_INTEGRATION_TESTS") == "1",
reason="Integration tests skipped"
)
def test_e2e_keyword_extraction_workflow(self, live_toolkit, live_api_setup, caplog):
"""
Test realistic agent workflow: keyword extraction -> search -> SQL execution.
Simulates what agent.py would do with execute_query_with_tracking.
"""
import time
import re
print("\n" + "="*80)
print("END-TO-END TEST: Agent Keyword Extraction Workflow")
print("="*80)
tenant_id = live_api_setup["tenant_id"]
source_name = live_api_setup["source_name"]
# Simulate user questions (like agent would receive)
user_questions = [
"What are the total sales by customer?",
"Show me all products with low inventory",
"Get customer contact information for premium accounts"
]
for q_idx, question in enumerate(user_questions, 1):
print(f"\n[Question {q_idx}] {question}")
# Step 1: Extract keywords (simulating agent.py logic)
stopwords = {'what', 'are', 'the', 'show', 'me', 'all', 'with', 'get', 'for', 'by'}
words = re.findall(r'\b[a-z]+\b', question.lower())
keywords = [w for w in words if w not in stopwords and len(w) > 2]
print(f" Extracted keywords: {keywords}")
# Step 2: Search schema
search_start = time.time()
search_result = live_toolkit.search_schema(
keywords=keywords[:5], # Limit to top 5
session_state={"tenant_id": tenant_id}
)
search_duration = time.time() - search_start
if "error" in search_result:
print(f" ⚠ API not available: {search_result['error']}")
pytest.skip("API not running")
print(f" ✓ Search: {search_duration:.3f}s, matches={search_result.get('total_matches', 0)}")
# Step 3: Agent would now generate SQL based on matched tables
# For this test, we'll use a generic query
if search_result.get('total_matches', 0) > 0:
sql_query = "SELECT 1 as test_column" # Safe query
query_start = time.time()
query_result = live_toolkit.execute_sql_query(
source_name=source_name,
sql_query=sql_query,
session_state={"tenant_id": tenant_id}
)
query_duration = time.time() - query_start
print(f" ✓ Query: {query_duration:.3f}s")
if "results" in query_result:
print(f" Results: {len(query_result['results'])} rows")
print(f" Total workflow: {(search_duration + query_duration if 'query_duration' in locals() else search_duration):.3f}s")
print("\n" + "="*80)
print("✓ All workflow tests completed successfully")
print("="*80 + "\n")
@pytest.mark.integration
def test_e2e_performance_benchmarks(self, live_toolkit, live_api_setup, caplog):
"""
Performance benchmark tests for the complete toolkit.
Measures timing for various operations under realistic load.
"""
import time
import statistics
print("\n" + "="*80)
print("PERFORMANCE BENCHMARKS")
print("="*80)
tenant_id = live_api_setup["tenant_id"]
# Benchmark 1: Schema search performance
print("\n[Benchmark 1] Schema search (10 iterations)...")
search_times = []
for i in range(10):
keywords = [f"test_keyword_{i % 3}", "customer", "order"]
start = time.time()
result = live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id})
duration = time.time() - start
if "error" not in result:
search_times.append(duration)
if search_times:
print(f" Average: {statistics.mean(search_times):.3f}s")
print(f" Median: {statistics.median(search_times):.3f}s")
print(f" Min: {min(search_times):.3f}s")
print(f" Max: {max(search_times):.3f}s")
print(f" Std Dev: {statistics.stdev(search_times):.3f}s" if len(search_times) > 1 else " N/A")
else:
print(" ⚠ No successful searches (API may be down)")
pytest.skip("API not available")
# Benchmark 2: Cache hit performance
print("\n[Benchmark 2] Cache hit performance...")
keywords = ["benchmark", "cache", "test"]
# First call (cache miss)
start = time.time()
live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id})
miss_time = time.time() - start
# Subsequent calls (cache hits)
hit_times = []
for _ in range(5):
start = time.time()
live_toolkit.search_schema(keywords=keywords, session_state={"tenant_id": tenant_id})
hit_times.append(time.time() - start)
avg_hit_time = statistics.mean(hit_times)
print(f" Cache miss: {miss_time:.3f}s")
print(f" Cache hit avg: {avg_hit_time:.3f}s")
print(f" Speedup: {(miss_time / avg_hit_time):.2f}x")
print("\n" + "="*80)
print("✓ Performance benchmarks completed")
print("="*80 + "\n")
@pytest.mark.integration
def test_full_live_user_flow(self, live_api_setup):
"""Validate an end-to-end user flow against the live data_sources API."""
setup = live_api_setup
if setup is None:
pytest.skip("Live API setup unavailable")
toolkit = DataSourcesSQLToolkit(
api_base_url=setup["base_url"],
api_key=setup["api_key"],
timeout=setup["timeout"]
)
session_state = {
"tenant_id": setup["tenant_id"],
"source_name": setup["source_name"],
"tool_execution_log": [],
"analysis_metadata": {},
"user_context": {}
}
metrics: Dict[str, Any] = {}
start = perf_counter()
list_result = toolkit.list_sources(session_state=session_state)
metrics["list_sources_ms"] = round((perf_counter() - start) * 1000, 2)
assert "available_sources" in list_result
assert setup["source_name"] in list_result.get("available_sources", [])
start = perf_counter()
search_result = toolkit.search_schema(
keywords=["booking", "customer"],
session_state=session_state
)
metrics["search_schema_ms"] = round((perf_counter() - start) * 1000, 2)
assert "available_sources" in search_result
assert session_state.get("analysis_metadata", {}).get("last_schema_search") is not None
start = perf_counter()
find_result = toolkit.find_relevant_tables(
question="What were our total bookings last quarter?",
concepts=["bookings", "quarter"],
session_state=session_state
)
metrics["find_relevant_tables_ms"] = round((perf_counter() - start) * 1000, 2)
assert "error" not in find_result
start = perf_counter()
instructions = toolkit.get_source_instructions(session_state=session_state)
metrics["get_source_instructions_ms"] = round((perf_counter() - start) * 1000, 2)
assert "instructions" in instructions
start = perf_counter()
good_sql = "SELECT current_database() AS current_db, NOW() AT TIME ZONE 'UTC' AS utc_now"
ok_result = toolkit.execute_sql_query(
sql_query=good_sql,
session_state=session_state
)
metrics["execute_sql_success_ms"] = round((perf_counter() - start) * 1000, 2)
assert ok_result.get("status") == "success"
assert isinstance(ok_result.get("results"), list)
metadata = session_state.get("analysis_metadata", {})
last_exec = metadata.get("last_sql_execution", {})
assert last_exec.get("tenant_id") == setup["tenant_id"]
assert "row_count" in last_exec
start = perf_counter()
bad_sql = "SELECT * FROM table_that_does_not_exist__pytest"
error_result = toolkit.execute_sql_query(
sql_query=bad_sql,
session_state=session_state
)
metrics["execute_sql_error_ms"] = round((perf_counter() - start) * 1000, 2)
assert "error" in error_result
# Capture async metadata for diagnostics
metrics["total_tool_calls"] = len(session_state.get("tool_execution_log", []))
# Basic sanity on latencies (should be non-zero but reasonable)
for label, value in metrics.items():
assert value >= 0
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])