File size: 14,163 Bytes
b07564d
f4dc602
 
b07564d
6c6d38f
f4dc602
6c6d38f
f4dc602
 
 
6c6d38f
e002acf
6c6d38f
 
 
 
 
 
 
 
85b8a4e
 
f4dc602
6c6d38f
 
 
 
 
f4dc602
 
 
6c6d38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4dc602
 
6c6d38f
f4dc602
 
6c6d38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e002acf
6c6d38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b07564d
6c6d38f
 
 
 
 
 
 
 
 
 
9d6bac9
6c6d38f
 
2336094
e002acf
6c6d38f
 
 
 
 
 
 
 
 
85b8a4e
6c6d38f
 
 
 
85b8a4e
6c6d38f
85b8a4e
6c6d38f
 
 
 
 
 
 
 
 
2336094
6c6d38f
 
 
 
 
 
 
 
 
 
 
 
 
2336094
 
6c6d38f
 
2336094
85b8a4e
 
6c6d38f
 
 
 
2336094
6c6d38f
 
85b8a4e
2336094
6c6d38f
 
2336094
6c6d38f
 
 
2336094
6c6d38f
 
85b8a4e
6c6d38f
 
 
 
85b8a4e
6c6d38f
 
 
 
2336094
 
e002acf
6c6d38f
 
e002acf
6c6d38f
 
 
 
 
 
 
 
 
 
2336094
6c6d38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2336094
6c6d38f
 
 
 
f4dc602
6c6d38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4dc602
6c6d38f
 
 
 
 
 
 
 
 
 
 
 
b07564d
6c6d38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
# space/tools/sql_tool.py
import os
import re
import json
import logging
import pandas as pd
from typing import Optional
from utils.config import AppConfig
from utils.tracing import Tracer

logger = logging.getLogger(__name__)

RESERVED_MD_WORKSPACE_NAMES = {"", "workspace", "default"}
MAX_QUERY_LENGTH = 50000
MAX_RESULT_ROWS = 100000


class SQLToolError(Exception):
    """Custom exception for SQL tool errors."""
    pass


