corpusdb / app /query_engine.py
mrsavage1's picture
Upload 62 files
ea520d9 verified
import duckdb
from pathlib import Path
from app.hf_store import HFStore
from app.metadata import metadata
from app.query_parser import query_parser
from typing import List, Dict, Any
import threading
import time
# Per-workspace connection pool: workspace_id -> (conn, last_used_ts, registered_tables)
_pool: Dict[str, dict] = {}
_pool_lock = threading.Lock()
_POOL_TTL = 300 # seconds idle before eviction
_MAX_POOL_SIZE = 20 # reduced from 50 to prevent memory issues (20 * 256MB = 5GB max)
_MEMORY_LIMIT_MB = 256 # per-connection memory limit
def _evict_idle_connections():
"""Remove connections idle longer than TTL, and LRU if over max size."""
now = time.time()
with _pool_lock:
# First evict idle
stale = [k for k, v in _pool.items() if now - v["last_used"] > _POOL_TTL]
for k in stale:
try:
_pool[k]["conn"].close()
except Exception:
pass
del _pool[k]
# Then evict LRU if still over limit
while len(_pool) > _MAX_POOL_SIZE:
oldest = min(_pool, key=lambda k: _pool[k]["last_used"])
try:
_pool[oldest]["conn"].close()
except Exception:
pass
del _pool[oldest]
class QueryEngine:
def __init__(self, user_store: HFStore):
self.store = user_store
self._workspace_id = getattr(user_store, 'workspace_id', None) or 'default'
self._owned = False # whether this instance owns the pooled connection
_evict_idle_connections()
with _pool_lock:
if self._workspace_id not in _pool:
conn = duckdb.connect(':memory:')
conn.execute(f"SET memory_limit='{_MEMORY_LIMIT_MB}MB'")
conn.execute("SET max_temp_directory_size='1GB'")
conn.execute("SET threads=2")
_pool[self._workspace_id] = {
"conn": conn,
"registered": set(),
"last_used": time.time(),
}
self._owned = True
_pool[self._workspace_id]["last_used"] = time.time()
entry = _pool[self._workspace_id]
self.conn = entry["conn"]
self._registered_tables = entry["registered"]
def _register_table(self, database: str, table: str):
"""Register a Parquet file as a DuckDB table only if needed"""
# Sanitize identifiers
database = query_parser.sanitize_identifier(database)
table = query_parser.sanitize_identifier(table)
key = f"{database}.{table}"
if key in self._registered_tables:
return
# Always ensure schema exists first
try:
self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{database}"')
except Exception:
pass
parquet_path = self.store.local("data", database, f"{table}.parquet")
if parquet_path.exists():
# Use $$ quoting for the path to prevent injection via filenames
safe_path = str(parquet_path).replace("'", "''")
self.conn.execute(f"""
CREATE OR REPLACE VIEW "{database}"."{table}" AS
SELECT * FROM read_parquet('{safe_path}')
""")
self._registered_tables.add(key)
def _auto_discover_tables(self):
"""Auto-discover all Parquet files and register them"""
tables = metadata.list_tables(self.store)
# First, create all schemas
schemas = set()
for table_info in tables:
db = query_parser.sanitize_identifier(table_info['database'])
if db not in schemas:
try:
self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{db}"')
schemas.add(db)
except Exception:
pass
# Then register tables
for table_info in tables:
try:
self._register_table(table_info['database'], table_info['table'])
except Exception:
pass
def execute_sql(self, sql: str, params: List[Any] = None, auto_discover=True, allow_write=True, allow_schema_ops=False):
"""Execute raw SQL query with validation and lazy table discovery"""
# Validate query
is_valid, error = query_parser.validate_query(sql, allow_write=allow_write, allow_schema_ops=allow_schema_ops)
if not is_valid:
return {
'ok': False,
'error': f"Query validation failed: {error}"
}
if auto_discover:
# Only discover and register tables mentioned in the query
tables_to_register = query_parser.extract_tables(sql)
for table_name in tables_to_register:
if '.' in table_name:
db, tbl = table_name.split('.', 1)
try:
# Create schema first, then register table
db_sanitized = query_parser.sanitize_identifier(db)
self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{db_sanitized}"')
self._register_table(db, tbl)
except Exception as e:
print(f"Warning: Failed to register {db}.{tbl}: {e}")
else:
# Default database (e.g. 'public' or first database found)
pass
try:
# Use parameterized queries if params provided
if params:
result = self.conn.execute(sql, params).fetchdf()
else:
result = self.conn.execute(sql).fetchdf()
return {
'ok': True,
'data': result.to_dict('records'),
'columns': list(result.columns),
'rows': len(result)
}
except Exception as e:
return {
'ok': False,
'error': str(e)
}
def query_table(self, database: str, table: str, filters: Dict = None, limit: int = None, offset: int = None):
"""Query a table with optional filters using safe parameterized queries"""
self._register_table(database, table)
# Build safe query
sql = query_parser.build_safe_query(database, table, filters, limit, offset)
# Extract filter values for parameter binding
params = []
if filters:
params = list(filters.values())
return self.execute_sql(sql, params=params, auto_discover=False)
def insert_row(self, database: str, table: str, row: Dict):
"""Insert row using DuckDB directly to Parquet"""
database = query_parser.sanitize_identifier(database)
table = query_parser.sanitize_identifier(table)
parquet_path = self.store.local("data", database, f"{table}.parquet")
# Ensure schema exists
self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{database}"')
# Drop any existing view so we can create a real TABLE for mutation
try:
self.conn.execute(f'DROP VIEW IF EXISTS "{database}"."{table}"')
except Exception:
pass
self._registered_tables.discard(f"{database}.{table}")
# Create a temporary table from parquet (not a view)
if parquet_path.exists():
safe_path = str(parquet_path).replace("'", "''")
self.conn.execute(f"""
CREATE OR REPLACE TABLE "{database}"."{table}" AS
SELECT * FROM read_parquet('{safe_path}')
""")
else:
return {'ok': False, 'error': 'Table does not exist. Create it first.'}
# Build INSERT statement with parameter binding
columns = list(row.keys())
safe_columns = [query_parser.sanitize_identifier(c) for c in columns]
placeholders = ', '.join(['?' for _ in columns])
column_list = ', '.join([f'"{c}"' for c in safe_columns])
sql = f'INSERT INTO "{database}"."{table}" ({column_list}) VALUES ({placeholders})'
try:
self.conn.execute(sql, list(row.values()))
# Export back to Parquet
safe_path = str(parquet_path).replace("'", "''")
self.conn.execute(f"""
COPY (SELECT * FROM "{database}"."{table}")
TO '{safe_path}' (FORMAT PARQUET)
""")
return {'ok': True}
except Exception as e:
return {'ok': False, 'error': str(e)}
def get_table_stats(self, database: str, table: str):
"""Get statistics about a table"""
self._register_table(database, table)
db = query_parser.sanitize_identifier(database)
tbl = query_parser.sanitize_identifier(table)
stats_sql = f"""
SELECT
COUNT(*) as row_count
FROM "{db}"."{tbl}"
"""
return self.execute_sql(stats_sql, auto_discover=False)
def cross_database_query(self, sql: str, params: List[Any] = None):
"""Execute queries across multiple databases with validation"""
self._auto_discover_tables()
return self.execute_sql(sql, params=params, auto_discover=False)
def close(self):
"""Close connection and remove from pool."""
with _pool_lock:
if self._workspace_id in _pool:
del _pool[self._workspace_id]
self.conn.close()
def create_query_engine(user_store: HFStore) -> QueryEngine:
"""Factory function to create a query engine for a user"""
return QueryEngine(user_store)