File size: 9,817 Bytes
723f9ab
 
 
 
 
 
e68ff31
 
 
 
 
 
 
9e7f5fb
 
e68ff31
 
 
5f30028
e68ff31
 
5f30028
e68ff31
 
 
 
 
 
 
5f30028
 
 
 
 
 
 
 
723f9ab
 
 
 
e68ff31
 
 
 
 
 
 
 
9e7f5fb
e68ff31
 
 
 
 
 
 
 
 
 
 
 
 
723f9ab
 
 
 
 
 
 
 
 
 
 
14b48d2
 
 
ea520d9
14b48d2
 
723f9ab
 
2131972
 
723f9ab
 
2131972
723f9ab
 
 
 
 
 
14b48d2
 
 
 
 
 
 
 
ea520d9
14b48d2
 
723f9ab
 
 
ea520d9
723f9ab
 
204988a
723f9ab
 
204988a
723f9ab
 
 
 
 
 
 
 
 
 
 
 
 
037873c
 
 
723f9ab
037873c
 
723f9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037873c
 
f422e47
 
 
 
 
 
 
 
037873c
 
 
 
 
 
 
 
 
 
723f9ab
 
 
 
 
 
 
 
 
 
 
 
037873c
723f9ab
 
037873c
723f9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea520d9
e68ff31
 
ea520d9
 
723f9ab
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import duckdb
from pathlib import Path
from app.hf_store import HFStore
from app.metadata import metadata
from app.query_parser import query_parser
from typing import List, Dict, Any
import threading
import time

# Per-workspace connection pool: workspace_id -> (conn, last_used_ts, registered_tables)
_pool: Dict[str, dict] = {}
_pool_lock = threading.Lock()
_POOL_TTL = 300  # seconds idle before eviction
_MAX_POOL_SIZE = 20  # reduced from 50 to prevent memory issues (20 * 256MB = 5GB max)
_MEMORY_LIMIT_MB = 256  # per-connection memory limit


def _evict_idle_connections():
    """Remove connections idle longer than TTL, and LRU if over max size."""
    now = time.time()
    with _pool_lock:
        # First evict idle
        stale = [k for k, v in _pool.items() if now - v["last_used"] > _POOL_TTL]
        for k in stale:
            try:
                _pool[k]["conn"].close()
            except Exception:
                pass
            del _pool[k]
        # Then evict LRU if still over limit
        while len(_pool) > _MAX_POOL_SIZE:
            oldest = min(_pool, key=lambda k: _pool[k]["last_used"])
            try:
                _pool[oldest]["conn"].close()
            except Exception:
                pass
            del _pool[oldest]

