# space/tools/sql_tool.py import os import re import json import logging import pandas as pd from typing import Optional from utils.config import AppConfig from utils.tracing import Tracer logger = logging.getLogger(__name__) RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"} MAX_QUERY_LENGTH = 50000 MAX_RESULT_ROWS = 100000 class SQLToolError(Exception): """Custom exception for SQL tool errors.""" pass class SQLTool: """ SQL execution tool supporting BigQuery and MotherDuck backends. Includes input validation, error handling, and secure query execution. """ def __init__(self, cfg: AppConfig, tracer: Tracer): self.cfg = cfg self.tracer = tracer self.backend = cfg.sql_backend self.client = None logger.info(f"Initializing SQLTool with backend: {self.backend}") try: if self.backend == "bigquery": self._init_bigquery() elif self.backend == "motherduck": self._init_motherduck() else: raise SQLToolError(f"Unknown SQL backend: {self.backend}") logger.info(f"SQLTool initialized successfully with {self.backend}") except Exception as e: logger.error(f"Failed to initialize SQLTool: {e}") raise SQLToolError(f"SQL backend initialization failed: {e}") from e def _init_bigquery(self): """Initialize BigQuery client with service account credentials.""" try: from google.cloud import bigquery from google.oauth2 import service_account key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON") if not key_json: raise SQLToolError( "Missing GCP_SERVICE_ACCOUNT_JSON environment variable. " "Please configure BigQuery credentials." ) # Parse credentials try: if key_json.strip().startswith("{"): info = json.loads(key_json) else: # Assume it's a file path with open(key_json, 'r') as f: info = json.load(f) except json.JSONDecodeError as e: raise SQLToolError(f"Invalid JSON in GCP_SERVICE_ACCOUNT_JSON: {e}") except FileNotFoundError: raise SQLToolError(f"GCP service account file not found: {key_json}") # Validate required fields required_fields = ["type", "project_id", "private_key", "client_email"] missing = [f for f in required_fields if f not in info] if missing: raise SQLToolError( f"GCP service account JSON missing required fields: {missing}" ) creds = service_account.Credentials.from_service_account_info(info) project = self.cfg.gcp_project or info.get("project_id") if not project: raise SQLToolError("GCP project ID not specified in config or credentials") self.client = bigquery.Client(credentials=creds, project=project) logger.info(f"BigQuery client initialized for project: {project}") except ImportError as e: raise SQLToolError( "BigQuery dependencies not installed. " "Install with: pip install google-cloud-bigquery" ) from e def _init_motherduck(self): """Initialize MotherDuck/DuckDB client with version validation.""" try: import duckdb # Version compatibility check - be more flexible version = duckdb.__version__ logger.info(f"DuckDB version: {version}") # Warn if not on recommended version, but don't fail if not version.startswith("1.3"): logger.warning( f"DuckDB {version} detected. Recommended: 1.3.x for MotherDuck compatibility. " "Some features may not work as expected." ) # Get configuration token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip() if not token: raise SQLToolError( "Missing MOTHERDUCK_TOKEN. " "Get your token from: https://motherduck.com/docs/key-tasks/authenticating-to-motherduck" ) db_name = (self.cfg.motherduck_db or "workspace").strip() allow_create = os.getenv("ALLOW_CREATE_DB", "true").lower() == "true" # Connect based on database name if db_name in RESERVED_MD_WORKSPACE_NAMES: # Workspace mode - no specific database context connection_string = f"md:?motherduck_token={token}" logger.info("Connecting to MotherDuck workspace") self.client = duckdb.connect(connection_string) else: # Try connecting to specific database try: connection_string = f"md:{db_name}?motherduck_token={token}" logger.info(f"Connecting to MotherDuck database: {db_name}") self.client = duckdb.connect(connection_string) except Exception as db_err: logger.warning(f"Direct connection to '{db_name}' failed: {db_err}") # Fallback: connect to workspace and setup database connection_string = f"md:?motherduck_token={token}" self.client = duckdb.connect(connection_string) self._ensure_db_context(db_name, allow_create) # Test connection try: self.client.execute("SELECT 1").fetchone() logger.info("MotherDuck connection test successful") except Exception as e: raise SQLToolError(f"MotherDuck connection test failed: {e}") except ImportError as e: raise SQLToolError( "DuckDB not installed. Install with: pip install duckdb" ) from e def _ensure_db_context(self, db_name: str, allow_create: bool): """ Ensure database context is set for MotherDuck. Creates database if it doesn't exist and allow_create is True. """ if db_name in RESERVED_MD_WORKSPACE_NAMES: return safe_name = self._quote_ident(db_name) # Try to USE the database first try: self.client.execute(f"USE {safe_name};") logger.info(f"Using existing database: {db_name}") return except Exception as use_err: logger.info(f"Database '{db_name}' not found: {use_err}") if not allow_create: raise SQLToolError( f"Database '{db_name}' does not exist and ALLOW_CREATE_DB is disabled. " f"Either create the database manually or set ALLOW_CREATE_DB=true" ) # Attempt to create and use the database try: logger.info(f"Creating database: {db_name}") self.client.execute(f"CREATE DATABASE IF NOT EXISTS {safe_name};") self.client.execute(f"USE {safe_name};") logger.info(f"Database '{db_name}' created and selected") except Exception as create_err: raise SQLToolError( f"Failed to create database '{db_name}': {create_err}" ) from create_err @staticmethod def _quote_ident(name: str) -> str: """ Safely quote SQL identifiers. Replaces non-alphanumeric characters with underscores. """ if not name: return "unnamed" # Remove dangerous characters safe = re.sub(r"[^a-zA-Z0-9_]", "_", name) # Ensure it doesn't start with a number if safe[0].isdigit(): safe = "_" + safe return safe def _validate_sql(self, sql: str) -> tuple[bool, str]: """ Validate SQL query for basic safety. Returns (is_valid, error_message). """ if not sql or not sql.strip(): return False, "Empty SQL query" if len(sql) > MAX_QUERY_LENGTH: return False, f"Query too long (max {MAX_QUERY_LENGTH} characters)" # Dangerous patterns check sql_lower = sql.lower() # Block multiple statements (simple check) if sql.count(';') > 1: return False, "Multiple SQL statements not allowed" # Block dangerous keywords in non-SELECT queries dangerous_patterns = [ (r'\bdrop\s+table\b', "DROP TABLE"), (r'\bdrop\s+database\b', "DROP DATABASE"), (r'\bdelete\s+from\b', "DELETE FROM"), (r'\btruncate\b', "TRUNCATE"), (r'\bexec\s*\(', "EXEC"), (r'\bexecute\s*\(', "EXECUTE"), ] for pattern, name in dangerous_patterns: if re.search(pattern, sql_lower): logger.warning(f"Blocked query with {name} pattern") return False, f"Query contains blocked operation: {name}" return True, "" def _nl_to_sql(self, message: str) -> str: """ Convert natural language to SQL query. This is a simple heuristic - replace with proper NL2SQL model for production. """ m = message.lower() # If it's already SQL, return as-is (after validation) if re.match(r'^\s*select\s', m, re.IGNORECASE): return message.strip() # Template-based generation (customize for your schema) if "avg" in m or "average" in m: if "by month" in m or "monthly" in m: return """ SELECT DATE_TRUNC('month', date_col) AS month, AVG(metric_col) AS avg_metric FROM analytics.fact_table GROUP BY 1 ORDER BY 1 DESC LIMIT 100; """ if "top" in m: # Extract number if present match = re.search(r'top\s+(\d+)', m) limit = match.group(1) if match else "10" return f""" SELECT * FROM analytics.fact_table ORDER BY metric_col DESC LIMIT {limit}; """ if "count" in m: return """ SELECT category_col, COUNT(*) AS count FROM analytics.fact_table GROUP BY 1 ORDER BY 2 DESC LIMIT 100; """ # Default fallback return """ SELECT * FROM analytics.fact_table LIMIT 100; """ def run(self, message: str) -> pd.DataFrame: """ Execute SQL query from natural language or SQL statement. Args: message: Natural language query or SQL statement Returns: DataFrame with query results Raises: SQLToolError: If query execution fails """ try: # Convert to SQL sql = self._nl_to_sql(message) logger.info(f"Generated SQL query (first 200 chars): {sql[:200]}") # Validate SQL is_valid, error_msg = self._validate_sql(sql) if not is_valid: raise SQLToolError(f"Invalid SQL query: {error_msg}") # Log query attempt self.tracer.trace_event("sql_query", { "sql": sql[:1000], # Limit logged SQL length "backend": self.backend, "message": message[:500] }) # Execute based on backend if self.backend == "bigquery": result = self._execute_bigquery(sql) else: # motherduck result = self._execute_duckdb(sql) # Validate result if not isinstance(result, pd.DataFrame): raise SQLToolError("Query did not return a DataFrame") # Check result size if len(result) > MAX_RESULT_ROWS: logger.warning(f"Result truncated from {len(result)} to {MAX_RESULT_ROWS} rows") result = result.head(MAX_RESULT_ROWS) logger.info(f"Query successful: {len(result)} rows, {len(result.columns)} columns") self.tracer.trace_event("sql_success", { "rows": len(result), "columns": len(result.columns) }) return result except SQLToolError: raise except Exception as e: error_msg = f"SQL execution failed: {str(e)}" logger.error(error_msg) self.tracer.trace_event("sql_error", {"error": error_msg}) raise SQLToolError(error_msg) from e def _execute_bigquery(self, sql: str) -> pd.DataFrame: """Execute query on BigQuery.""" try: query_job = self.client.query(sql) df = query_job.to_dataframe() return df except Exception as e: raise SQLToolError(f"BigQuery execution error: {str(e)}") from e def _execute_duckdb(self, sql: str) -> pd.DataFrame: """Execute query on DuckDB/MotherDuck.""" try: result = self.client.execute(sql) df = result.fetch_df() return df except Exception as e: raise SQLToolError(f"DuckDB execution error: {str(e)}") from e def test_connection(self) -> bool: """Test database connection.""" try: test_query = "SELECT 1 AS test" result = self.run(test_query) return len(result) == 1 and result.iloc[0, 0] == 1 except Exception as e: logger.error(f"Connection test failed: {e}") return False