File size: 6,323 Bytes
4ef118d
 
592cb1d
 
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
 
4ef118d
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
 
 
 
 
 
 
 
 
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Database service for resolving adapters and executing queries.

The backend now uses one active database configuration at a time.
"""

from __future__ import annotations

import asyncio
import logging
import threading
from typing import Any, Union
from pathlib import Path
import sqlite3

import psycopg2

from .db_adapters import SQLAlchemyAdapter, SQLiteAdapter, SupabaseAdapter, build_adapter
from .db_registry import ProviderConfig, get_provider_registry, normalize_provider_type
from ..config import get_settings

logger = logging.getLogger(__name__)

DbAdapter = Union[SQLiteAdapter, SupabaseAdapter, SQLAlchemyAdapter]

_adapter_cache: dict[str, DbAdapter] = {}
_adapter_cache_lock = threading.Lock()

APP_TABLES: list[str] = [
    "conversation_documents",
    "space_agents",
    "attachments",
    "conversation_events",
    "conversation_messages",
    "document_chunks",
    "document_sections",
    "space_documents",
    "conversations",
    "agents",
    "spaces",
    "home_shortcuts",
    "home_notes",
    "user_settings",
    "memory_summaries",
    "memory_domains",
    "user_tools",
    "pending_form_runs",
    "scrapbook",
]


def _resolve_provider(provider_id_or_type: str | None) -> ProviderConfig | None:
    registry = get_provider_registry()
    providers = registry.list()
    if not providers:
        return None

    active = providers[0]
    if not provider_id_or_type:
        return active

    raw = str(provider_id_or_type).strip()
    if not raw:
        return active

    by_id = registry.get(raw)
    if by_id:
        return by_id

    provider_type = normalize_provider_type(raw)
    if provider_type == active.type:
        return active
    return None


def get_db_adapter(provider_id_or_type: str | None = None) -> DbAdapter | None:
    """
    Get a database adapter for the specified provider (or default).
    """
    provider = _resolve_provider(provider_id_or_type)
    if not provider:
        # Only warn if explicitly requested but not found, or if no providers at all
        if provider_id_or_type or not get_provider_registry().list():
             logger.warning("[DB] No database provider found for: %s", provider_id_or_type)
        return None

    with _adapter_cache_lock:
        if provider.id in _adapter_cache:
            return _adapter_cache[provider.id]

        try:
            adapter = build_adapter(provider)
            _adapter_cache[provider.id] = adapter
            logger.info("[DB] Built adapter for provider: %s (%s)", provider.id, provider.type)
            return adapter
        except Exception as e:
            logger.error("[DB] Failed to build adapter for %s: %s", provider.id, e)
            return None


def invalidate_db_adapter_cache(provider_id: str | None = None) -> None:
    with _adapter_cache_lock:
        if provider_id is None:
            _adapter_cache.clear()
            return
        _adapter_cache.pop(provider_id, None)


async def execute_db_async(adapter: DbAdapter, request: Any) -> Any:
    """
    Execute a synchronous adapter query in a worker thread to avoid blocking the event loop.
    """
    return await asyncio.to_thread(adapter.execute, request)


def initialize_provider_schema(provider: ProviderConfig) -> dict[str, Any]:
    """
    Initialize required database schema for provider.
    - SQLite: schema is auto-created by adapter constructor.
    - Supabase/Postgres: execute supabase/schema.sql when SUPABASE_DB_URL is configured.
    """
    if provider.type == "sqlite":
        if not provider.sqlite_path:
            return {"success": False, "message": "SQLite path is missing."}
        # Ensure we don't reuse a stale adapter/connection after reset.
        invalidate_db_adapter_cache(provider.id)
        conn = None
        try:
            conn = sqlite3.connect(provider.sqlite_path)
            cursor = conn.cursor()
            cursor.execute("PRAGMA foreign_keys=OFF;")
            for table in APP_TABLES:
                cursor.execute(f"DROP TABLE IF EXISTS {table};")
            conn.commit()
        except Exception as exc:
            logger.exception("[DB] SQLite reset failed: %s", exc)
            return {"success": False, "message": f"SQLite reset failed: {exc}"}
        finally:
            if conn:
                conn.close()

        invalidate_db_adapter_cache(provider.id)
        adapter = get_db_adapter(provider.id)
        if not adapter:
            return {"success": False, "message": "Failed to initialize SQLite adapter."}
        return {"success": True, "message": "SQLite schema reset and initialized."}

    if provider.type in {"postgres", "mysql", "mariadb"}:
        return {
            "success": False,
            "message": (
                f"Automatic schema initialization is not implemented for provider type '{provider.type}'. "
                "Create the schema externally and use /db/query test to validate connectivity."
            ),
        }

    if provider.type != "supabase":
        return {"success": False, "message": f"Unsupported provider type: {provider.type}"}

    settings = get_settings()
    db_url = (settings.supabase_db_url or "").strip()
    if not db_url:
        return {
            "success": False,
            "message": "SUPABASE_DB_URL is required for automatic Supabase initialization.",
        }

    schema_path = Path(__file__).resolve().parents[3] / "supabase" / "schema.sql"
    if not schema_path.exists():
        return {"success": False, "message": f"Schema file not found: {schema_path}"}

    schema_sql = schema_path.read_text(encoding="utf-8")
    if not schema_sql.strip():
        return {"success": False, "message": "Schema SQL file is empty."}

    conn = None
    try:
        conn = psycopg2.connect(db_url)
        conn.autocommit = True
        with conn.cursor() as cursor:
            for table in APP_TABLES:
                cursor.execute(f"DROP TABLE IF EXISTS public.{table} CASCADE;")
            cursor.execute(schema_sql)
        return {"success": True, "message": "Supabase schema reset and initialized."}
    except Exception as exc:
        logger.exception("[DB] Supabase initialization failed: %s", exc)
        return {"success": False, "message": f"Supabase initialization failed: {exc}"}
    finally:
        if conn:
            conn.close()