class QueryEngine:
    def __init__(self, user_store: HFStore):
        self.store = user_store
        self._workspace_id = getattr(user_store, 'workspace_id', None) or 'default'
        self._owned = False  # whether this instance owns the pooled connection

        _evict_idle_connections()

        with _pool_lock:
            if self._workspace_id not in _pool:
                conn = duckdb.connect(':memory:')
                conn.execute(f"SET memory_limit='{_MEMORY_LIMIT_MB}MB'")
                conn.execute("SET max_temp_directory_size='1GB'")
                conn.execute("SET threads=2")
                _pool[self._workspace_id] = {
                    "conn": conn,
                    "registered": set(),
                    "last_used": time.time(),
                }
                self._owned = True
            _pool[self._workspace_id]["last_used"] = time.time()

        entry = _pool[self._workspace_id]
        self.conn = entry["conn"]
        self._registered_tables = entry["registered"]
    
    def _register_table(self, database: str, table: str):
        """Register a Parquet file as a DuckDB table only if needed"""
        # Sanitize identifiers
        database = query_parser.sanitize_identifier(database)
        table = query_parser.sanitize_identifier(table)
        
        key = f"{database}.{table}"
        if key in self._registered_tables:
            return
        
        # Always ensure schema exists first
        try:
            self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{database}"')
        except Exception:
            pass
        
        parquet_path = self.store.local("data", database, f"{table}.parquet")
        if parquet_path.exists():
            # Use $$ quoting for the path to prevent injection via filenames
            safe_path = str(parquet_path).replace("'", "''")
            self.conn.execute(f"""
                CREATE OR REPLACE VIEW "{database}"."{table}" AS 
                SELECT * FROM read_parquet('{safe_path}')
            """)
            self._registered_tables.add(key)
    
    def _auto_discover_tables(self):
        """Auto-discover all Parquet files and register them"""
        tables = metadata.list_tables(self.store)
        # First, create all schemas
        schemas = set()
        for table_info in tables:
            db = query_parser.sanitize_identifier(table_info['database'])
            if db not in schemas:
                try:
                    self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{db}"')
                    schemas.add(db)
                except Exception:
                    pass
        # Then register tables
        for table_info in tables:
            try:
                self._register_table(table_info['database'], table_info['table'])
            except Exception:
                pass
    
    def execute_sql(self, sql: str, params: List[Any] = None, auto_discover=True, allow_write=True, allow_schema_ops=False):
        """Execute raw SQL query with validation and lazy table discovery"""
        # Validate query
        is_valid, error = query_parser.validate_query(sql, allow_write=allow_write, allow_schema_ops=allow_schema_ops)
        if not is_valid:
            return {
                'ok': False,
                'error': f"Query validation failed: {error}"
            }
        
        if auto_discover:
            # Only discover and register tables mentioned in the query
            tables_to_register = query_parser.extract_tables(sql)
            for table_name in tables_to_register:
                if '.' in table_name:
                    db, tbl = table_name.split('.', 1)
                    try:
                        # Create schema first, then register table
                        db_sanitized = query_parser.sanitize_identifier(db)
                        self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{db_sanitized}"')
                        self._register_table(db, tbl)
                    except Exception as e:
                        print(f"Warning: Failed to register {db}.{tbl}: {e}")
                else:
                    # Default database (e.g. 'public' or first database found)
                    pass
        
        try:
            # Use parameterized queries if params provided
            if params:
                result = self.conn.execute(sql, params).fetchdf()
            else:
                result = self.conn.execute(sql).fetchdf()
            
            return {
                'ok': True,
                'data': result.to_dict('records'),
                'columns': list(result.columns),
                'rows': len(result)
            }
        except Exception as e:
            return {
                'ok': False,
                'error': str(e)
            }
    
    def query_table(self, database: str, table: str, filters: Dict = None, limit: int = None, offset: int = None):
        """Query a table with optional filters using safe parameterized queries"""
        self._register_table(database, table)
        
        # Build safe query
        sql = query_parser.build_safe_query(database, table, filters, limit, offset)
        
        # Extract filter values for parameter binding
        params = []
        if filters:
            params = list(filters.values())
        
        return self.execute_sql(sql, params=params, auto_discover=False)
    
    def insert_row(self, database: str, table: str, row: Dict):
        """Insert row using DuckDB directly to Parquet"""
        database = query_parser.sanitize_identifier(database)
        table = query_parser.sanitize_identifier(table)
        
        parquet_path = self.store.local("data", database, f"{table}.parquet")
        
        # Ensure schema exists
        self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{database}"')

        # Drop any existing view so we can create a real TABLE for mutation
        try:
            self.conn.execute(f'DROP VIEW IF EXISTS "{database}"."{table}"')
        except Exception:
            pass
        self._registered_tables.discard(f"{database}.{table}")

        # Create a temporary table from parquet (not a view)
        if parquet_path.exists():
            safe_path = str(parquet_path).replace("'", "''")
            self.conn.execute(f"""
                CREATE OR REPLACE TABLE "{database}"."{table}" AS 
                SELECT * FROM read_parquet('{safe_path}')
            """)
        else:
            return {'ok': False, 'error': 'Table does not exist. Create it first.'}
        
        # Build INSERT statement with parameter binding
        columns = list(row.keys())
        safe_columns = [query_parser.sanitize_identifier(c) for c in columns]
        placeholders = ', '.join(['?' for _ in columns])
        column_list = ', '.join([f'"{c}"' for c in safe_columns])
        
        sql = f'INSERT INTO "{database}"."{table}" ({column_list}) VALUES ({placeholders})'
        
        try:
            self.conn.execute(sql, list(row.values()))
            
            # Export back to Parquet
            safe_path = str(parquet_path).replace("'", "''")
            self.conn.execute(f"""
                COPY (SELECT * FROM "{database}"."{table}") 
                TO '{safe_path}' (FORMAT PARQUET)
            """)
            
            return {'ok': True}
        except Exception as e:
            return {'ok': False, 'error': str(e)}
    
    def get_table_stats(self, database: str, table: str):
        """Get statistics about a table"""
        self._register_table(database, table)
        
        db = query_parser.sanitize_identifier(database)
        tbl = query_parser.sanitize_identifier(table)
        
        stats_sql = f"""
            SELECT 
                COUNT(*) as row_count
            FROM "{db}"."{tbl}"
        """
        
        return self.execute_sql(stats_sql, auto_discover=False)
    
    def cross_database_query(self, sql: str, params: List[Any] = None):
        """Execute queries across multiple databases with validation"""
        self._auto_discover_tables()
        return self.execute_sql(sql, params=params, auto_discover=False)
    
    def close(self):
        """Close connection and remove from pool."""
        with _pool_lock:
            if self._workspace_id in _pool:
                del _pool[self._workspace_id]
        self.conn.close()

def create_query_engine(user_store: HFStore) -> QueryEngine:
    """Factory function to create a query engine for a user"""
    return QueryEngine(user_store)