class SQLTool:
    """
    SQL execution tool supporting BigQuery and MotherDuck backends.
    Includes input validation, error handling, and secure query execution.
    """
    
    def __init__(self, cfg: AppConfig, tracer: Tracer):
        self.cfg = cfg
        self.tracer = tracer
        self.backend = cfg.sql_backend
        self.client = None
        
        logger.info(f"Initializing SQLTool with backend: {self.backend}")
        
        try:
            if self.backend == "bigquery":
                self._init_bigquery()
            elif self.backend == "motherduck":
                self._init_motherduck()
            else:
                raise SQLToolError(f"Unknown SQL backend: {self.backend}")
            
            logger.info(f"SQLTool initialized successfully with {self.backend}")
            
        except Exception as e:
            logger.error(f"Failed to initialize SQLTool: {e}")
            raise SQLToolError(f"SQL backend initialization failed: {e}") from e
    
    def _init_bigquery(self):
        """Initialize BigQuery client with service account credentials."""
        try:
            from google.cloud import bigquery
            from google.oauth2 import service_account
            
            key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
            if not key_json:
                raise SQLToolError(
                    "Missing GCP_SERVICE_ACCOUNT_JSON environment variable. "
                    "Please configure BigQuery credentials."
                )
            
            # Parse credentials
            try:
                if key_json.strip().startswith("{"):
                    info = json.loads(key_json)
                else:
                    # Assume it's a file path
                    with open(key_json, 'r') as f:
                        info = json.load(f)
            except json.JSONDecodeError as e:
                raise SQLToolError(f"Invalid JSON in GCP_SERVICE_ACCOUNT_JSON: {e}")
            except FileNotFoundError:
                raise SQLToolError(f"GCP service account file not found: {key_json}")
            
            # Validate required fields
            required_fields = ["type", "project_id", "private_key", "client_email"]
            missing = [f for f in required_fields if f not in info]
            if missing:
                raise SQLToolError(
                    f"GCP service account JSON missing required fields: {missing}"
                )
            
            creds = service_account.Credentials.from_service_account_info(info)
            project = self.cfg.gcp_project or info.get("project_id")
            
            if not project:
                raise SQLToolError("GCP project ID not specified in config or credentials")
            
            self.client = bigquery.Client(credentials=creds, project=project)
            logger.info(f"BigQuery client initialized for project: {project}")
            
        except ImportError as e:
            raise SQLToolError(
                "BigQuery dependencies not installed. "
                "Install with: pip install google-cloud-bigquery"
            ) from e
    
    def _init_motherduck(self):
        """Initialize MotherDuck/DuckDB client with version validation."""
        try:
            import duckdb
            
            # Version compatibility check - be more flexible
            version = duckdb.__version__
            logger.info(f"DuckDB version: {version}")
            
            # Warn if not on recommended version, but don't fail
            if not version.startswith("1.3"):
                logger.warning(
                    f"DuckDB {version} detected. Recommended: 1.3.x for MotherDuck compatibility. "
                    "Some features may not work as expected."
                )
            
            # Get configuration
            token = (self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN") or "").strip()
            if not token:
                raise SQLToolError(
                    "Missing MOTHERDUCK_TOKEN. "
                    "Get your token from: https://motherduck.com/docs/key-tasks/authenticating-to-motherduck"
                )
            
            db_name = (self.cfg.motherduck_db or "workspace").strip()
            allow_create = os.getenv("ALLOW_CREATE_DB", "true").lower() == "true"
            
            # Connect based on database name
            if db_name in RESERVED_MD_WORKSPACE_NAMES:
                # Workspace mode - no specific database context
                connection_string = f"md:?motherduck_token={token}"
                logger.info("Connecting to MotherDuck workspace")
                self.client = duckdb.connect(connection_string)
            else:
                # Try connecting to specific database
                try:
                    connection_string = f"md:{db_name}?motherduck_token={token}"
                    logger.info(f"Connecting to MotherDuck database: {db_name}")
                    self.client = duckdb.connect(connection_string)
                except Exception as db_err:
                    logger.warning(f"Direct connection to '{db_name}' failed: {db_err}")
                    
                    # Fallback: connect to workspace and setup database
                    connection_string = f"md:?motherduck_token={token}"
                    self.client = duckdb.connect(connection_string)
                    self._ensure_db_context(db_name, allow_create)
            
            # Test connection
            try:
                self.client.execute("SELECT 1").fetchone()
                logger.info("MotherDuck connection test successful")
            except Exception as e:
                raise SQLToolError(f"MotherDuck connection test failed: {e}")
                
        except ImportError as e:
            raise SQLToolError(
                "DuckDB not installed. Install with: pip install duckdb"
            ) from e
    
    def _ensure_db_context(self, db_name: str, allow_create: bool):
        """
        Ensure database context is set for MotherDuck.
        Creates database if it doesn't exist and allow_create is True.
        """
        if db_name in RESERVED_MD_WORKSPACE_NAMES:
            return
        
        safe_name = self._quote_ident(db_name)
        
        # Try to USE the database first
        try:
            self.client.execute(f"USE {safe_name};")
            logger.info(f"Using existing database: {db_name}")
            return
        except Exception as use_err:
            logger.info(f"Database '{db_name}' not found: {use_err}")
            
            if not allow_create:
                raise SQLToolError(
                    f"Database '{db_name}' does not exist and ALLOW_CREATE_DB is disabled. "
                    f"Either create the database manually or set ALLOW_CREATE_DB=true"
                )
        
        # Attempt to create and use the database
        try:
            logger.info(f"Creating database: {db_name}")
            self.client.execute(f"CREATE DATABASE IF NOT EXISTS {safe_name};")
            self.client.execute(f"USE {safe_name};")
            logger.info(f"Database '{db_name}' created and selected")
        except Exception as create_err:
            raise SQLToolError(
                f"Failed to create database '{db_name}': {create_err}"
            ) from create_err
    
    @staticmethod
    def _quote_ident(name: str) -> str:
        """
        Safely quote SQL identifiers.
        Replaces non-alphanumeric characters with underscores.
        """
        if not name:
            return "unnamed"
        
        # Remove dangerous characters
        safe = re.sub(r"[^a-zA-Z0-9_]", "_", name)
        
        # Ensure it doesn't start with a number
        if safe[0].isdigit():
            safe = "_" + safe
        
        return safe
    
    def _validate_sql(self, sql: str) -> tuple[bool, str]:
        """
        Validate SQL query for basic safety.
        Returns (is_valid, error_message).
        """
        if not sql or not sql.strip():
            return False, "Empty SQL query"
        
        if len(sql) > MAX_QUERY_LENGTH:
            return False, f"Query too long (max {MAX_QUERY_LENGTH} characters)"
        
        # Dangerous patterns check
        sql_lower = sql.lower()
        
        # Block multiple statements (simple check)
        if sql.count(';') > 1:
            return False, "Multiple SQL statements not allowed"
        
        # Block dangerous keywords in non-SELECT queries
        dangerous_patterns = [
            (r'\bdrop\s+table\b', "DROP TABLE"),
            (r'\bdrop\s+database\b', "DROP DATABASE"),
            (r'\bdelete\s+from\b', "DELETE FROM"),
            (r'\btruncate\b', "TRUNCATE"),
            (r'\bexec\s*\(', "EXEC"),
            (r'\bexecute\s*\(', "EXECUTE"),
        ]
        
        for pattern, name in dangerous_patterns:
            if re.search(pattern, sql_lower):
                logger.warning(f"Blocked query with {name} pattern")
                return False, f"Query contains blocked operation: {name}"
        
        return True, ""
    
    def _nl_to_sql(self, message: str) -> str:
        """
        Convert natural language to SQL query.
        This is a simple heuristic - replace with proper NL2SQL model for production.
        """
        m = message.lower()
        
        # If it's already SQL, return as-is (after validation)
        if re.match(r'^\s*select\s', m, re.IGNORECASE):
            return message.strip()
        
        # Template-based generation (customize for your schema)
        if "avg" in m or "average" in m:
            if "by month" in m or "monthly" in m:
                return """
SELECT 
    DATE_TRUNC('month', date_col) AS month,
    AVG(metric_col) AS avg_metric
FROM analytics.fact_table
GROUP BY 1
ORDER BY 1 DESC
LIMIT 100;
"""
        
        if "top" in m:
            # Extract number if present
            match = re.search(r'top\s+(\d+)', m)
            limit = match.group(1) if match else "10"
            return f"""
SELECT *
FROM analytics.fact_table
ORDER BY metric_col DESC
LIMIT {limit};
"""
        
        if "count" in m:
            return """
SELECT 
    category_col,
    COUNT(*) AS count
FROM analytics.fact_table
GROUP BY 1
ORDER BY 2 DESC
LIMIT 100;
"""
        
        # Default fallback
        return """
SELECT *
FROM analytics.fact_table
LIMIT 100;
"""
    
    def run(self, message: str) -> pd.DataFrame:
        """
        Execute SQL query from natural language or SQL statement.
        
        Args:
            message: Natural language query or SQL statement
            
        Returns:
            DataFrame with query results
            
        Raises:
            SQLToolError: If query execution fails
        """
        try:
            # Convert to SQL
            sql = self._nl_to_sql(message)
            logger.info(f"Generated SQL query (first 200 chars): {sql[:200]}")
            
            # Validate SQL
            is_valid, error_msg = self._validate_sql(sql)
            if not is_valid:
                raise SQLToolError(f"Invalid SQL query: {error_msg}")
            
            # Log query attempt
            self.tracer.trace_event("sql_query", {
                "sql": sql[:1000],  # Limit logged SQL length
                "backend": self.backend,
                "message": message[:500]
            })
            
            # Execute based on backend
            if self.backend == "bigquery":
                result = self._execute_bigquery(sql)
            else:  # motherduck
                result = self._execute_duckdb(sql)
            
            # Validate result
            if not isinstance(result, pd.DataFrame):
                raise SQLToolError("Query did not return a DataFrame")
            
            # Check result size
            if len(result) > MAX_RESULT_ROWS:
                logger.warning(f"Result truncated from {len(result)} to {MAX_RESULT_ROWS} rows")
                result = result.head(MAX_RESULT_ROWS)
            
            logger.info(f"Query successful: {len(result)} rows, {len(result.columns)} columns")
            self.tracer.trace_event("sql_success", {
                "rows": len(result),
                "columns": len(result.columns)
            })
            
            return result
            
        except SQLToolError:
            raise
        except Exception as e:
            error_msg = f"SQL execution failed: {str(e)}"
            logger.error(error_msg)
            self.tracer.trace_event("sql_error", {"error": error_msg})
            raise SQLToolError(error_msg) from e
    
    def _execute_bigquery(self, sql: str) -> pd.DataFrame:
        """Execute query on BigQuery."""
        try:
            query_job = self.client.query(sql)
            df = query_job.to_dataframe()
            return df
        except Exception as e:
            raise SQLToolError(f"BigQuery execution error: {str(e)}") from e
    
    def _execute_duckdb(self, sql: str) -> pd.DataFrame:
        """Execute query on DuckDB/MotherDuck."""
        try:
            result = self.client.execute(sql)
            df = result.fetch_df()
            return df
        except Exception as e:
            raise SQLToolError(f"DuckDB execution error: {str(e)}") from e
    
    def test_connection(self) -> bool:
        """Test database connection."""
        try:
            test_query = "SELECT 1 AS test"
            result = self.run(test_query)
            return len(result) == 1 and result.iloc[0, 0] == 1
        except Exception as e:
            logger.error(f"Connection test failed: {e}")
            return False