File size: 9,817 Bytes
723f9ab e68ff31 9e7f5fb e68ff31 5f30028 e68ff31 5f30028 e68ff31 5f30028 723f9ab e68ff31 9e7f5fb e68ff31 723f9ab 14b48d2 ea520d9 14b48d2 723f9ab 2131972 723f9ab 2131972 723f9ab 14b48d2 ea520d9 14b48d2 723f9ab ea520d9 723f9ab 204988a 723f9ab 204988a 723f9ab 037873c 723f9ab 037873c 723f9ab 037873c f422e47 037873c 723f9ab 037873c 723f9ab 037873c 723f9ab ea520d9 e68ff31 ea520d9 723f9ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | import duckdb
from pathlib import Path
from app.hf_store import HFStore
from app.metadata import metadata
from app.query_parser import query_parser
from typing import List, Dict, Any
import threading
import time
# Per-workspace connection pool: workspace_id -> (conn, last_used_ts, registered_tables)
_pool: Dict[str, dict] = {}
_pool_lock = threading.Lock()
_POOL_TTL = 300 # seconds idle before eviction
_MAX_POOL_SIZE = 20 # reduced from 50 to prevent memory issues (20 * 256MB = 5GB max)
_MEMORY_LIMIT_MB = 256 # per-connection memory limit
def _evict_idle_connections():
"""Remove connections idle longer than TTL, and LRU if over max size."""
now = time.time()
with _pool_lock:
# First evict idle
stale = [k for k, v in _pool.items() if now - v["last_used"] > _POOL_TTL]
for k in stale:
try:
_pool[k]["conn"].close()
except Exception:
pass
del _pool[k]
# Then evict LRU if still over limit
while len(_pool) > _MAX_POOL_SIZE:
oldest = min(_pool, key=lambda k: _pool[k]["last_used"])
try:
_pool[oldest]["conn"].close()
except Exception:
pass
del _pool[oldest]
class QueryEngine:
def __init__(self, user_store: HFStore):
self.store = user_store
self._workspace_id = getattr(user_store, 'workspace_id', None) or 'default'
self._owned = False # whether this instance owns the pooled connection
_evict_idle_connections()
with _pool_lock:
if self._workspace_id not in _pool:
conn = duckdb.connect(':memory:')
conn.execute(f"SET memory_limit='{_MEMORY_LIMIT_MB}MB'")
conn.execute("SET max_temp_directory_size='1GB'")
conn.execute("SET threads=2")
_pool[self._workspace_id] = {
"conn": conn,
"registered": set(),
"last_used": time.time(),
}
self._owned = True
_pool[self._workspace_id]["last_used"] = time.time()
entry = _pool[self._workspace_id]
self.conn = entry["conn"]
self._registered_tables = entry["registered"]
def _register_table(self, database: str, table: str):
"""Register a Parquet file as a DuckDB table only if needed"""
# Sanitize identifiers
database = query_parser.sanitize_identifier(database)
table = query_parser.sanitize_identifier(table)
key = f"{database}.{table}"
if key in self._registered_tables:
return
# Always ensure schema exists first
try:
self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{database}"')
except Exception:
pass
parquet_path = self.store.local("data", database, f"{table}.parquet")
if parquet_path.exists():
# Use $$ quoting for the path to prevent injection via filenames
safe_path = str(parquet_path).replace("'", "''")
self.conn.execute(f"""
CREATE OR REPLACE VIEW "{database}"."{table}" AS
SELECT * FROM read_parquet('{safe_path}')
""")
self._registered_tables.add(key)
def _auto_discover_tables(self):
"""Auto-discover all Parquet files and register them"""
tables = metadata.list_tables(self.store)
# First, create all schemas
schemas = set()
for table_info in tables:
db = query_parser.sanitize_identifier(table_info['database'])
if db not in schemas:
try:
self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{db}"')
schemas.add(db)
except Exception:
pass
# Then register tables
for table_info in tables:
try:
self._register_table(table_info['database'], table_info['table'])
except Exception:
pass
def execute_sql(self, sql: str, params: List[Any] = None, auto_discover=True, allow_write=True, allow_schema_ops=False):
"""Execute raw SQL query with validation and lazy table discovery"""
# Validate query
is_valid, error = query_parser.validate_query(sql, allow_write=allow_write, allow_schema_ops=allow_schema_ops)
if not is_valid:
return {
'ok': False,
'error': f"Query validation failed: {error}"
}
if auto_discover:
# Only discover and register tables mentioned in the query
tables_to_register = query_parser.extract_tables(sql)
for table_name in tables_to_register:
if '.' in table_name:
db, tbl = table_name.split('.', 1)
try:
# Create schema first, then register table
db_sanitized = query_parser.sanitize_identifier(db)
self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{db_sanitized}"')
self._register_table(db, tbl)
except Exception as e:
print(f"Warning: Failed to register {db}.{tbl}: {e}")
else:
# Default database (e.g. 'public' or first database found)
pass
try:
# Use parameterized queries if params provided
if params:
result = self.conn.execute(sql, params).fetchdf()
else:
result = self.conn.execute(sql).fetchdf()
return {
'ok': True,
'data': result.to_dict('records'),
'columns': list(result.columns),
'rows': len(result)
}
except Exception as e:
return {
'ok': False,
'error': str(e)
}
def query_table(self, database: str, table: str, filters: Dict = None, limit: int = None, offset: int = None):
"""Query a table with optional filters using safe parameterized queries"""
self._register_table(database, table)
# Build safe query
sql = query_parser.build_safe_query(database, table, filters, limit, offset)
# Extract filter values for parameter binding
params = []
if filters:
params = list(filters.values())
return self.execute_sql(sql, params=params, auto_discover=False)
def insert_row(self, database: str, table: str, row: Dict):
"""Insert row using DuckDB directly to Parquet"""
database = query_parser.sanitize_identifier(database)
table = query_parser.sanitize_identifier(table)
parquet_path = self.store.local("data", database, f"{table}.parquet")
# Ensure schema exists
self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{database}"')
# Drop any existing view so we can create a real TABLE for mutation
try:
self.conn.execute(f'DROP VIEW IF EXISTS "{database}"."{table}"')
except Exception:
pass
self._registered_tables.discard(f"{database}.{table}")
# Create a temporary table from parquet (not a view)
if parquet_path.exists():
safe_path = str(parquet_path).replace("'", "''")
self.conn.execute(f"""
CREATE OR REPLACE TABLE "{database}"."{table}" AS
SELECT * FROM read_parquet('{safe_path}')
""")
else:
return {'ok': False, 'error': 'Table does not exist. Create it first.'}
# Build INSERT statement with parameter binding
columns = list(row.keys())
safe_columns = [query_parser.sanitize_identifier(c) for c in columns]
placeholders = ', '.join(['?' for _ in columns])
column_list = ', '.join([f'"{c}"' for c in safe_columns])
sql = f'INSERT INTO "{database}"."{table}" ({column_list}) VALUES ({placeholders})'
try:
self.conn.execute(sql, list(row.values()))
# Export back to Parquet
safe_path = str(parquet_path).replace("'", "''")
self.conn.execute(f"""
COPY (SELECT * FROM "{database}"."{table}")
TO '{safe_path}' (FORMAT PARQUET)
""")
return {'ok': True}
except Exception as e:
return {'ok': False, 'error': str(e)}
def get_table_stats(self, database: str, table: str):
"""Get statistics about a table"""
self._register_table(database, table)
db = query_parser.sanitize_identifier(database)
tbl = query_parser.sanitize_identifier(table)
stats_sql = f"""
SELECT
COUNT(*) as row_count
FROM "{db}"."{tbl}"
"""
return self.execute_sql(stats_sql, auto_discover=False)
def cross_database_query(self, sql: str, params: List[Any] = None):
"""Execute queries across multiple databases with validation"""
self._auto_discover_tables()
return self.execute_sql(sql, params=params, auto_discover=False)
def close(self):
"""Close connection and remove from pool."""
with _pool_lock:
if self._workspace_id in _pool:
del _pool[self._workspace_id]
self.conn.close()
def create_query_engine(user_store: HFStore) -> QueryEngine:
"""Factory function to create a query engine for a user"""
return QueryEngine(user_store)
|