"""Database query tool with safety checks.""" import logging from typing import List, Dict, Any, Optional import re from sqlalchemy import create_engine, text, inspect from sqlalchemy.exc import SQLAlchemyError from src.core.config import get_settings logger = logging.getLogger(__name__) class DatabaseQuery: """Database query tool with SQL injection prevention.""" # Dangerous SQL keywords that should not be allowed DANGEROUS_KEYWORDS = { "DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE", "GRANT", "REVOKE", "EXEC", "EXECUTE", "MERGE", } # Allowed SQL keywords (SELECT queries only) ALLOWED_KEYWORDS = { "SELECT", "FROM", "WHERE", "JOIN", "INNER", "LEFT", "RIGHT", "FULL", "OUTER", "ON", "GROUP", "BY", "ORDER", "HAVING", "LIMIT", "OFFSET", "AS", "AND", "OR", "NOT", "IN", "LIKE", "BETWEEN", "IS", "NULL", "DISTINCT", "COUNT", "SUM", "AVG", "MAX", "MIN", "CASE", "WHEN", "THEN", "ELSE", "END", } def __init__(self, database_url: Optional[str] = None): """Initialize database query tool.""" self.settings = get_settings() self.database_url = database_url or self.settings.database_url if not self.database_url: logger.warning("No database URL configured") self.engine = None else: try: self.engine = create_engine(self.database_url) logger.info(f"Connected to database: {self.database_url.split('@')[-1] if '@' in self.database_url else 'local'}") except Exception as e: logger.error(f"Error connecting to database: {e}") self.engine = None def is_safe_query(self, query: str) -> tuple[bool, Optional[str]]: """ Check if a SQL query is safe to execute. Args: query: SQL query string Returns: Tuple of (is_safe, error_message) """ query_upper = query.upper().strip() # Must start with SELECT if not query_upper.startswith("SELECT"): return False, "Only SELECT queries are allowed" # Check for dangerous keywords for keyword in self.DANGEROUS_KEYWORDS: if re.search(rf"\b{keyword}\b", query_upper): return False, f"Dangerous keyword '{keyword}' is not allowed" # Check for semicolons (potential for multiple statements) if ";" in query and query.count(";") > 1: return False, "Multiple statements not allowed" # Check for comments that might hide malicious code if "--" in query or "/*" in query: return False, "SQL comments are not allowed" return True, None def query( self, sql: str, limit: int = 100, ) -> Dict[str, Any]: """ Execute a safe SELECT query. Args: sql: SQL SELECT query limit: Maximum number of rows to return Returns: Dictionary with query results """ if not self.engine: return { "success": False, "error": "Database not configured", "results": [], } # Check if query is safe is_safe, error = self.is_safe_query(sql) if not is_safe: return { "success": False, "error": error, "results": [], } try: # Add LIMIT if not present sql_upper = sql.upper() if "LIMIT" not in sql_upper: sql = f"{sql.rstrip(';')} LIMIT {limit}" # Execute query with self.engine.connect() as connection: result = connection.execute(text(sql)) rows = result.fetchall() columns = result.keys() # Convert to list of dictionaries results = [] for row in rows: results.append(dict(zip(columns, row))) return { "success": True, "results": results, "row_count": len(results), "columns": list(columns), } except SQLAlchemyError as e: logger.error(f"Database query error: {e}") return { "success": False, "error": str(e), "results": [], } except Exception as e: logger.error(f"Unexpected error executing query: {e}") return { "success": False, "error": str(e), "results": [], } def get_table_schema(self, table_name: str) -> Dict[str, Any]: """ Get schema information for a table. Args: table_name: Name of the table Returns: Dictionary with table schema """ if not self.engine: return { "success": False, "error": "Database not configured", } try: inspector = inspect(self.engine) columns = inspector.get_columns(table_name) primary_keys = inspector.get_primary_keys(table_name) foreign_keys = inspector.get_foreign_keys(table_name) return { "success": True, "table": table_name, "columns": [ { "name": col["name"], "type": str(col["type"]), "nullable": col.get("nullable", True), } for col in columns ], "primary_keys": primary_keys, "foreign_keys": [ { "name": fk["name"], "constrained_columns": fk["constrained_columns"], "referred_table": fk["referred_table"], "referred_columns": fk["referred_columns"], } for fk in foreign_keys ], } except Exception as e: logger.error(f"Error getting table schema: {e}") return { "success": False, "error": str(e), } def list_tables(self) -> List[str]: """List all tables in the database.""" if not self.engine: return [] try: inspector = inspect(self.engine) return inspector.get_table_names() except Exception as e: logger.error(f"Error listing tables: {e}") return [] def get_tool_schema(self) -> Dict[str, Any]: """Get tool schema for agent integration.""" return { "name": "database_query", "description": "Execute safe SELECT queries on the database", "parameters": { "type": "object", "properties": { "sql": { "type": "string", "description": "SQL SELECT query to execute", }, "limit": { "type": "integer", "description": "Maximum number of rows to return (default: 100)", "default": 100, }, }, "required": ["sql"], }, } # Global instance _database_query: Optional[DatabaseQuery] = None def get_database_query() -> DatabaseQuery: """Get or create the global database query instance.""" global _database_query if _database_query is None: _database_query = DatabaseQuery() return _database_query