Spaces:
Running
Running
| """ | |
| Database adapters for Supabase and SQLite providers. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sqlite3 | |
| import threading | |
| import uuid | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from typing import Any | |
| from agno.utils.log import logger | |
| from sqlalchemy import MetaData, and_, create_engine, delete, func, inspect, or_, select, update | |
| from sqlalchemy.dialects.mysql import insert as mysql_insert | |
| from sqlalchemy.dialects.postgresql import insert as postgres_insert | |
| from ..models.db import DbFilter, DbQueryRequest, DbQueryResponse | |
| from .db_registry import ProviderConfig | |
| from .sqlite_schema import SCHEMA_STATEMENTS | |
| def _utc_now_iso() -> str: | |
| return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" | |
| JSON_COLUMNS: dict[str, set[str]] = { | |
| "agents": {"tool_ids", "skill_ids"}, | |
| "conversations": {"title_emojis", "session_summary"}, | |
| "conversation_messages": { | |
| "content", | |
| "tool_calls", | |
| "tool_call_history", | |
| "research_step_history", | |
| "related_questions", | |
| "sources", | |
| "document_sources", | |
| "grounding_supports", | |
| "stream_blocks", | |
| }, | |
| "conversation_events": {"payload"}, | |
| "attachments": {"data"}, | |
| "document_sections": {"title_path", "loc"}, | |
| "document_chunks": {"title_path", "loc", "embedding"}, | |
| "memory_domains": {"aliases"}, | |
| "user_tools": {"config", "input_schema"}, | |
| "pending_form_runs": {"requirements_data", "messages"}, | |
| "scrapbook": {"tags"}, | |
| } | |
| TABLES_WITH_ID = { | |
| "spaces", | |
| "agents", | |
| "conversations", | |
| "conversation_messages", | |
| "conversation_events", | |
| "attachments", | |
| "space_documents", | |
| "document_sections", | |
| "document_chunks", | |
| "home_notes", | |
| "home_shortcuts", | |
| "memory_domains", | |
| "memory_summaries", | |
| "user_tools", | |
| "pending_form_runs", | |
| "email_provider_configs", | |
| "email_notifications", | |
| "scrapbook", | |
| } | |
| TABLES_WITH_UPDATED_AT = { | |
| "spaces", | |
| "agents", | |
| "conversations", | |
| "space_documents", | |
| "document_sections", | |
| "document_chunks", | |
| "home_notes", | |
| "home_shortcuts", | |
| "user_settings", | |
| "memory_domains", | |
| "memory_summaries", | |
| "user_tools", | |
| "email_provider_configs", | |
| "scrapbook", | |
| } | |
| TABLES_WITH_CREATED_AT = { | |
| "spaces", | |
| "agents", | |
| "conversations", | |
| "conversation_messages", | |
| "conversation_events", | |
| "attachments", | |
| "space_documents", | |
| "conversation_documents", | |
| "document_sections", | |
| "document_chunks", | |
| "space_agents", | |
| "home_notes", | |
| "home_shortcuts", | |
| "memory_domains", | |
| "memory_summaries", | |
| "user_tools", | |
| "pending_form_runs", | |
| "email_provider_configs", | |
| "email_notifications", | |
| "scrapbook", | |
| } | |
| def _serialize_value(table: str, column: str, value: Any) -> Any: | |
| if value is None: | |
| return None | |
| if table in JSON_COLUMNS and column in JSON_COLUMNS[table]: | |
| try: | |
| return json.dumps(value, ensure_ascii=False) | |
| except TypeError: | |
| return json.dumps(str(value), ensure_ascii=False) | |
| if isinstance(value, bool): | |
| return 1 if value else 0 | |
| return value | |
| def _deserialize_row(table: str, row: dict[str, Any]) -> dict[str, Any]: | |
| if table in JSON_COLUMNS: | |
| for column in JSON_COLUMNS[table]: | |
| if column in row and row[column] is not None: | |
| try: | |
| row[column] = json.loads(row[column]) | |
| except Exception: | |
| pass | |
| # SQLite stores bool as int | |
| for key, value in row.items(): | |
| if isinstance(value, int) and key.startswith("is_"): | |
| row[key] = bool(value) | |
| return row | |
| def _prepare_payload(table: str, payload: dict[str, Any]) -> dict[str, Any]: | |
| return {key: _serialize_value(table, key, value) for key, value in payload.items()} | |
| def _deserialize_rows(table: str, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| return [_deserialize_row(table, dict(row)) for row in rows] | |
| def _normalize_columns(columns: str | list[str] | None) -> list[str] | None: | |
| if columns is None: | |
| return None | |
| if isinstance(columns, list): | |
| return columns | |
| if isinstance(columns, str): | |
| trimmed = columns.strip() | |
| if trimmed == "*": | |
| return None | |
| return [col.strip() for col in trimmed.split(",") if col.strip()] | |
| return None | |
| def _build_where_clause(filters: list[DbFilter] | None) -> tuple[str, list[Any]]: | |
| if not filters: | |
| return "", [] | |
| def build(f: DbFilter) -> tuple[str, list[Any]]: | |
| if f.op == "or" and f.filters: | |
| or_clauses: list[str] = [] | |
| or_params: list[Any] = [] | |
| for inner in f.filters: | |
| clause, params = build(inner) | |
| if clause: | |
| or_clauses.append(clause) | |
| or_params.extend(params) | |
| if not or_clauses: | |
| return "", [] | |
| return f"({ ' OR '.join(or_clauses) })", or_params | |
| column = f.column | |
| if not column: | |
| return "", [] | |
| if f.op == "eq": | |
| return f"{column} = ?", [f.value] | |
| if f.op == "gt": | |
| return f"{column} > ?", [f.value] | |
| if f.op == "lt": | |
| return f"{column} < ?", [f.value] | |
| if f.op == "ilike": | |
| return f"LOWER({column}) LIKE LOWER(?)", [f"%{f.value}%"] | |
| if f.op == "is_null": | |
| return f"{column} IS NULL", [] | |
| if f.op in {"in", "not_in"}: | |
| values = f.values or [] | |
| if not values: | |
| return ("1=0", []) if f.op == "in" else ("1=1", []) | |
| placeholders = ", ".join(["?"] * len(values)) | |
| operator = "IN" if f.op == "in" else "NOT IN" | |
| return f"{column} {operator} ({placeholders})", list(values) | |
| return "", [] | |
| clauses: list[str] = [] | |
| params: list[Any] = [] | |
| for filt in filters: | |
| clause, clause_params = build(filt) | |
| if clause: | |
| clauses.append(clause) | |
| params.extend(clause_params) | |
| if not clauses: | |
| return "", [] | |
| return "WHERE " + " AND ".join(clauses), params | |
| def _extract_filter_values(filters: list[DbFilter] | None, column: str) -> list[str]: | |
| if not filters: | |
| return [] | |
| values: list[str] = [] | |
| def walk(f: DbFilter) -> None: | |
| if f.op == "or" and f.filters: | |
| for inner in f.filters: | |
| walk(inner) | |
| return | |
| if f.column != column: | |
| return | |
| if f.op == "eq" and f.value is not None: | |
| values.append(str(f.value)) | |
| return | |
| if f.op == "in" and f.values: | |
| for item in f.values: | |
| if item is not None: | |
| values.append(str(item)) | |
| for filt in filters: | |
| walk(filt) | |
| return list(dict.fromkeys(values)) | |
| def _build_sa_expression(table_obj, filt: DbFilter): | |
| if filt.op == "or" and filt.filters: | |
| expressions = [_build_sa_expression(table_obj, inner) for inner in filt.filters] | |
| expressions = [expr for expr in expressions if expr is not None] | |
| return or_(*expressions) if expressions else None | |
| column_name = filt.column | |
| if not column_name or column_name not in table_obj.c: | |
| return None | |
| column = table_obj.c[column_name] | |
| if filt.op == "eq": | |
| return column == filt.value | |
| if filt.op == "gt": | |
| return column > filt.value | |
| if filt.op == "lt": | |
| return column < filt.value | |
| if filt.op == "ilike": | |
| return column.ilike(f"%{filt.value}%") | |
| if filt.op == "is_null": | |
| return column.is_(None) | |
| if filt.op == "in": | |
| return column.in_(filt.values or []) | |
| if filt.op == "not_in": | |
| return ~column.in_(filt.values or []) | |
| return None | |
| class SQLiteAdapter: | |
| config: ProviderConfig | |
| def __post_init__(self) -> None: | |
| self._lock = threading.Lock() | |
| if not self.config.sqlite_path: | |
| raise ValueError("SQLite provider missing path") | |
| os.makedirs(os.path.dirname(self.config.sqlite_path) or ".", exist_ok=True) | |
| self._conn = sqlite3.connect( | |
| self.config.sqlite_path, | |
| check_same_thread=False, | |
| timeout=5.0, | |
| ) | |
| self._conn.row_factory = sqlite3.Row | |
| self._configure_connection() | |
| self._ensure_schema() | |
| def _configure_connection(self) -> None: | |
| """ | |
| Tune SQLite for mixed read/write concurrency. | |
| """ | |
| with self._lock: | |
| cursor = self._conn.cursor() | |
| cursor.execute("PRAGMA journal_mode=WAL;") | |
| cursor.execute("PRAGMA synchronous=NORMAL;") | |
| cursor.execute("PRAGMA busy_timeout=5000;") | |
| cursor.execute("PRAGMA temp_store=MEMORY;") | |
| self._conn.commit() | |
| def _ensure_schema(self) -> None: | |
| with self._lock: | |
| cursor = self._conn.cursor() | |
| for stmt in SCHEMA_STATEMENTS: | |
| cursor.executescript(stmt) | |
| # Lightweight forward migrations for existing local DBs. | |
| cursor.execute("PRAGMA table_info(conversation_messages)") | |
| columns = {str(row[1]) for row in cursor.fetchall()} | |
| if "stream_blocks" not in columns: | |
| cursor.execute( | |
| "ALTER TABLE conversation_messages " | |
| "ADD COLUMN stream_blocks TEXT NOT NULL DEFAULT '[]'" | |
| ) | |
| if "stream_schema_version" not in columns: | |
| cursor.execute( | |
| "ALTER TABLE conversation_messages " | |
| "ADD COLUMN stream_schema_version INTEGER NOT NULL DEFAULT 1" | |
| ) | |
| cursor.execute("PRAGMA table_info(agents)") | |
| agent_columns = {str(row[1]) for row in cursor.fetchall()} | |
| if "use_global_model_settings" not in agent_columns: | |
| cursor.execute( | |
| "ALTER TABLE agents " | |
| "ADD COLUMN use_global_model_settings INTEGER NOT NULL DEFAULT 1" | |
| ) | |
| if "avatar_type" not in agent_columns: | |
| cursor.execute( | |
| "ALTER TABLE agents " | |
| "ADD COLUMN avatar_type TEXT NOT NULL DEFAULT 'emoji'" | |
| ) | |
| if "avatar_image" not in agent_columns: | |
| cursor.execute("ALTER TABLE agents ADD COLUMN avatar_image TEXT") | |
| if "avatar_shape" not in agent_columns: | |
| cursor.execute( | |
| "ALTER TABLE agents " | |
| "ADD COLUMN avatar_shape TEXT NOT NULL DEFAULT 'circle'" | |
| ) | |
| if "banner_mode" not in agent_columns: | |
| cursor.execute( | |
| "ALTER TABLE agents " | |
| "ADD COLUMN banner_mode TEXT NOT NULL DEFAULT 'none'" | |
| ) | |
| if "banner_image" not in agent_columns: | |
| cursor.execute("ALTER TABLE agents ADD COLUMN banner_image TEXT") | |
| if "skill_ids" not in agent_columns: | |
| cursor.execute( | |
| "ALTER TABLE agents " | |
| "ADD COLUMN skill_ids TEXT NOT NULL DEFAULT '[]'" | |
| ) | |
| # Forward migration: ensure scrapbook table exists for pre-existing DBs. | |
| cursor.execute( | |
| "CREATE TABLE IF NOT EXISTS scrapbook (" | |
| "id TEXT PRIMARY KEY, " | |
| "title TEXT NOT NULL DEFAULT '', " | |
| "emoji TEXT, " | |
| "summary TEXT NOT NULL DEFAULT '', " | |
| "content TEXT NOT NULL DEFAULT '', " | |
| "source_url TEXT, " | |
| "platform TEXT NOT NULL DEFAULT 'manual', " | |
| "thumbnail TEXT, " | |
| "tags TEXT NOT NULL DEFAULT '[]', " | |
| "created_at TEXT NOT NULL, " | |
| "updated_at TEXT NOT NULL)" | |
| ) | |
| cursor.execute( | |
| "CREATE INDEX IF NOT EXISTS idx_scrapbook_created_at " | |
| "ON scrapbook(created_at DESC)" | |
| ) | |
| cursor.execute("PRAGMA table_info(scrapbook)") | |
| scrapbook_columns = {str(row[1]) for row in cursor.fetchall()} | |
| if "emoji" not in scrapbook_columns: | |
| cursor.execute("ALTER TABLE scrapbook ADD COLUMN emoji TEXT") | |
| cursor.execute("PRAGMA table_info(conversations)") | |
| conv_columns = {str(row[1]) for row in cursor.fetchall()} | |
| if "scrapbook_id" not in conv_columns: | |
| cursor.execute("ALTER TABLE conversations ADD COLUMN scrapbook_id TEXT") | |
| self._conn.commit() | |
| def _execute(self, sql: str, params: list[Any] | tuple[Any, ...] = ()) -> sqlite3.Cursor: | |
| with self._lock: | |
| cursor = self._conn.cursor() | |
| cursor.execute(sql, params) | |
| self._conn.commit() | |
| return cursor | |
| def _fetchall(self, sql: str, params: list[Any]) -> list[dict[str, Any]]: | |
| with self._lock: | |
| cursor = self._conn.cursor() | |
| cursor.execute(sql, params) | |
| rows = [dict(row) for row in cursor.fetchall()] | |
| return rows | |
| def _fetchone(self, sql: str, params: list[Any]) -> dict[str, Any] | None: | |
| with self._lock: | |
| cursor = self._conn.cursor() | |
| cursor.execute(sql, params) | |
| row = cursor.fetchone() | |
| return dict(row) if row else None | |
| def execute(self, req: DbQueryRequest) -> DbQueryResponse: | |
| action = req.action | |
| if action == "select": | |
| return self._select(req) | |
| if action == "insert": | |
| return self._insert(req) | |
| if action == "update": | |
| return self._update(req) | |
| if action == "delete": | |
| return self._delete(req) | |
| if action == "upsert": | |
| return self._upsert(req) | |
| if action == "rpc": | |
| return self._rpc(req) | |
| if action == "test": | |
| return self._test(req) | |
| return DbQueryResponse(error="Unsupported action") | |
| def _select(self, req: DbQueryRequest) -> DbQueryResponse: | |
| table = req.table | |
| if not table: | |
| return DbQueryResponse(error="Missing table") | |
| columns = _normalize_columns(req.columns) | |
| select_clause = "*" if not columns else ", ".join(columns) | |
| where_clause, params = _build_where_clause(req.filters) | |
| order_clause = "" | |
| if req.order: | |
| orders = [f"{o.column} {'ASC' if o.ascending else 'DESC'}" for o in req.order] | |
| order_clause = " ORDER BY " + ", ".join(orders) | |
| limit_clause = "" | |
| if req.range: | |
| limit = max(0, req.range.to - req.range.from_ + 1) | |
| limit_clause = f" LIMIT {limit} OFFSET {req.range.from_}" | |
| elif req.limit: | |
| limit_clause = f" LIMIT {req.limit}" | |
| sql = f"SELECT {select_clause} FROM {table} {where_clause}{order_clause}{limit_clause}" | |
| rows = self._fetchall(sql, params) | |
| data = [_deserialize_row(table, row) for row in rows] | |
| count = None | |
| if req.count == "exact": | |
| count_sql = f"SELECT COUNT(*) as count FROM {table} {where_clause}" | |
| count_row = self._fetchone(count_sql, params) | |
| count = int(count_row["count"]) if count_row else 0 | |
| if req.single or req.maybe_single: | |
| data = data[0] if data else None | |
| return DbQueryResponse(data=data, count=count) | |
| def _insert(self, req: DbQueryRequest) -> DbQueryResponse: | |
| table = req.table | |
| if not table: | |
| return DbQueryResponse(error="Missing table") | |
| values = req.values | |
| if values is None: | |
| return DbQueryResponse(error="Missing values") | |
| rows = values if isinstance(values, list) else [values] | |
| now = _utc_now_iso() | |
| prepared = [] | |
| for row in rows: | |
| payload = dict(row) | |
| if table in TABLES_WITH_ID and not payload.get("id"): | |
| payload["id"] = str(uuid.uuid4()) | |
| if table in TABLES_WITH_CREATED_AT or "created_at" in payload: | |
| payload.setdefault("created_at", now) | |
| if table in TABLES_WITH_UPDATED_AT or "updated_at" in payload: | |
| payload.setdefault("updated_at", now) | |
| prepared.append(payload) | |
| columns = sorted({key for row in prepared for key in row.keys()}) | |
| placeholders = ", ".join(["?"] * len(columns)) | |
| sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})" | |
| with self._lock: | |
| cursor = self._conn.cursor() | |
| for payload in prepared: | |
| params = [_serialize_value(table, col, payload.get(col)) for col in columns] | |
| cursor.execute(sql, params) | |
| if table == "conversation_messages": | |
| conv_id = payload.get("conversation_id") | |
| if conv_id: | |
| cursor.execute( | |
| "UPDATE conversations SET updated_at = ? WHERE id = ?", | |
| [now, conv_id], | |
| ) | |
| self._conn.commit() | |
| single_mode = bool(req.single or req.maybe_single) | |
| if req.columns or single_mode: | |
| ids = [row.get("id") for row in prepared if row.get("id")] | |
| data = None | |
| if ids: | |
| filters = [DbFilter(op="in", column="id", values=ids)] | |
| select_req = DbQueryRequest( | |
| providerId=req.provider_id, | |
| action="select", | |
| table=table, | |
| columns=req.columns, | |
| filters=filters, | |
| single=single_mode, | |
| ) | |
| return self._select(select_req) | |
| data = prepared[0] if single_mode else prepared | |
| return DbQueryResponse(data=data) | |
| return DbQueryResponse(data=prepared[0] if single_mode else prepared) | |
| def _update(self, req: DbQueryRequest) -> DbQueryResponse: | |
| table = req.table | |
| if not table: | |
| return DbQueryResponse(error="Missing table") | |
| payload = req.payload or {} | |
| if not payload: | |
| return DbQueryResponse(error="Missing payload") | |
| now = _utc_now_iso() | |
| if table in TABLES_WITH_UPDATED_AT and "updated_at" not in payload: | |
| payload["updated_at"] = now | |
| set_clause = ", ".join([f"{k} = ?" for k in payload.keys()]) | |
| params = [_serialize_value(table, k, v) for k, v in payload.items()] | |
| where_clause, where_params = _build_where_clause(req.filters) | |
| sql = f"UPDATE {table} SET {set_clause} {where_clause}" | |
| self._execute(sql, params + where_params) | |
| if table == "conversation_messages": | |
| conv_id = payload.get("conversation_id") | |
| if conv_id: | |
| self._execute( | |
| "UPDATE conversations SET updated_at = ? WHERE id = ?", | |
| [now, conv_id], | |
| ) | |
| if req.columns or req.single or req.maybe_single: | |
| select_req = DbQueryRequest( | |
| providerId=req.provider_id, | |
| action="select", | |
| table=table, | |
| columns=req.columns, | |
| filters=req.filters, | |
| single=bool(req.single or req.maybe_single), | |
| ) | |
| return self._select(select_req) | |
| return DbQueryResponse(data=None) | |
| def _delete(self, req: DbQueryRequest) -> DbQueryResponse: | |
| table = req.table | |
| if not table: | |
| return DbQueryResponse(error="Missing table") | |
| pending_cleanup_all = False | |
| pending_cleanup_conversation_ids: list[str] = [] | |
| if table == "conversations": | |
| if req.filters: | |
| pending_cleanup_conversation_ids = _extract_filter_values(req.filters, "id") | |
| else: | |
| pending_cleanup_all = True | |
| elif table == "conversation_messages": | |
| pending_cleanup_conversation_ids = _extract_filter_values(req.filters, "conversation_id") | |
| where_clause, params = _build_where_clause(req.filters) | |
| sql = f"DELETE FROM {table} {where_clause}" | |
| self._execute(sql, params) | |
| # Keep pending HITL runs in sync with conversation lifecycle. | |
| if pending_cleanup_all: | |
| self._execute("DELETE FROM pending_form_runs", []) | |
| elif pending_cleanup_conversation_ids: | |
| placeholders = ", ".join(["?"] * len(pending_cleanup_conversation_ids)) | |
| self._execute( | |
| f"DELETE FROM pending_form_runs WHERE conversation_id IN ({placeholders})", | |
| pending_cleanup_conversation_ids, | |
| ) | |
| return DbQueryResponse(data=None) | |
| def _upsert(self, req: DbQueryRequest) -> DbQueryResponse: | |
| table = req.table | |
| if not table: | |
| return DbQueryResponse(error="Missing table") | |
| values = req.values | |
| if values is None: | |
| return DbQueryResponse(error="Missing values") | |
| rows = values if isinstance(values, list) else [values] | |
| now = _utc_now_iso() | |
| prepared = [] | |
| for row in rows: | |
| payload = dict(row) | |
| if table in TABLES_WITH_ID and not payload.get("id"): | |
| payload["id"] = str(uuid.uuid4()) | |
| if table in TABLES_WITH_CREATED_AT: | |
| payload.setdefault("created_at", now) | |
| if table in TABLES_WITH_UPDATED_AT and "updated_at" not in payload: | |
| payload["updated_at"] = now | |
| prepared.append(payload) | |
| columns = sorted({key for row in prepared for key in row.keys()}) | |
| placeholders = ", ".join(["?"] * len(columns)) | |
| on_conflict_raw = req.on_conflict or ["id"] | |
| on_conflict = ( | |
| [str(item).strip() for item in on_conflict_raw if str(item).strip()] | |
| if isinstance(on_conflict_raw, list) | |
| else [part.strip() for part in str(on_conflict_raw).split(",") if part.strip()] | |
| ) | |
| update_cols = [col for col in columns if col not in on_conflict] | |
| update_clause = ", ".join([f"{col}=excluded.{col}" for col in update_cols]) | |
| sql = ( | |
| f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders}) " | |
| f"ON CONFLICT({', '.join(on_conflict)}) DO UPDATE SET {update_clause}" | |
| ) | |
| with self._lock: | |
| cursor = self._conn.cursor() | |
| for payload in prepared: | |
| params = [_serialize_value(table, col, payload.get(col)) for col in columns] | |
| cursor.execute(sql, params) | |
| self._conn.commit() | |
| single_mode = bool(req.single or req.maybe_single) | |
| if req.columns or single_mode: | |
| ids = [row.get("id") for row in prepared if row.get("id")] | |
| filters = [DbFilter(op="in", column="id", values=ids)] if ids else None | |
| select_req = DbQueryRequest( | |
| providerId=req.provider_id, | |
| action="select", | |
| table=table, | |
| columns=req.columns, | |
| filters=filters, | |
| single=single_mode, | |
| ) | |
| return self._select(select_req) | |
| return DbQueryResponse(data=prepared[0] if single_mode else prepared) | |
| def _rpc(self, req: DbQueryRequest) -> DbQueryResponse: | |
| if not req.rpc: | |
| return DbQueryResponse(error="Missing rpc") | |
| if req.rpc.name == "match_document_chunks": | |
| return self._rpc_match_document_chunks(req.rpc.params or {}) | |
| if req.rpc.name == "hybrid_search": | |
| return self._rpc_hybrid_search(req.rpc.params or {}) | |
| return DbQueryResponse(error=f"Unsupported rpc: {req.rpc.name}") | |
| def _rpc_match_document_chunks(self, params: dict[str, Any]) -> DbQueryResponse: | |
| document_ids = params.get("document_ids") or [] | |
| query_embedding = params.get("query_embedding") or [] | |
| match_count = int(params.get("match_count") or 3) | |
| if not document_ids or not query_embedding: | |
| return DbQueryResponse(data=[]) | |
| rows = self._fetchall( | |
| "SELECT id, document_id, section_id, title_path, text, source_hint, chunk_index, embedding " | |
| "FROM document_chunks WHERE document_id IN ({})".format( | |
| ", ".join(["?"] * len(document_ids)) | |
| ), | |
| [str(i) for i in document_ids], | |
| ) | |
| scored = [] | |
| for row in rows: | |
| try: | |
| embedding = json.loads(row.get("embedding") or "[]") | |
| except Exception: | |
| embedding = [] | |
| if not embedding or len(embedding) != len(query_embedding): | |
| continue | |
| dot = sum(a * b for a, b in zip(embedding, query_embedding)) | |
| norm_a = sum(a * a for a in embedding) ** 0.5 | |
| norm_b = sum(b * b for b in query_embedding) ** 0.5 | |
| similarity = dot / (norm_a * norm_b) if norm_a and norm_b else 0.0 | |
| row["similarity"] = similarity | |
| scored.append(row) | |
| scored.sort(key=lambda r: r["similarity"], reverse=True) | |
| limited = scored[: max(match_count, 1)] | |
| data = [] | |
| for row in limited: | |
| entry = _deserialize_row("document_chunks", row) | |
| entry["similarity"] = row["similarity"] | |
| data.append(entry) | |
| return DbQueryResponse(data=data) | |
| def _rpc_hybrid_search(self, params: dict[str, Any]) -> DbQueryResponse: | |
| document_ids = params.get("document_ids") or [] | |
| query_text = str(params.get("query_text") or "").strip().lower() | |
| query_embedding = params.get("query_embedding") or [] | |
| match_count = int(params.get("match_count") or 10) | |
| if not document_ids or not query_text or not query_embedding: | |
| return DbQueryResponse(data=[]) | |
| rows = self._fetchall( | |
| "SELECT id, document_id, section_id, title_path, text, source_hint, chunk_index, embedding " | |
| "FROM document_chunks WHERE document_id IN ({})".format( | |
| ", ".join(["?"] * len(document_ids)) | |
| ), | |
| [str(i) for i in document_ids], | |
| ) | |
| scored = [] | |
| for row in rows: | |
| try: | |
| embedding = json.loads(row.get("embedding") or "[]") | |
| except Exception: | |
| embedding = [] | |
| if not embedding or len(embedding) != len(query_embedding): | |
| continue | |
| dot = sum(a * b for a, b in zip(embedding, query_embedding)) | |
| norm_a = sum(a * a for a in embedding) ** 0.5 | |
| norm_b = sum(b * b for b in query_embedding) ** 0.5 | |
| similarity = dot / (norm_a * norm_b) if norm_a and norm_b else 0.0 | |
| text = (row.get("text") or "").lower() | |
| fts_score = text.count(query_text) | |
| score = similarity + (0.1 * fts_score) | |
| row["similarity"] = similarity | |
| row["fts_score"] = fts_score | |
| row["score"] = score | |
| scored.append(row) | |
| scored.sort(key=lambda r: r["score"], reverse=True) | |
| limited = scored[: max(match_count, 1)] | |
| data = [] | |
| for row in limited: | |
| entry = _deserialize_row("document_chunks", row) | |
| entry["similarity"] = row["similarity"] | |
| entry["fts_score"] = row["fts_score"] | |
| entry["score"] = row["score"] | |
| data.append(entry) | |
| return DbQueryResponse(data=data) | |
| def _test(self, req: DbQueryRequest) -> DbQueryResponse: | |
| tables = [ | |
| "spaces", | |
| "agents", | |
| "space_agents", | |
| "conversations", | |
| "conversation_messages", | |
| "space_documents", | |
| "conversation_documents", | |
| "document_sections", | |
| "document_chunks", | |
| "user_settings", | |
| "memory_domains", | |
| "memory_summaries", | |
| "user_tools", | |
| "home_notes", | |
| "home_shortcuts", | |
| ] | |
| results = {} | |
| for table in tables: | |
| try: | |
| self._fetchone(f"SELECT 1 FROM {table} LIMIT 1", []) | |
| results[table] = True | |
| except Exception: | |
| results[table] = False | |
| all_ok = all(results.values()) | |
| return DbQueryResponse( | |
| data={ | |
| "success": all_ok, | |
| "connection": True, | |
| "tables": results, | |
| "message": "Connection successful; required tables are present." | |
| if all_ok | |
| else "Connection OK, but missing tables.", | |
| } | |
| ) | |
| class SQLAlchemyAdapter: | |
| config: ProviderConfig | |
| def __post_init__(self) -> None: | |
| if not self.config.connection_url: | |
| raise ValueError(f"{self.config.type} provider missing connection url") | |
| self._engine = create_engine(self.config.connection_url, future=True, pool_pre_ping=True) | |
| self._metadata = MetaData() | |
| self._table_cache: dict[str, Any] = {} | |
| self._lock = threading.Lock() | |
| def _get_table(self, table_name: str): | |
| with self._lock: | |
| if table_name in self._table_cache: | |
| return self._table_cache[table_name] | |
| table_obj = self._metadata.tables.get(table_name) | |
| if table_obj is None: | |
| self._metadata.reflect(bind=self._engine, only=[table_name], extend_existing=True) | |
| table_obj = self._metadata.tables.get(table_name) | |
| if table_obj is None: | |
| raise ValueError(f"Unknown table: {table_name}") | |
| self._table_cache[table_name] = table_obj | |
| return table_obj | |
| def _apply_filters(self, stmt, table_obj, filters: list[DbFilter] | None): | |
| expressions = [_build_sa_expression(table_obj, filt) for filt in (filters or [])] | |
| expressions = [expr for expr in expressions if expr is not None] | |
| if expressions: | |
| stmt = stmt.where(and_(*expressions)) | |
| return stmt | |
| def execute(self, req: DbQueryRequest) -> DbQueryResponse: | |
| try: | |
| if req.action == "test": | |
| return self._test() | |
| if req.action == "rpc": | |
| return DbQueryResponse(error=f"RPC is not supported for provider type '{self.config.type}'") | |
| if not req.table: | |
| return DbQueryResponse(error="Missing table") | |
| if req.action == "select": | |
| return self._select(req) | |
| if req.action == "insert": | |
| return self._insert(req) | |
| if req.action == "update": | |
| return self._update(req) | |
| if req.action == "delete": | |
| return self._delete(req) | |
| if req.action == "upsert": | |
| return self._upsert(req) | |
| return DbQueryResponse(error="Unsupported action") | |
| except Exception as exc: | |
| logger.error("%s adapter error: %s", self.config.type, exc) | |
| return DbQueryResponse(error=str(exc)) | |
| def _select(self, req: DbQueryRequest) -> DbQueryResponse: | |
| table_obj = self._get_table(req.table) | |
| columns = _normalize_columns(req.columns) | |
| selected_columns = [table_obj.c[col] for col in columns if col in table_obj.c] if columns else [table_obj] | |
| stmt = select(*selected_columns) | |
| stmt = self._apply_filters(stmt, table_obj, req.filters) | |
| if req.order: | |
| for order in req.order: | |
| if order.column in table_obj.c: | |
| column = table_obj.c[order.column] | |
| stmt = stmt.order_by(column.asc() if order.ascending else column.desc()) | |
| if req.range: | |
| stmt = stmt.offset(req.range.from_).limit(max(0, req.range.to - req.range.from_ + 1)) | |
| elif req.limit: | |
| stmt = stmt.limit(req.limit) | |
| with self._engine.begin() as conn: | |
| rows = [dict(row._mapping) for row in conn.execute(stmt).fetchall()] | |
| count = None | |
| if req.count == "exact": | |
| count_stmt = select(func.count()).select_from(table_obj) | |
| count_stmt = self._apply_filters(count_stmt, table_obj, req.filters) | |
| count = int(conn.execute(count_stmt).scalar_one() or 0) | |
| data: Any = _deserialize_rows(req.table, rows) | |
| if req.single or req.maybe_single: | |
| data = data[0] if data else None | |
| return DbQueryResponse(data=data, count=count) | |
| def _prepare_rows(self, table: str, values: list[dict[str, Any]] | dict[str, Any]) -> list[dict[str, Any]]: | |
| rows = values if isinstance(values, list) else [values] | |
| now = _utc_now_iso() | |
| prepared = [] | |
| for row in rows: | |
| payload = dict(row) | |
| if table in TABLES_WITH_ID and not payload.get("id"): | |
| payload["id"] = str(uuid.uuid4()) | |
| if table in TABLES_WITH_CREATED_AT or "created_at" in payload: | |
| payload.setdefault("created_at", now) | |
| if table in TABLES_WITH_UPDATED_AT or "updated_at" in payload: | |
| payload.setdefault("updated_at", now) | |
| prepared.append(_prepare_payload(table, payload)) | |
| return prepared | |
| def _insert(self, req: DbQueryRequest) -> DbQueryResponse: | |
| values = req.values if req.values is not None else req.payload | |
| if values is None: | |
| return DbQueryResponse(error="Missing values") | |
| table_obj = self._get_table(req.table) | |
| prepared = self._prepare_rows(req.table, values) | |
| with self._engine.begin() as conn: | |
| conn.execute(table_obj.insert(), prepared) | |
| data: Any = _deserialize_rows(req.table, prepared) | |
| if req.single or req.maybe_single: | |
| data = data[0] if data else None | |
| return DbQueryResponse(data=data) | |
| def _update(self, req: DbQueryRequest) -> DbQueryResponse: | |
| payload = dict(req.payload or {}) | |
| if not payload: | |
| return DbQueryResponse(error="Missing payload") | |
| if req.table in TABLES_WITH_UPDATED_AT and "updated_at" not in payload: | |
| payload["updated_at"] = _utc_now_iso() | |
| table_obj = self._get_table(req.table) | |
| stmt = update(table_obj).values(**_prepare_payload(req.table, payload)) | |
| stmt = self._apply_filters(stmt, table_obj, req.filters) | |
| with self._engine.begin() as conn: | |
| conn.execute(stmt) | |
| if req.columns or req.single or req.maybe_single: | |
| return self._select( | |
| DbQueryRequest( | |
| providerId=req.provider_id, | |
| action="select", | |
| table=req.table, | |
| columns=req.columns, | |
| filters=req.filters, | |
| single=bool(req.single or req.maybe_single), | |
| ) | |
| ) | |
| return DbQueryResponse(data=None) | |
| def _delete(self, req: DbQueryRequest) -> DbQueryResponse: | |
| table_obj = self._get_table(req.table) | |
| stmt = delete(table_obj) | |
| stmt = self._apply_filters(stmt, table_obj, req.filters) | |
| with self._engine.begin() as conn: | |
| conn.execute(stmt) | |
| return DbQueryResponse(data=None) | |
| def _upsert(self, req: DbQueryRequest) -> DbQueryResponse: | |
| values = req.values if req.values is not None else req.payload | |
| if values is None: | |
| return DbQueryResponse(error="Missing values") | |
| table_obj = self._get_table(req.table) | |
| prepared = self._prepare_rows(req.table, values) | |
| on_conflict_raw = req.on_conflict or ["id"] | |
| on_conflict = ( | |
| [str(item).strip() for item in on_conflict_raw if str(item).strip()] | |
| if isinstance(on_conflict_raw, list) | |
| else [part.strip() for part in str(on_conflict_raw).split(",") if part.strip()] | |
| ) | |
| update_cols = [col.name for col in table_obj.columns if col.name not in on_conflict] | |
| dialect = self._engine.dialect.name | |
| if dialect == "postgresql": | |
| stmt = postgres_insert(table_obj).values(prepared) | |
| stmt = stmt.on_conflict_do_update( | |
| index_elements=on_conflict, | |
| set_={col: getattr(stmt.excluded, col) for col in update_cols}, | |
| ) | |
| elif dialect in {"mysql", "mariadb"}: | |
| stmt = mysql_insert(table_obj).values(prepared) | |
| stmt = stmt.on_duplicate_key_update( | |
| **{col: getattr(stmt.inserted, col) for col in update_cols} | |
| ) | |
| else: | |
| return DbQueryResponse( | |
| error=f"Upsert is not supported for SQL dialect '{dialect}' on provider type '{self.config.type}'" | |
| ) | |
| with self._engine.begin() as conn: | |
| conn.execute(stmt) | |
| data: Any = _deserialize_rows(req.table, prepared) | |
| if req.single or req.maybe_single: | |
| data = data[0] if data else None | |
| return DbQueryResponse(data=data) | |
| def _test(self) -> DbQueryResponse: | |
| inspector = inspect(self._engine) | |
| table_names = set(inspector.get_table_names()) | |
| tables = [ | |
| "spaces", | |
| "agents", | |
| "space_agents", | |
| "conversations", | |
| "conversation_messages", | |
| "space_documents", | |
| "conversation_documents", | |
| "document_sections", | |
| "document_chunks", | |
| "user_settings", | |
| "memory_domains", | |
| "memory_summaries", | |
| "user_tools", | |
| "home_notes", | |
| "home_shortcuts", | |
| "pending_form_runs", | |
| "scrapbook", | |
| ] | |
| results = {table: table in table_names for table in tables} | |
| all_ok = all(results.values()) | |
| return DbQueryResponse( | |
| data={ | |
| "success": all_ok, | |
| "connection": True, | |
| "tables": results, | |
| "message": "Connection successful; required tables are present." | |
| if all_ok | |
| else "Connection OK, but missing tables.", | |
| } | |
| ) | |
| class SupabaseAdapter: | |
| config: ProviderConfig | |
| def __post_init__(self) -> None: | |
| from supabase import create_client | |
| if not self.config.supabase_url or not self.config.supabase_anon_key: | |
| raise ValueError("Supabase provider missing url or anon key") | |
| self._client = create_client(self.config.supabase_url, self.config.supabase_anon_key) | |
| def execute(self, req: DbQueryRequest) -> DbQueryResponse: | |
| try: | |
| if req.action == "test": | |
| return self._test() | |
| if req.action == "rpc": | |
| return self._rpc(req) | |
| if not req.table: | |
| return DbQueryResponse(error="Missing table") | |
| if req.action == "select": | |
| return self._select(req) | |
| if req.action == "insert": | |
| return self._insert(req) | |
| if req.action == "update": | |
| return self._update(req) | |
| if req.action == "delete": | |
| return self._delete(req) | |
| if req.action == "upsert": | |
| return self._upsert(req) | |
| return DbQueryResponse(error="Unsupported action") | |
| except Exception as exc: | |
| logger.error("Supabase adapter error: %s", exc) | |
| return DbQueryResponse(error=str(exc)) | |
| def _table(self, table: str): | |
| if hasattr(self._client, "table"): | |
| return self._client.table(table) | |
| if hasattr(self._client, "from_"): | |
| return self._client.from_(table) | |
| raise AttributeError("Supabase client has no table/from_ method") | |
| def _apply_filters(self, query, filters: list[DbFilter] | None): | |
| if not filters: | |
| return query | |
| for filt in filters: | |
| if filt.op == "or" and filt.filters: | |
| or_parts = [] | |
| for inner in filt.filters: | |
| if inner.op == "is_null": | |
| or_parts.append(f"{inner.column}.is.null") | |
| elif inner.op == "not_in": | |
| values = ",".join([str(v) for v in (inner.values or [])]) | |
| or_parts.append(f"{inner.column}.not.in.({values})") | |
| elif inner.op == "eq": | |
| or_parts.append(f"{inner.column}.eq.{inner.value}") | |
| if hasattr(query, "or_") and or_parts: | |
| query = query.or_(",".join(or_parts)) | |
| continue | |
| col = filt.column | |
| if not col: | |
| continue | |
| if filt.op == "eq": | |
| query = query.eq(col, filt.value) | |
| elif filt.op == "gt": | |
| query = query.gt(col, filt.value) | |
| elif filt.op == "lt": | |
| query = query.lt(col, filt.value) | |
| elif filt.op == "ilike": | |
| query = query.ilike(col, f"%{filt.value}%") | |
| elif filt.op == "in": | |
| query = query.in_(col, filt.values or []) | |
| elif filt.op == "not_in": | |
| if hasattr(query, "not_"): | |
| query = query.not_.in_(col, filt.values or []) | |
| elif filt.op == "is_null": | |
| if hasattr(query, "is_"): | |
| query = query.is_(col, "null") | |
| return query | |
| def _select(self, req: DbQueryRequest) -> DbQueryResponse: | |
| query = self._table(req.table) | |
| columns = req.columns or "*" | |
| if isinstance(columns, list): | |
| columns = ",".join([str(col).strip() for col in columns if str(col).strip()]) or "*" | |
| if req.count: | |
| query = query.select(columns, count=req.count) | |
| else: | |
| query = query.select(columns) | |
| query = self._apply_filters(query, req.filters) | |
| if req.order: | |
| for order in req.order: | |
| query = query.order(order.column, desc=not order.ascending) | |
| if req.range: | |
| query = query.range(req.range.from_, req.range.to) | |
| elif req.limit: | |
| query = query.limit(req.limit) | |
| if req.maybe_single: | |
| if hasattr(query, "maybe_single"): | |
| query = query.maybe_single() | |
| elif hasattr(query, "maybeSingle"): | |
| query = query.maybeSingle() | |
| else: | |
| # Keep best-effort maybe-single semantics without forcing object coercion. | |
| if not req.limit: | |
| query = query.limit(1) | |
| elif req.single and hasattr(query, "single"): | |
| query = query.single() | |
| try: | |
| result = query.execute() | |
| except Exception as exc: | |
| if req.maybe_single: | |
| error_text = str(exc) | |
| if ( | |
| "PGRST116" in error_text | |
| or "Cannot coerce the result to a single JSON object" in error_text | |
| or "The result contains 0 rows" in error_text | |
| ): | |
| return DbQueryResponse(data=None, count=None) | |
| raise | |
| data = getattr(result, "data", None) | |
| count = getattr(result, "count", None) | |
| error = getattr(result, "error", None) | |
| if req.maybe_single and isinstance(data, list): | |
| data = data[0] if data else None | |
| if error and req.maybe_single: | |
| error_text = str(error) | |
| if ( | |
| "PGRST116" in error_text | |
| or "Cannot coerce the result to a single JSON object" in error_text | |
| or "The result contains 0 rows" in error_text | |
| ): | |
| return DbQueryResponse(data=None, count=count) | |
| if error: | |
| return DbQueryResponse(error=str(error)) | |
| return DbQueryResponse(data=data, count=count) | |
| def _insert(self, req: DbQueryRequest) -> DbQueryResponse: | |
| query = self._table(req.table) | |
| values = req.values if req.values is not None else req.payload | |
| if values is None: | |
| return DbQueryResponse(error="Missing values") | |
| # Use upsert with ignoreDuplicates=False to get data back, or just insert | |
| # Supabase insert returns 204 by default; we handle this gracefully | |
| try: | |
| result = query.insert(values).execute() | |
| except Exception as exc: | |
| # 204 No Content is not an error — insert succeeded but no data returned | |
| if "204" in str(exc) or "Missing response" in str(exc): | |
| return DbQueryResponse(data=None) | |
| raise | |
| error = getattr(result, "error", None) | |
| if error: | |
| return DbQueryResponse(error=str(error)) | |
| data = getattr(result, "data", None) | |
| if (req.single or req.maybe_single) and isinstance(data, list): | |
| data = data[0] if data else None | |
| return DbQueryResponse(data=data) | |
| def _update(self, req: DbQueryRequest) -> DbQueryResponse: | |
| query = self._table(req.table) | |
| payload = req.payload or {} | |
| query = query.update(payload) | |
| query = self._apply_filters(query, req.filters) | |
| # Supabase update returns 204 by default; handle gracefully | |
| try: | |
| result = query.execute() | |
| except Exception as exc: | |
| if "204" in str(exc) or "Missing response" in str(exc): | |
| return DbQueryResponse(data=None) | |
| raise | |
| error = getattr(result, "error", None) | |
| if error: | |
| return DbQueryResponse(error=str(error)) | |
| data = getattr(result, "data", None) | |
| if (req.single or req.maybe_single) and isinstance(data, list): | |
| data = data[0] if data else None | |
| return DbQueryResponse(data=data) | |
| def _delete(self, req: DbQueryRequest) -> DbQueryResponse: | |
| pending_cleanup_all = False | |
| pending_cleanup_conversation_ids: list[str] = [] | |
| if req.table == "conversations": | |
| if req.filters: | |
| pending_cleanup_conversation_ids = _extract_filter_values(req.filters, "id") | |
| else: | |
| pending_cleanup_all = True | |
| elif req.table == "conversation_messages": | |
| pending_cleanup_conversation_ids = _extract_filter_values(req.filters, "conversation_id") | |
| query = self._table(req.table) | |
| query = query.delete() | |
| query = self._apply_filters(query, req.filters) | |
| result = query.execute() | |
| error = getattr(result, "error", None) | |
| if error: | |
| return DbQueryResponse(error=str(error)) | |
| # Keep pending HITL runs in sync with conversation lifecycle. | |
| if pending_cleanup_all: | |
| self._table("pending_form_runs").delete().execute() | |
| elif pending_cleanup_conversation_ids: | |
| self._table("pending_form_runs").delete().in_( | |
| "conversation_id", pending_cleanup_conversation_ids | |
| ).execute() | |
| return DbQueryResponse(data=getattr(result, "data", None)) | |
| def _upsert(self, req: DbQueryRequest) -> DbQueryResponse: | |
| query = self._table(req.table) | |
| values = req.values if req.values is not None else req.payload | |
| if values is None: | |
| return DbQueryResponse(error="Missing values") | |
| on_conflict = req.on_conflict | |
| if isinstance(on_conflict, list): | |
| on_conflict = ",".join([str(item).strip() for item in on_conflict if str(item).strip()]) | |
| query = query.upsert(values, on_conflict=on_conflict) | |
| result = query.execute() | |
| error = getattr(result, "error", None) | |
| if error: | |
| return DbQueryResponse(error=str(error)) | |
| data = getattr(result, "data", None) | |
| if (req.single or req.maybe_single) and isinstance(data, list): | |
| data = data[0] if data else None | |
| return DbQueryResponse(data=data) | |
| def _rpc(self, req: DbQueryRequest) -> DbQueryResponse: | |
| if not req.rpc: | |
| return DbQueryResponse(error="Missing rpc") | |
| result = self._client.rpc(req.rpc.name, req.rpc.params or {}).execute() | |
| error = getattr(result, "error", None) | |
| if error: | |
| return DbQueryResponse(error=str(error)) | |
| return DbQueryResponse(data=getattr(result, "data", None)) | |
| def _test(self) -> DbQueryResponse: | |
| table_fields = { | |
| "spaces": "id", | |
| "agents": "id", | |
| "space_agents": "space_id", | |
| "conversations": "id", | |
| "conversation_messages": "id", | |
| "space_documents": "id", | |
| "conversation_documents": "conversation_id", | |
| "document_sections": "id", | |
| "document_chunks": "id", | |
| "user_settings": "key", | |
| "memory_domains": "id", | |
| "memory_summaries": "id", | |
| "user_tools": "id", | |
| "home_notes": "id", | |
| "home_shortcuts": "id", | |
| } | |
| results = {} | |
| for table, field in table_fields.items(): | |
| try: | |
| query = self._table(table).select(field).limit(1) | |
| result = query.execute() | |
| results[table] = getattr(result, "error", None) is None | |
| except Exception: | |
| results[table] = False | |
| all_ok = all(results.values()) | |
| return DbQueryResponse( | |
| data={ | |
| "success": all_ok, | |
| "connection": True, | |
| "tables": results, | |
| "message": "Connection successful; required tables are present." | |
| if all_ok | |
| else "Connection OK, but missing tables.", | |
| } | |
| ) | |
| def build_adapter(config: ProviderConfig): | |
| if config.type == "sqlite": | |
| return SQLiteAdapter(config) | |
| if config.type == "supabase": | |
| return SupabaseAdapter(config) | |
| if config.type in {"postgres", "mysql", "mariadb"}: | |
| return SQLAlchemyAdapter(config) | |
| raise ValueError("Unsupported provider type") | |