Spaces:
Sleeping
Sleeping
File size: 7,861 Bytes
8bf4d58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
"""Database query tool with safety checks."""
import logging
from typing import List, Dict, Any, Optional
import re
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.exc import SQLAlchemyError
from src.core.config import get_settings
logger = logging.getLogger(__name__)
class DatabaseQuery:
"""Database query tool with SQL injection prevention."""
# Dangerous SQL keywords that should not be allowed
DANGEROUS_KEYWORDS = {
"DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT",
"UPDATE", "GRANT", "REVOKE", "EXEC", "EXECUTE", "MERGE",
}
# Allowed SQL keywords (SELECT queries only)
ALLOWED_KEYWORDS = {
"SELECT", "FROM", "WHERE", "JOIN", "INNER", "LEFT", "RIGHT",
"FULL", "OUTER", "ON", "GROUP", "BY", "ORDER", "HAVING",
"LIMIT", "OFFSET", "AS", "AND", "OR", "NOT", "IN", "LIKE",
"BETWEEN", "IS", "NULL", "DISTINCT", "COUNT", "SUM", "AVG",
"MAX", "MIN", "CASE", "WHEN", "THEN", "ELSE", "END",
}
def __init__(self, database_url: Optional[str] = None):
"""Initialize database query tool."""
self.settings = get_settings()
self.database_url = database_url or self.settings.database_url
if not self.database_url:
logger.warning("No database URL configured")
self.engine = None
else:
try:
self.engine = create_engine(self.database_url)
logger.info(f"Connected to database: {self.database_url.split('@')[-1] if '@' in self.database_url else 'local'}")
except Exception as e:
logger.error(f"Error connecting to database: {e}")
self.engine = None
def is_safe_query(self, query: str) -> tuple[bool, Optional[str]]:
"""
Check if a SQL query is safe to execute.
Args:
query: SQL query string
Returns:
Tuple of (is_safe, error_message)
"""
query_upper = query.upper().strip()
# Must start with SELECT
if not query_upper.startswith("SELECT"):
return False, "Only SELECT queries are allowed"
# Check for dangerous keywords
for keyword in self.DANGEROUS_KEYWORDS:
if re.search(rf"\b{keyword}\b", query_upper):
return False, f"Dangerous keyword '{keyword}' is not allowed"
# Check for semicolons (potential for multiple statements)
if ";" in query and query.count(";") > 1:
return False, "Multiple statements not allowed"
# Check for comments that might hide malicious code
if "--" in query or "/*" in query:
return False, "SQL comments are not allowed"
return True, None
def query(
self,
sql: str,
limit: int = 100,
) -> Dict[str, Any]:
"""
Execute a safe SELECT query.
Args:
sql: SQL SELECT query
limit: Maximum number of rows to return
Returns:
Dictionary with query results
"""
if not self.engine:
return {
"success": False,
"error": "Database not configured",
"results": [],
}
# Check if query is safe
is_safe, error = self.is_safe_query(sql)
if not is_safe:
return {
"success": False,
"error": error,
"results": [],
}
try:
# Add LIMIT if not present
sql_upper = sql.upper()
if "LIMIT" not in sql_upper:
sql = f"{sql.rstrip(';')} LIMIT {limit}"
# Execute query
with self.engine.connect() as connection:
result = connection.execute(text(sql))
rows = result.fetchall()
columns = result.keys()
# Convert to list of dictionaries
results = []
for row in rows:
results.append(dict(zip(columns, row)))
return {
"success": True,
"results": results,
"row_count": len(results),
"columns": list(columns),
}
except SQLAlchemyError as e:
logger.error(f"Database query error: {e}")
return {
"success": False,
"error": str(e),
"results": [],
}
except Exception as e:
logger.error(f"Unexpected error executing query: {e}")
return {
"success": False,
"error": str(e),
"results": [],
}
def get_table_schema(self, table_name: str) -> Dict[str, Any]:
"""
Get schema information for a table.
Args:
table_name: Name of the table
Returns:
Dictionary with table schema
"""
if not self.engine:
return {
"success": False,
"error": "Database not configured",
}
try:
inspector = inspect(self.engine)
columns = inspector.get_columns(table_name)
primary_keys = inspector.get_primary_keys(table_name)
foreign_keys = inspector.get_foreign_keys(table_name)
return {
"success": True,
"table": table_name,
"columns": [
{
"name": col["name"],
"type": str(col["type"]),
"nullable": col.get("nullable", True),
}
for col in columns
],
"primary_keys": primary_keys,
"foreign_keys": [
{
"name": fk["name"],
"constrained_columns": fk["constrained_columns"],
"referred_table": fk["referred_table"],
"referred_columns": fk["referred_columns"],
}
for fk in foreign_keys
],
}
except Exception as e:
logger.error(f"Error getting table schema: {e}")
return {
"success": False,
"error": str(e),
}
def list_tables(self) -> List[str]:
"""List all tables in the database."""
if not self.engine:
return []
try:
inspector = inspect(self.engine)
return inspector.get_table_names()
except Exception as e:
logger.error(f"Error listing tables: {e}")
return []
def get_tool_schema(self) -> Dict[str, Any]:
"""Get tool schema for agent integration."""
return {
"name": "database_query",
"description": "Execute safe SELECT queries on the database",
"parameters": {
"type": "object",
"properties": {
"sql": {
"type": "string",
"description": "SQL SELECT query to execute",
},
"limit": {
"type": "integer",
"description": "Maximum number of rows to return (default: 100)",
"default": 100,
},
},
"required": ["sql"],
},
}
# Global instance
_database_query: Optional[DatabaseQuery] = None
def get_database_query() -> DatabaseQuery:
"""Get or create the global database query instance."""
global _database_query
if _database_query is None:
_database_query = DatabaseQuery()
return _database_query
|