""" Database adapters for Supabase and SQLite providers. """ from __future__ import annotations import json import os import sqlite3 import threading import uuid from dataclasses import dataclass from datetime import datetime from typing import Any from agno.utils.log import logger from sqlalchemy import MetaData, and_, create_engine, delete, func, inspect, or_, select, update from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as postgres_insert from ..models.db import DbFilter, DbQueryRequest, DbQueryResponse from .db_registry import ProviderConfig from .sqlite_schema import SCHEMA_STATEMENTS def _utc_now_iso() -> str: return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" JSON_COLUMNS: dict[str, set[str]] = { "agents": {"tool_ids", "skill_ids"}, "conversations": {"title_emojis", "session_summary"}, "conversation_messages": { "content", "tool_calls", "tool_call_history", "research_step_history", "related_questions", "sources", "document_sources", "grounding_supports", "stream_blocks", }, "conversation_events": {"payload"}, "attachments": {"data"}, "document_sections": {"title_path", "loc"}, "document_chunks": {"title_path", "loc", "embedding"}, "memory_domains": {"aliases"}, "user_tools": {"config", "input_schema"}, "pending_form_runs": {"requirements_data", "messages"}, "scrapbook": {"tags"}, } TABLES_WITH_ID = { "spaces", "agents", "conversations", "conversation_messages", "conversation_events", "attachments", "space_documents", "document_sections", "document_chunks", "home_notes", "home_shortcuts", "memory_domains", "memory_summaries", "user_tools", "pending_form_runs", "email_provider_configs", "email_notifications", "scrapbook", } TABLES_WITH_UPDATED_AT = { "spaces", "agents", "conversations", "space_documents", "document_sections", "document_chunks", "home_notes", "home_shortcuts", "user_settings", "memory_domains", "memory_summaries", "user_tools", "email_provider_configs", "scrapbook", } TABLES_WITH_CREATED_AT = { "spaces", "agents", "conversations", "conversation_messages", "conversation_events", "attachments", "space_documents", "conversation_documents", "document_sections", "document_chunks", "space_agents", "home_notes", "home_shortcuts", "memory_domains", "memory_summaries", "user_tools", "pending_form_runs", "email_provider_configs", "email_notifications", "scrapbook", } def _serialize_value(table: str, column: str, value: Any) -> Any: if value is None: return None if table in JSON_COLUMNS and column in JSON_COLUMNS[table]: try: return json.dumps(value, ensure_ascii=False) except TypeError: return json.dumps(str(value), ensure_ascii=False) if isinstance(value, bool): return 1 if value else 0 return value def _deserialize_row(table: str, row: dict[str, Any]) -> dict[str, Any]: if table in JSON_COLUMNS: for column in JSON_COLUMNS[table]: if column in row and row[column] is not None: try: row[column] = json.loads(row[column]) except Exception: pass # SQLite stores bool as int for key, value in row.items(): if isinstance(value, int) and key.startswith("is_"): row[key] = bool(value) return row def _prepare_payload(table: str, payload: dict[str, Any]) -> dict[str, Any]: return {key: _serialize_value(table, key, value) for key, value in payload.items()} def _deserialize_rows(table: str, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: return [_deserialize_row(table, dict(row)) for row in rows] def _normalize_columns(columns: str | list[str] | None) -> list[str] | None: if columns is None: return None if isinstance(columns, list): return columns if isinstance(columns, str): trimmed = columns.strip() if trimmed == "*": return None return [col.strip() for col in trimmed.split(",") if col.strip()] return None def _build_where_clause(filters: list[DbFilter] | None) -> tuple[str, list[Any]]: if not filters: return "", [] def build(f: DbFilter) -> tuple[str, list[Any]]: if f.op == "or" and f.filters: or_clauses: list[str] = [] or_params: list[Any] = [] for inner in f.filters: clause, params = build(inner) if clause: or_clauses.append(clause) or_params.extend(params) if not or_clauses: return "", [] return f"({ ' OR '.join(or_clauses) })", or_params column = f.column if not column: return "", [] if f.op == "eq": return f"{column} = ?", [f.value] if f.op == "gt": return f"{column} > ?", [f.value] if f.op == "lt": return f"{column} < ?", [f.value] if f.op == "ilike": return f"LOWER({column}) LIKE LOWER(?)", [f"%{f.value}%"] if f.op == "is_null": return f"{column} IS NULL", [] if f.op in {"in", "not_in"}: values = f.values or [] if not values: return ("1=0", []) if f.op == "in" else ("1=1", []) placeholders = ", ".join(["?"] * len(values)) operator = "IN" if f.op == "in" else "NOT IN" return f"{column} {operator} ({placeholders})", list(values) return "", [] clauses: list[str] = [] params: list[Any] = [] for filt in filters: clause, clause_params = build(filt) if clause: clauses.append(clause) params.extend(clause_params) if not clauses: return "", [] return "WHERE " + " AND ".join(clauses), params def _extract_filter_values(filters: list[DbFilter] | None, column: str) -> list[str]: if not filters: return [] values: list[str] = [] def walk(f: DbFilter) -> None: if f.op == "or" and f.filters: for inner in f.filters: walk(inner) return if f.column != column: return if f.op == "eq" and f.value is not None: values.append(str(f.value)) return if f.op == "in" and f.values: for item in f.values: if item is not None: values.append(str(item)) for filt in filters: walk(filt) return list(dict.fromkeys(values)) def _build_sa_expression(table_obj, filt: DbFilter): if filt.op == "or" and filt.filters: expressions = [_build_sa_expression(table_obj, inner) for inner in filt.filters] expressions = [expr for expr in expressions if expr is not None] return or_(*expressions) if expressions else None column_name = filt.column if not column_name or column_name not in table_obj.c: return None column = table_obj.c[column_name] if filt.op == "eq": return column == filt.value if filt.op == "gt": return column > filt.value if filt.op == "lt": return column < filt.value if filt.op == "ilike": return column.ilike(f"%{filt.value}%") if filt.op == "is_null": return column.is_(None) if filt.op == "in": return column.in_(filt.values or []) if filt.op == "not_in": return ~column.in_(filt.values or []) return None @dataclass class SQLiteAdapter: config: ProviderConfig def __post_init__(self) -> None: self._lock = threading.Lock() if not self.config.sqlite_path: raise ValueError("SQLite provider missing path") os.makedirs(os.path.dirname(self.config.sqlite_path) or ".", exist_ok=True) self._conn = sqlite3.connect( self.config.sqlite_path, check_same_thread=False, timeout=5.0, ) self._conn.row_factory = sqlite3.Row self._configure_connection() self._ensure_schema() def _configure_connection(self) -> None: """ Tune SQLite for mixed read/write concurrency. """ with self._lock: cursor = self._conn.cursor() cursor.execute("PRAGMA journal_mode=WAL;") cursor.execute("PRAGMA synchronous=NORMAL;") cursor.execute("PRAGMA busy_timeout=5000;") cursor.execute("PRAGMA temp_store=MEMORY;") self._conn.commit() def _ensure_schema(self) -> None: with self._lock: cursor = self._conn.cursor() for stmt in SCHEMA_STATEMENTS: cursor.executescript(stmt) # Lightweight forward migrations for existing local DBs. cursor.execute("PRAGMA table_info(conversation_messages)") columns = {str(row[1]) for row in cursor.fetchall()} if "stream_blocks" not in columns: cursor.execute( "ALTER TABLE conversation_messages " "ADD COLUMN stream_blocks TEXT NOT NULL DEFAULT '[]'" ) if "stream_schema_version" not in columns: cursor.execute( "ALTER TABLE conversation_messages " "ADD COLUMN stream_schema_version INTEGER NOT NULL DEFAULT 1" ) cursor.execute("PRAGMA table_info(agents)") agent_columns = {str(row[1]) for row in cursor.fetchall()} if "use_global_model_settings" not in agent_columns: cursor.execute( "ALTER TABLE agents " "ADD COLUMN use_global_model_settings INTEGER NOT NULL DEFAULT 1" ) if "avatar_type" not in agent_columns: cursor.execute( "ALTER TABLE agents " "ADD COLUMN avatar_type TEXT NOT NULL DEFAULT 'emoji'" ) if "avatar_image" not in agent_columns: cursor.execute("ALTER TABLE agents ADD COLUMN avatar_image TEXT") if "avatar_shape" not in agent_columns: cursor.execute( "ALTER TABLE agents " "ADD COLUMN avatar_shape TEXT NOT NULL DEFAULT 'circle'" ) if "banner_mode" not in agent_columns: cursor.execute( "ALTER TABLE agents " "ADD COLUMN banner_mode TEXT NOT NULL DEFAULT 'none'" ) if "banner_image" not in agent_columns: cursor.execute("ALTER TABLE agents ADD COLUMN banner_image TEXT") if "skill_ids" not in agent_columns: cursor.execute( "ALTER TABLE agents " "ADD COLUMN skill_ids TEXT NOT NULL DEFAULT '[]'" ) # Forward migration: ensure scrapbook table exists for pre-existing DBs. cursor.execute( "CREATE TABLE IF NOT EXISTS scrapbook (" "id TEXT PRIMARY KEY, " "title TEXT NOT NULL DEFAULT '', " "emoji TEXT, " "summary TEXT NOT NULL DEFAULT '', " "content TEXT NOT NULL DEFAULT '', " "source_url TEXT, " "platform TEXT NOT NULL DEFAULT 'manual', " "thumbnail TEXT, " "tags TEXT NOT NULL DEFAULT '[]', " "created_at TEXT NOT NULL, " "updated_at TEXT NOT NULL)" ) cursor.execute( "CREATE INDEX IF NOT EXISTS idx_scrapbook_created_at " "ON scrapbook(created_at DESC)" ) cursor.execute("PRAGMA table_info(scrapbook)") scrapbook_columns = {str(row[1]) for row in cursor.fetchall()} if "emoji" not in scrapbook_columns: cursor.execute("ALTER TABLE scrapbook ADD COLUMN emoji TEXT") cursor.execute("PRAGMA table_info(conversations)") conv_columns = {str(row[1]) for row in cursor.fetchall()} if "scrapbook_id" not in conv_columns: cursor.execute("ALTER TABLE conversations ADD COLUMN scrapbook_id TEXT") self._conn.commit() def _execute(self, sql: str, params: list[Any] | tuple[Any, ...] = ()) -> sqlite3.Cursor: with self._lock: cursor = self._conn.cursor() cursor.execute(sql, params) self._conn.commit() return cursor def _fetchall(self, sql: str, params: list[Any]) -> list[dict[str, Any]]: with self._lock: cursor = self._conn.cursor() cursor.execute(sql, params) rows = [dict(row) for row in cursor.fetchall()] return rows def _fetchone(self, sql: str, params: list[Any]) -> dict[str, Any] | None: with self._lock: cursor = self._conn.cursor() cursor.execute(sql, params) row = cursor.fetchone() return dict(row) if row else None def execute(self, req: DbQueryRequest) -> DbQueryResponse: action = req.action if action == "select": return self._select(req) if action == "insert": return self._insert(req) if action == "update": return self._update(req) if action == "delete": return self._delete(req) if action == "upsert": return self._upsert(req) if action == "rpc": return self._rpc(req) if action == "test": return self._test(req) return DbQueryResponse(error="Unsupported action") def _select(self, req: DbQueryRequest) -> DbQueryResponse: table = req.table if not table: return DbQueryResponse(error="Missing table") columns = _normalize_columns(req.columns) select_clause = "*" if not columns else ", ".join(columns) where_clause, params = _build_where_clause(req.filters) order_clause = "" if req.order: orders = [f"{o.column} {'ASC' if o.ascending else 'DESC'}" for o in req.order] order_clause = " ORDER BY " + ", ".join(orders) limit_clause = "" if req.range: limit = max(0, req.range.to - req.range.from_ + 1) limit_clause = f" LIMIT {limit} OFFSET {req.range.from_}" elif req.limit: limit_clause = f" LIMIT {req.limit}" sql = f"SELECT {select_clause} FROM {table} {where_clause}{order_clause}{limit_clause}" rows = self._fetchall(sql, params) data = [_deserialize_row(table, row) for row in rows] count = None if req.count == "exact": count_sql = f"SELECT COUNT(*) as count FROM {table} {where_clause}" count_row = self._fetchone(count_sql, params) count = int(count_row["count"]) if count_row else 0 if req.single or req.maybe_single: data = data[0] if data else None return DbQueryResponse(data=data, count=count) def _insert(self, req: DbQueryRequest) -> DbQueryResponse: table = req.table if not table: return DbQueryResponse(error="Missing table") values = req.values if values is None: return DbQueryResponse(error="Missing values") rows = values if isinstance(values, list) else [values] now = _utc_now_iso() prepared = [] for row in rows: payload = dict(row) if table in TABLES_WITH_ID and not payload.get("id"): payload["id"] = str(uuid.uuid4()) if table in TABLES_WITH_CREATED_AT or "created_at" in payload: payload.setdefault("created_at", now) if table in TABLES_WITH_UPDATED_AT or "updated_at" in payload: payload.setdefault("updated_at", now) prepared.append(payload) columns = sorted({key for row in prepared for key in row.keys()}) placeholders = ", ".join(["?"] * len(columns)) sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})" with self._lock: cursor = self._conn.cursor() for payload in prepared: params = [_serialize_value(table, col, payload.get(col)) for col in columns] cursor.execute(sql, params) if table == "conversation_messages": conv_id = payload.get("conversation_id") if conv_id: cursor.execute( "UPDATE conversations SET updated_at = ? WHERE id = ?", [now, conv_id], ) self._conn.commit() single_mode = bool(req.single or req.maybe_single) if req.columns or single_mode: ids = [row.get("id") for row in prepared if row.get("id")] data = None if ids: filters = [DbFilter(op="in", column="id", values=ids)] select_req = DbQueryRequest( providerId=req.provider_id, action="select", table=table, columns=req.columns, filters=filters, single=single_mode, ) return self._select(select_req) data = prepared[0] if single_mode else prepared return DbQueryResponse(data=data) return DbQueryResponse(data=prepared[0] if single_mode else prepared) def _update(self, req: DbQueryRequest) -> DbQueryResponse: table = req.table if not table: return DbQueryResponse(error="Missing table") payload = req.payload or {} if not payload: return DbQueryResponse(error="Missing payload") now = _utc_now_iso() if table in TABLES_WITH_UPDATED_AT and "updated_at" not in payload: payload["updated_at"] = now set_clause = ", ".join([f"{k} = ?" for k in payload.keys()]) params = [_serialize_value(table, k, v) for k, v in payload.items()] where_clause, where_params = _build_where_clause(req.filters) sql = f"UPDATE {table} SET {set_clause} {where_clause}" self._execute(sql, params + where_params) if table == "conversation_messages": conv_id = payload.get("conversation_id") if conv_id: self._execute( "UPDATE conversations SET updated_at = ? WHERE id = ?", [now, conv_id], ) if req.columns or req.single or req.maybe_single: select_req = DbQueryRequest( providerId=req.provider_id, action="select", table=table, columns=req.columns, filters=req.filters, single=bool(req.single or req.maybe_single), ) return self._select(select_req) return DbQueryResponse(data=None) def _delete(self, req: DbQueryRequest) -> DbQueryResponse: table = req.table if not table: return DbQueryResponse(error="Missing table") pending_cleanup_all = False pending_cleanup_conversation_ids: list[str] = [] if table == "conversations": if req.filters: pending_cleanup_conversation_ids = _extract_filter_values(req.filters, "id") else: pending_cleanup_all = True elif table == "conversation_messages": pending_cleanup_conversation_ids = _extract_filter_values(req.filters, "conversation_id") where_clause, params = _build_where_clause(req.filters) sql = f"DELETE FROM {table} {where_clause}" self._execute(sql, params) # Keep pending HITL runs in sync with conversation lifecycle. if pending_cleanup_all: self._execute("DELETE FROM pending_form_runs", []) elif pending_cleanup_conversation_ids: placeholders = ", ".join(["?"] * len(pending_cleanup_conversation_ids)) self._execute( f"DELETE FROM pending_form_runs WHERE conversation_id IN ({placeholders})", pending_cleanup_conversation_ids, ) return DbQueryResponse(data=None) def _upsert(self, req: DbQueryRequest) -> DbQueryResponse: table = req.table if not table: return DbQueryResponse(error="Missing table") values = req.values if values is None: return DbQueryResponse(error="Missing values") rows = values if isinstance(values, list) else [values] now = _utc_now_iso() prepared = [] for row in rows: payload = dict(row) if table in TABLES_WITH_ID and not payload.get("id"): payload["id"] = str(uuid.uuid4()) if table in TABLES_WITH_CREATED_AT: payload.setdefault("created_at", now) if table in TABLES_WITH_UPDATED_AT and "updated_at" not in payload: payload["updated_at"] = now prepared.append(payload) columns = sorted({key for row in prepared for key in row.keys()}) placeholders = ", ".join(["?"] * len(columns)) on_conflict_raw = req.on_conflict or ["id"] on_conflict = ( [str(item).strip() for item in on_conflict_raw if str(item).strip()] if isinstance(on_conflict_raw, list) else [part.strip() for part in str(on_conflict_raw).split(",") if part.strip()] ) update_cols = [col for col in columns if col not in on_conflict] update_clause = ", ".join([f"{col}=excluded.{col}" for col in update_cols]) sql = ( f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders}) " f"ON CONFLICT({', '.join(on_conflict)}) DO UPDATE SET {update_clause}" ) with self._lock: cursor = self._conn.cursor() for payload in prepared: params = [_serialize_value(table, col, payload.get(col)) for col in columns] cursor.execute(sql, params) self._conn.commit() single_mode = bool(req.single or req.maybe_single) if req.columns or single_mode: ids = [row.get("id") for row in prepared if row.get("id")] filters = [DbFilter(op="in", column="id", values=ids)] if ids else None select_req = DbQueryRequest( providerId=req.provider_id, action="select", table=table, columns=req.columns, filters=filters, single=single_mode, ) return self._select(select_req) return DbQueryResponse(data=prepared[0] if single_mode else prepared) def _rpc(self, req: DbQueryRequest) -> DbQueryResponse: if not req.rpc: return DbQueryResponse(error="Missing rpc") if req.rpc.name == "match_document_chunks": return self._rpc_match_document_chunks(req.rpc.params or {}) if req.rpc.name == "hybrid_search": return self._rpc_hybrid_search(req.rpc.params or {}) return DbQueryResponse(error=f"Unsupported rpc: {req.rpc.name}") def _rpc_match_document_chunks(self, params: dict[str, Any]) -> DbQueryResponse: document_ids = params.get("document_ids") or [] query_embedding = params.get("query_embedding") or [] match_count = int(params.get("match_count") or 3) if not document_ids or not query_embedding: return DbQueryResponse(data=[]) rows = self._fetchall( "SELECT id, document_id, section_id, title_path, text, source_hint, chunk_index, embedding " "FROM document_chunks WHERE document_id IN ({})".format( ", ".join(["?"] * len(document_ids)) ), [str(i) for i in document_ids], ) scored = [] for row in rows: try: embedding = json.loads(row.get("embedding") or "[]") except Exception: embedding = [] if not embedding or len(embedding) != len(query_embedding): continue dot = sum(a * b for a, b in zip(embedding, query_embedding)) norm_a = sum(a * a for a in embedding) ** 0.5 norm_b = sum(b * b for b in query_embedding) ** 0.5 similarity = dot / (norm_a * norm_b) if norm_a and norm_b else 0.0 row["similarity"] = similarity scored.append(row) scored.sort(key=lambda r: r["similarity"], reverse=True) limited = scored[: max(match_count, 1)] data = [] for row in limited: entry = _deserialize_row("document_chunks", row) entry["similarity"] = row["similarity"] data.append(entry) return DbQueryResponse(data=data) def _rpc_hybrid_search(self, params: dict[str, Any]) -> DbQueryResponse: document_ids = params.get("document_ids") or [] query_text = str(params.get("query_text") or "").strip().lower() query_embedding = params.get("query_embedding") or [] match_count = int(params.get("match_count") or 10) if not document_ids or not query_text or not query_embedding: return DbQueryResponse(data=[]) rows = self._fetchall( "SELECT id, document_id, section_id, title_path, text, source_hint, chunk_index, embedding " "FROM document_chunks WHERE document_id IN ({})".format( ", ".join(["?"] * len(document_ids)) ), [str(i) for i in document_ids], ) scored = [] for row in rows: try: embedding = json.loads(row.get("embedding") or "[]") except Exception: embedding = [] if not embedding or len(embedding) != len(query_embedding): continue dot = sum(a * b for a, b in zip(embedding, query_embedding)) norm_a = sum(a * a for a in embedding) ** 0.5 norm_b = sum(b * b for b in query_embedding) ** 0.5 similarity = dot / (norm_a * norm_b) if norm_a and norm_b else 0.0 text = (row.get("text") or "").lower() fts_score = text.count(query_text) score = similarity + (0.1 * fts_score) row["similarity"] = similarity row["fts_score"] = fts_score row["score"] = score scored.append(row) scored.sort(key=lambda r: r["score"], reverse=True) limited = scored[: max(match_count, 1)] data = [] for row in limited: entry = _deserialize_row("document_chunks", row) entry["similarity"] = row["similarity"] entry["fts_score"] = row["fts_score"] entry["score"] = row["score"] data.append(entry) return DbQueryResponse(data=data) def _test(self, req: DbQueryRequest) -> DbQueryResponse: tables = [ "spaces", "agents", "space_agents", "conversations", "conversation_messages", "space_documents", "conversation_documents", "document_sections", "document_chunks", "user_settings", "memory_domains", "memory_summaries", "user_tools", "home_notes", "home_shortcuts", ] results = {} for table in tables: try: self._fetchone(f"SELECT 1 FROM {table} LIMIT 1", []) results[table] = True except Exception: results[table] = False all_ok = all(results.values()) return DbQueryResponse( data={ "success": all_ok, "connection": True, "tables": results, "message": "Connection successful; required tables are present." if all_ok else "Connection OK, but missing tables.", } ) @dataclass class SQLAlchemyAdapter: config: ProviderConfig def __post_init__(self) -> None: if not self.config.connection_url: raise ValueError(f"{self.config.type} provider missing connection url") self._engine = create_engine(self.config.connection_url, future=True, pool_pre_ping=True) self._metadata = MetaData() self._table_cache: dict[str, Any] = {} self._lock = threading.Lock() def _get_table(self, table_name: str): with self._lock: if table_name in self._table_cache: return self._table_cache[table_name] table_obj = self._metadata.tables.get(table_name) if table_obj is None: self._metadata.reflect(bind=self._engine, only=[table_name], extend_existing=True) table_obj = self._metadata.tables.get(table_name) if table_obj is None: raise ValueError(f"Unknown table: {table_name}") self._table_cache[table_name] = table_obj return table_obj def _apply_filters(self, stmt, table_obj, filters: list[DbFilter] | None): expressions = [_build_sa_expression(table_obj, filt) for filt in (filters or [])] expressions = [expr for expr in expressions if expr is not None] if expressions: stmt = stmt.where(and_(*expressions)) return stmt def execute(self, req: DbQueryRequest) -> DbQueryResponse: try: if req.action == "test": return self._test() if req.action == "rpc": return DbQueryResponse(error=f"RPC is not supported for provider type '{self.config.type}'") if not req.table: return DbQueryResponse(error="Missing table") if req.action == "select": return self._select(req) if req.action == "insert": return self._insert(req) if req.action == "update": return self._update(req) if req.action == "delete": return self._delete(req) if req.action == "upsert": return self._upsert(req) return DbQueryResponse(error="Unsupported action") except Exception as exc: logger.error("%s adapter error: %s", self.config.type, exc) return DbQueryResponse(error=str(exc)) def _select(self, req: DbQueryRequest) -> DbQueryResponse: table_obj = self._get_table(req.table) columns = _normalize_columns(req.columns) selected_columns = [table_obj.c[col] for col in columns if col in table_obj.c] if columns else [table_obj] stmt = select(*selected_columns) stmt = self._apply_filters(stmt, table_obj, req.filters) if req.order: for order in req.order: if order.column in table_obj.c: column = table_obj.c[order.column] stmt = stmt.order_by(column.asc() if order.ascending else column.desc()) if req.range: stmt = stmt.offset(req.range.from_).limit(max(0, req.range.to - req.range.from_ + 1)) elif req.limit: stmt = stmt.limit(req.limit) with self._engine.begin() as conn: rows = [dict(row._mapping) for row in conn.execute(stmt).fetchall()] count = None if req.count == "exact": count_stmt = select(func.count()).select_from(table_obj) count_stmt = self._apply_filters(count_stmt, table_obj, req.filters) count = int(conn.execute(count_stmt).scalar_one() or 0) data: Any = _deserialize_rows(req.table, rows) if req.single or req.maybe_single: data = data[0] if data else None return DbQueryResponse(data=data, count=count) def _prepare_rows(self, table: str, values: list[dict[str, Any]] | dict[str, Any]) -> list[dict[str, Any]]: rows = values if isinstance(values, list) else [values] now = _utc_now_iso() prepared = [] for row in rows: payload = dict(row) if table in TABLES_WITH_ID and not payload.get("id"): payload["id"] = str(uuid.uuid4()) if table in TABLES_WITH_CREATED_AT or "created_at" in payload: payload.setdefault("created_at", now) if table in TABLES_WITH_UPDATED_AT or "updated_at" in payload: payload.setdefault("updated_at", now) prepared.append(_prepare_payload(table, payload)) return prepared def _insert(self, req: DbQueryRequest) -> DbQueryResponse: values = req.values if req.values is not None else req.payload if values is None: return DbQueryResponse(error="Missing values") table_obj = self._get_table(req.table) prepared = self._prepare_rows(req.table, values) with self._engine.begin() as conn: conn.execute(table_obj.insert(), prepared) data: Any = _deserialize_rows(req.table, prepared) if req.single or req.maybe_single: data = data[0] if data else None return DbQueryResponse(data=data) def _update(self, req: DbQueryRequest) -> DbQueryResponse: payload = dict(req.payload or {}) if not payload: return DbQueryResponse(error="Missing payload") if req.table in TABLES_WITH_UPDATED_AT and "updated_at" not in payload: payload["updated_at"] = _utc_now_iso() table_obj = self._get_table(req.table) stmt = update(table_obj).values(**_prepare_payload(req.table, payload)) stmt = self._apply_filters(stmt, table_obj, req.filters) with self._engine.begin() as conn: conn.execute(stmt) if req.columns or req.single or req.maybe_single: return self._select( DbQueryRequest( providerId=req.provider_id, action="select", table=req.table, columns=req.columns, filters=req.filters, single=bool(req.single or req.maybe_single), ) ) return DbQueryResponse(data=None) def _delete(self, req: DbQueryRequest) -> DbQueryResponse: table_obj = self._get_table(req.table) stmt = delete(table_obj) stmt = self._apply_filters(stmt, table_obj, req.filters) with self._engine.begin() as conn: conn.execute(stmt) return DbQueryResponse(data=None) def _upsert(self, req: DbQueryRequest) -> DbQueryResponse: values = req.values if req.values is not None else req.payload if values is None: return DbQueryResponse(error="Missing values") table_obj = self._get_table(req.table) prepared = self._prepare_rows(req.table, values) on_conflict_raw = req.on_conflict or ["id"] on_conflict = ( [str(item).strip() for item in on_conflict_raw if str(item).strip()] if isinstance(on_conflict_raw, list) else [part.strip() for part in str(on_conflict_raw).split(",") if part.strip()] ) update_cols = [col.name for col in table_obj.columns if col.name not in on_conflict] dialect = self._engine.dialect.name if dialect == "postgresql": stmt = postgres_insert(table_obj).values(prepared) stmt = stmt.on_conflict_do_update( index_elements=on_conflict, set_={col: getattr(stmt.excluded, col) for col in update_cols}, ) elif dialect in {"mysql", "mariadb"}: stmt = mysql_insert(table_obj).values(prepared) stmt = stmt.on_duplicate_key_update( **{col: getattr(stmt.inserted, col) for col in update_cols} ) else: return DbQueryResponse( error=f"Upsert is not supported for SQL dialect '{dialect}' on provider type '{self.config.type}'" ) with self._engine.begin() as conn: conn.execute(stmt) data: Any = _deserialize_rows(req.table, prepared) if req.single or req.maybe_single: data = data[0] if data else None return DbQueryResponse(data=data) def _test(self) -> DbQueryResponse: inspector = inspect(self._engine) table_names = set(inspector.get_table_names()) tables = [ "spaces", "agents", "space_agents", "conversations", "conversation_messages", "space_documents", "conversation_documents", "document_sections", "document_chunks", "user_settings", "memory_domains", "memory_summaries", "user_tools", "home_notes", "home_shortcuts", "pending_form_runs", "scrapbook", ] results = {table: table in table_names for table in tables} all_ok = all(results.values()) return DbQueryResponse( data={ "success": all_ok, "connection": True, "tables": results, "message": "Connection successful; required tables are present." if all_ok else "Connection OK, but missing tables.", } ) @dataclass class SupabaseAdapter: config: ProviderConfig def __post_init__(self) -> None: from supabase import create_client if not self.config.supabase_url or not self.config.supabase_anon_key: raise ValueError("Supabase provider missing url or anon key") self._client = create_client(self.config.supabase_url, self.config.supabase_anon_key) def execute(self, req: DbQueryRequest) -> DbQueryResponse: try: if req.action == "test": return self._test() if req.action == "rpc": return self._rpc(req) if not req.table: return DbQueryResponse(error="Missing table") if req.action == "select": return self._select(req) if req.action == "insert": return self._insert(req) if req.action == "update": return self._update(req) if req.action == "delete": return self._delete(req) if req.action == "upsert": return self._upsert(req) return DbQueryResponse(error="Unsupported action") except Exception as exc: logger.error("Supabase adapter error: %s", exc) return DbQueryResponse(error=str(exc)) def _table(self, table: str): if hasattr(self._client, "table"): return self._client.table(table) if hasattr(self._client, "from_"): return self._client.from_(table) raise AttributeError("Supabase client has no table/from_ method") def _apply_filters(self, query, filters: list[DbFilter] | None): if not filters: return query for filt in filters: if filt.op == "or" and filt.filters: or_parts = [] for inner in filt.filters: if inner.op == "is_null": or_parts.append(f"{inner.column}.is.null") elif inner.op == "not_in": values = ",".join([str(v) for v in (inner.values or [])]) or_parts.append(f"{inner.column}.not.in.({values})") elif inner.op == "eq": or_parts.append(f"{inner.column}.eq.{inner.value}") if hasattr(query, "or_") and or_parts: query = query.or_(",".join(or_parts)) continue col = filt.column if not col: continue if filt.op == "eq": query = query.eq(col, filt.value) elif filt.op == "gt": query = query.gt(col, filt.value) elif filt.op == "lt": query = query.lt(col, filt.value) elif filt.op == "ilike": query = query.ilike(col, f"%{filt.value}%") elif filt.op == "in": query = query.in_(col, filt.values or []) elif filt.op == "not_in": if hasattr(query, "not_"): query = query.not_.in_(col, filt.values or []) elif filt.op == "is_null": if hasattr(query, "is_"): query = query.is_(col, "null") return query def _select(self, req: DbQueryRequest) -> DbQueryResponse: query = self._table(req.table) columns = req.columns or "*" if isinstance(columns, list): columns = ",".join([str(col).strip() for col in columns if str(col).strip()]) or "*" if req.count: query = query.select(columns, count=req.count) else: query = query.select(columns) query = self._apply_filters(query, req.filters) if req.order: for order in req.order: query = query.order(order.column, desc=not order.ascending) if req.range: query = query.range(req.range.from_, req.range.to) elif req.limit: query = query.limit(req.limit) if req.maybe_single: if hasattr(query, "maybe_single"): query = query.maybe_single() elif hasattr(query, "maybeSingle"): query = query.maybeSingle() else: # Keep best-effort maybe-single semantics without forcing object coercion. if not req.limit: query = query.limit(1) elif req.single and hasattr(query, "single"): query = query.single() try: result = query.execute() except Exception as exc: if req.maybe_single: error_text = str(exc) if ( "PGRST116" in error_text or "Cannot coerce the result to a single JSON object" in error_text or "The result contains 0 rows" in error_text ): return DbQueryResponse(data=None, count=None) raise data = getattr(result, "data", None) count = getattr(result, "count", None) error = getattr(result, "error", None) if req.maybe_single and isinstance(data, list): data = data[0] if data else None if error and req.maybe_single: error_text = str(error) if ( "PGRST116" in error_text or "Cannot coerce the result to a single JSON object" in error_text or "The result contains 0 rows" in error_text ): return DbQueryResponse(data=None, count=count) if error: return DbQueryResponse(error=str(error)) return DbQueryResponse(data=data, count=count) def _insert(self, req: DbQueryRequest) -> DbQueryResponse: query = self._table(req.table) values = req.values if req.values is not None else req.payload if values is None: return DbQueryResponse(error="Missing values") # Use upsert with ignoreDuplicates=False to get data back, or just insert # Supabase insert returns 204 by default; we handle this gracefully try: result = query.insert(values).execute() except Exception as exc: # 204 No Content is not an error — insert succeeded but no data returned if "204" in str(exc) or "Missing response" in str(exc): return DbQueryResponse(data=None) raise error = getattr(result, "error", None) if error: return DbQueryResponse(error=str(error)) data = getattr(result, "data", None) if (req.single or req.maybe_single) and isinstance(data, list): data = data[0] if data else None return DbQueryResponse(data=data) def _update(self, req: DbQueryRequest) -> DbQueryResponse: query = self._table(req.table) payload = req.payload or {} query = query.update(payload) query = self._apply_filters(query, req.filters) # Supabase update returns 204 by default; handle gracefully try: result = query.execute() except Exception as exc: if "204" in str(exc) or "Missing response" in str(exc): return DbQueryResponse(data=None) raise error = getattr(result, "error", None) if error: return DbQueryResponse(error=str(error)) data = getattr(result, "data", None) if (req.single or req.maybe_single) and isinstance(data, list): data = data[0] if data else None return DbQueryResponse(data=data) def _delete(self, req: DbQueryRequest) -> DbQueryResponse: pending_cleanup_all = False pending_cleanup_conversation_ids: list[str] = [] if req.table == "conversations": if req.filters: pending_cleanup_conversation_ids = _extract_filter_values(req.filters, "id") else: pending_cleanup_all = True elif req.table == "conversation_messages": pending_cleanup_conversation_ids = _extract_filter_values(req.filters, "conversation_id") query = self._table(req.table) query = query.delete() query = self._apply_filters(query, req.filters) result = query.execute() error = getattr(result, "error", None) if error: return DbQueryResponse(error=str(error)) # Keep pending HITL runs in sync with conversation lifecycle. if pending_cleanup_all: self._table("pending_form_runs").delete().execute() elif pending_cleanup_conversation_ids: self._table("pending_form_runs").delete().in_( "conversation_id", pending_cleanup_conversation_ids ).execute() return DbQueryResponse(data=getattr(result, "data", None)) def _upsert(self, req: DbQueryRequest) -> DbQueryResponse: query = self._table(req.table) values = req.values if req.values is not None else req.payload if values is None: return DbQueryResponse(error="Missing values") on_conflict = req.on_conflict if isinstance(on_conflict, list): on_conflict = ",".join([str(item).strip() for item in on_conflict if str(item).strip()]) query = query.upsert(values, on_conflict=on_conflict) result = query.execute() error = getattr(result, "error", None) if error: return DbQueryResponse(error=str(error)) data = getattr(result, "data", None) if (req.single or req.maybe_single) and isinstance(data, list): data = data[0] if data else None return DbQueryResponse(data=data) def _rpc(self, req: DbQueryRequest) -> DbQueryResponse: if not req.rpc: return DbQueryResponse(error="Missing rpc") result = self._client.rpc(req.rpc.name, req.rpc.params or {}).execute() error = getattr(result, "error", None) if error: return DbQueryResponse(error=str(error)) return DbQueryResponse(data=getattr(result, "data", None)) def _test(self) -> DbQueryResponse: table_fields = { "spaces": "id", "agents": "id", "space_agents": "space_id", "conversations": "id", "conversation_messages": "id", "space_documents": "id", "conversation_documents": "conversation_id", "document_sections": "id", "document_chunks": "id", "user_settings": "key", "memory_domains": "id", "memory_summaries": "id", "user_tools": "id", "home_notes": "id", "home_shortcuts": "id", } results = {} for table, field in table_fields.items(): try: query = self._table(table).select(field).limit(1) result = query.execute() results[table] = getattr(result, "error", None) is None except Exception: results[table] = False all_ok = all(results.values()) return DbQueryResponse( data={ "success": all_ok, "connection": True, "tables": results, "message": "Connection successful; required tables are present." if all_ok else "Connection OK, but missing tables.", } ) def build_adapter(config: ProviderConfig): if config.type == "sqlite": return SQLiteAdapter(config) if config.type == "supabase": return SupabaseAdapter(config) if config.type in {"postgres", "mysql", "mariadb"}: return SQLAlchemyAdapter(config) raise ValueError("Unsupported provider type")