File size: 7,861 Bytes
8bf4d58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""Database query tool with safety checks."""

import logging
from typing import List, Dict, Any, Optional
import re
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.exc import SQLAlchemyError
from src.core.config import get_settings

logger = logging.getLogger(__name__)


class DatabaseQuery:
    """Database query tool with SQL injection prevention."""

    # Dangerous SQL keywords that should not be allowed
    DANGEROUS_KEYWORDS = {
        "DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT",
        "UPDATE", "GRANT", "REVOKE", "EXEC", "EXECUTE", "MERGE",
    }

    # Allowed SQL keywords (SELECT queries only)
    ALLOWED_KEYWORDS = {
        "SELECT", "FROM", "WHERE", "JOIN", "INNER", "LEFT", "RIGHT",
        "FULL", "OUTER", "ON", "GROUP", "BY", "ORDER", "HAVING",
        "LIMIT", "OFFSET", "AS", "AND", "OR", "NOT", "IN", "LIKE",
        "BETWEEN", "IS", "NULL", "DISTINCT", "COUNT", "SUM", "AVG",
        "MAX", "MIN", "CASE", "WHEN", "THEN", "ELSE", "END",
    }

    def __init__(self, database_url: Optional[str] = None):
        """Initialize database query tool."""
        self.settings = get_settings()
        self.database_url = database_url or self.settings.database_url

        if not self.database_url:
            logger.warning("No database URL configured")
            self.engine = None
        else:
            try:
                self.engine = create_engine(self.database_url)
                logger.info(f"Connected to database: {self.database_url.split('@')[-1] if '@' in self.database_url else 'local'}")
            except Exception as e:
                logger.error(f"Error connecting to database: {e}")
                self.engine = None

    def is_safe_query(self, query: str) -> tuple[bool, Optional[str]]:
        """
        Check if a SQL query is safe to execute.

        Args:
            query: SQL query string

        Returns:
            Tuple of (is_safe, error_message)
        """
        query_upper = query.upper().strip()

        # Must start with SELECT
        if not query_upper.startswith("SELECT"):
            return False, "Only SELECT queries are allowed"

        # Check for dangerous keywords
        for keyword in self.DANGEROUS_KEYWORDS:
            if re.search(rf"\b{keyword}\b", query_upper):
                return False, f"Dangerous keyword '{keyword}' is not allowed"

        # Check for semicolons (potential for multiple statements)
        if ";" in query and query.count(";") > 1:
            return False, "Multiple statements not allowed"

        # Check for comments that might hide malicious code
        if "--" in query or "/*" in query:
            return False, "SQL comments are not allowed"

        return True, None

    def query(
        self,
        sql: str,
        limit: int = 100,
    ) -> Dict[str, Any]:
        """
        Execute a safe SELECT query.

        Args:
            sql: SQL SELECT query
            limit: Maximum number of rows to return

        Returns:
            Dictionary with query results
        """
        if not self.engine:
            return {
                "success": False,
                "error": "Database not configured",
                "results": [],
            }

        # Check if query is safe
        is_safe, error = self.is_safe_query(sql)
        if not is_safe:
            return {
                "success": False,
                "error": error,
                "results": [],
            }

        try:
            # Add LIMIT if not present
            sql_upper = sql.upper()
            if "LIMIT" not in sql_upper:
                sql = f"{sql.rstrip(';')} LIMIT {limit}"

            # Execute query
            with self.engine.connect() as connection:
                result = connection.execute(text(sql))
                rows = result.fetchall()
                columns = result.keys()

                # Convert to list of dictionaries
                results = []
                for row in rows:
                    results.append(dict(zip(columns, row)))

                return {
                    "success": True,
                    "results": results,
                    "row_count": len(results),
                    "columns": list(columns),
                }
        except SQLAlchemyError as e:
            logger.error(f"Database query error: {e}")
            return {
                "success": False,
                "error": str(e),
                "results": [],
            }
        except Exception as e:
            logger.error(f"Unexpected error executing query: {e}")
            return {
                "success": False,
                "error": str(e),
                "results": [],
            }

    def get_table_schema(self, table_name: str) -> Dict[str, Any]:
        """
        Get schema information for a table.

        Args:
            table_name: Name of the table

        Returns:
            Dictionary with table schema
        """
        if not self.engine:
            return {
                "success": False,
                "error": "Database not configured",
            }

        try:
            inspector = inspect(self.engine)
            columns = inspector.get_columns(table_name)
            primary_keys = inspector.get_primary_keys(table_name)
            foreign_keys = inspector.get_foreign_keys(table_name)

            return {
                "success": True,
                "table": table_name,
                "columns": [
                    {
                        "name": col["name"],
                        "type": str(col["type"]),
                        "nullable": col.get("nullable", True),
                    }
                    for col in columns
                ],
                "primary_keys": primary_keys,
                "foreign_keys": [
                    {
                        "name": fk["name"],
                        "constrained_columns": fk["constrained_columns"],
                        "referred_table": fk["referred_table"],
                        "referred_columns": fk["referred_columns"],
                    }
                    for fk in foreign_keys
                ],
            }
        except Exception as e:
            logger.error(f"Error getting table schema: {e}")
            return {
                "success": False,
                "error": str(e),
            }

    def list_tables(self) -> List[str]:
        """List all tables in the database."""
        if not self.engine:
            return []

        try:
            inspector = inspect(self.engine)
            return inspector.get_table_names()
        except Exception as e:
            logger.error(f"Error listing tables: {e}")
            return []

    def get_tool_schema(self) -> Dict[str, Any]:
        """Get tool schema for agent integration."""
        return {
            "name": "database_query",
            "description": "Execute safe SELECT queries on the database",
            "parameters": {
                "type": "object",
                "properties": {
                    "sql": {
                        "type": "string",
                        "description": "SQL SELECT query to execute",
                    },
                    "limit": {
                        "type": "integer",
                        "description": "Maximum number of rows to return (default: 100)",
                        "default": 100,
                    },
                },
                "required": ["sql"],
            },
        }


# Global instance
_database_query: Optional[DatabaseQuery] = None


def get_database_query() -> DatabaseQuery:
    """Get or create the global database query instance."""
    global _database_query
    if _database_query is None:
        _database_query = DatabaseQuery()
    return _database_query