import duckdb from pathlib import Path from app.hf_store import HFStore from app.metadata import metadata from app.query_parser import query_parser from typing import List, Dict, Any import threading import time # Per-workspace connection pool: workspace_id -> (conn, last_used_ts, registered_tables) _pool: Dict[str, dict] = {} _pool_lock = threading.Lock() _POOL_TTL = 300 # seconds idle before eviction _MAX_POOL_SIZE = 20 # reduced from 50 to prevent memory issues (20 * 256MB = 5GB max) _MEMORY_LIMIT_MB = 256 # per-connection memory limit def _evict_idle_connections(): """Remove connections idle longer than TTL, and LRU if over max size.""" now = time.time() with _pool_lock: # First evict idle stale = [k for k, v in _pool.items() if now - v["last_used"] > _POOL_TTL] for k in stale: try: _pool[k]["conn"].close() except Exception: pass del _pool[k] # Then evict LRU if still over limit while len(_pool) > _MAX_POOL_SIZE: oldest = min(_pool, key=lambda k: _pool[k]["last_used"]) try: _pool[oldest]["conn"].close() except Exception: pass del _pool[oldest] class QueryEngine: def __init__(self, user_store: HFStore): self.store = user_store self._workspace_id = getattr(user_store, 'workspace_id', None) or 'default' self._owned = False # whether this instance owns the pooled connection _evict_idle_connections() with _pool_lock: if self._workspace_id not in _pool: conn = duckdb.connect(':memory:') conn.execute(f"SET memory_limit='{_MEMORY_LIMIT_MB}MB'") conn.execute("SET max_temp_directory_size='1GB'") conn.execute("SET threads=2") _pool[self._workspace_id] = { "conn": conn, "registered": set(), "last_used": time.time(), } self._owned = True _pool[self._workspace_id]["last_used"] = time.time() entry = _pool[self._workspace_id] self.conn = entry["conn"] self._registered_tables = entry["registered"] def _register_table(self, database: str, table: str): """Register a Parquet file as a DuckDB table only if needed""" # Sanitize identifiers database = query_parser.sanitize_identifier(database) table = query_parser.sanitize_identifier(table) key = f"{database}.{table}" if key in self._registered_tables: return # Always ensure schema exists first try: self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{database}"') except Exception: pass parquet_path = self.store.local("data", database, f"{table}.parquet") if parquet_path.exists(): # Use $$ quoting for the path to prevent injection via filenames safe_path = str(parquet_path).replace("'", "''") self.conn.execute(f""" CREATE OR REPLACE VIEW "{database}"."{table}" AS SELECT * FROM read_parquet('{safe_path}') """) self._registered_tables.add(key) def _auto_discover_tables(self): """Auto-discover all Parquet files and register them""" tables = metadata.list_tables(self.store) # First, create all schemas schemas = set() for table_info in tables: db = query_parser.sanitize_identifier(table_info['database']) if db not in schemas: try: self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{db}"') schemas.add(db) except Exception: pass # Then register tables for table_info in tables: try: self._register_table(table_info['database'], table_info['table']) except Exception: pass def execute_sql(self, sql: str, params: List[Any] = None, auto_discover=True, allow_write=True, allow_schema_ops=False): """Execute raw SQL query with validation and lazy table discovery""" # Validate query is_valid, error = query_parser.validate_query(sql, allow_write=allow_write, allow_schema_ops=allow_schema_ops) if not is_valid: return { 'ok': False, 'error': f"Query validation failed: {error}" } if auto_discover: # Only discover and register tables mentioned in the query tables_to_register = query_parser.extract_tables(sql) for table_name in tables_to_register: if '.' in table_name: db, tbl = table_name.split('.', 1) try: # Create schema first, then register table db_sanitized = query_parser.sanitize_identifier(db) self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{db_sanitized}"') self._register_table(db, tbl) except Exception as e: print(f"Warning: Failed to register {db}.{tbl}: {e}") else: # Default database (e.g. 'public' or first database found) pass try: # Use parameterized queries if params provided if params: result = self.conn.execute(sql, params).fetchdf() else: result = self.conn.execute(sql).fetchdf() return { 'ok': True, 'data': result.to_dict('records'), 'columns': list(result.columns), 'rows': len(result) } except Exception as e: return { 'ok': False, 'error': str(e) } def query_table(self, database: str, table: str, filters: Dict = None, limit: int = None, offset: int = None): """Query a table with optional filters using safe parameterized queries""" self._register_table(database, table) # Build safe query sql = query_parser.build_safe_query(database, table, filters, limit, offset) # Extract filter values for parameter binding params = [] if filters: params = list(filters.values()) return self.execute_sql(sql, params=params, auto_discover=False) def insert_row(self, database: str, table: str, row: Dict): """Insert row using DuckDB directly to Parquet""" database = query_parser.sanitize_identifier(database) table = query_parser.sanitize_identifier(table) parquet_path = self.store.local("data", database, f"{table}.parquet") # Ensure schema exists self.conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{database}"') # Drop any existing view so we can create a real TABLE for mutation try: self.conn.execute(f'DROP VIEW IF EXISTS "{database}"."{table}"') except Exception: pass self._registered_tables.discard(f"{database}.{table}") # Create a temporary table from parquet (not a view) if parquet_path.exists(): safe_path = str(parquet_path).replace("'", "''") self.conn.execute(f""" CREATE OR REPLACE TABLE "{database}"."{table}" AS SELECT * FROM read_parquet('{safe_path}') """) else: return {'ok': False, 'error': 'Table does not exist. Create it first.'} # Build INSERT statement with parameter binding columns = list(row.keys()) safe_columns = [query_parser.sanitize_identifier(c) for c in columns] placeholders = ', '.join(['?' for _ in columns]) column_list = ', '.join([f'"{c}"' for c in safe_columns]) sql = f'INSERT INTO "{database}"."{table}" ({column_list}) VALUES ({placeholders})' try: self.conn.execute(sql, list(row.values())) # Export back to Parquet safe_path = str(parquet_path).replace("'", "''") self.conn.execute(f""" COPY (SELECT * FROM "{database}"."{table}") TO '{safe_path}' (FORMAT PARQUET) """) return {'ok': True} except Exception as e: return {'ok': False, 'error': str(e)} def get_table_stats(self, database: str, table: str): """Get statistics about a table""" self._register_table(database, table) db = query_parser.sanitize_identifier(database) tbl = query_parser.sanitize_identifier(table) stats_sql = f""" SELECT COUNT(*) as row_count FROM "{db}"."{tbl}" """ return self.execute_sql(stats_sql, auto_discover=False) def cross_database_query(self, sql: str, params: List[Any] = None): """Execute queries across multiple databases with validation""" self._auto_discover_tables() return self.execute_sql(sql, params=params, auto_discover=False) def close(self): """Close connection and remove from pool.""" with _pool_lock: if self._workspace_id in _pool: del _pool[self._workspace_id] self.conn.close() def create_query_engine(user_store: HFStore) -> QueryEngine: """Factory function to create a query engine for a user""" return QueryEngine(user_store)