File size: 7,029 Bytes
4ef118d
592cb1d
 
 
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
 
 
 
 
4ef118d
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
592cb1d
 
 
 
4ef118d
592cb1d
4ef118d
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
592cb1d
 
 
 
 
 
 
 
 
 
 
4ef118d
 
 
 
 
 
 
 
 
 
 
 
592cb1d
4ef118d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
Database proxy routes.

The backend exposes one active database configuration at a time.
"""

from __future__ import annotations

import logging
import threading
from pathlib import Path

from fastapi import APIRouter, Header, HTTPException

from ..config import get_settings
from ..models.db import DbProviderUpsertRequest, DbQueryRequest, DbQueryResponse
from ..services.db_adapters import build_adapter
from ..services.db_registry import ProviderConfig, get_provider_registry, normalize_provider_type
from ..services.db_service import initialize_provider_schema, invalidate_db_adapter_cache

router = APIRouter()
logger = logging.getLogger(__name__)

_adapters = {}
_adapters_lock = threading.Lock()


def _get_adapter(provider_id: str):
    registry = get_provider_registry()
    provider = registry.get(provider_id)
    if not provider:
        raise HTTPException(status_code=400, detail="Unknown providerId")
    with _adapters_lock:
        if provider_id not in _adapters:
            _adapters[provider_id] = build_adapter(provider)
        return _adapters[provider_id]


def _invalidate_adapter_cache(provider_id: str | None = None):
    with _adapters_lock:
        if provider_id is None:
            _adapters.clear()
        else:
            _adapters.pop(provider_id, None)
    invalidate_db_adapter_cache(provider_id)


@router.get("/db/providers")
def list_db_providers():
    registry = get_provider_registry()
    settings = get_settings()
    global_requires_key = bool((settings.db_access_key or "").strip())
    include_details = registry.is_mutable()
    data = [
        {
            "id": provider.id,
            "type": provider.type,
            "label": provider.label,
            "requiresAccessKey": bool(provider.access_key) or global_requires_key,
            "url": provider.supabase_url if include_details and provider.type == "supabase" else None,
            "anonKey": (
                provider.supabase_anon_key if include_details and provider.type == "supabase" else None
            ),
            "path": provider.sqlite_path if include_details and provider.type == "sqlite" else None,
            "connectionUrl": (
                provider.connection_url
                if include_details and provider.type not in {"supabase", "sqlite"}
                else None
            ),
        }
        for provider in registry.list()
    ]
    return {"providers": data, "provider": data[0] if data else None, "mutable": registry.is_mutable()}


@router.post("/db/providers")
def upsert_db_provider(request: DbProviderUpsertRequest):
    registry = get_provider_registry()
    if not registry.is_mutable():
        raise HTTPException(status_code=403, detail="Provider changes are only allowed in Electron mode")

    provider_id = request.id.strip() or "default"
    provider_type = normalize_provider_type(request.type)
    if not provider_type:
        raise HTTPException(status_code=400, detail="Unsupported database provider type")

    label = (request.label or provider_type).strip()
    access_key = (request.access_key or "").strip() or None

    if provider_type == "supabase":
        url = (request.url or "").strip()
        anon_key = (request.anon_key or "").strip()
        if not url or not anon_key:
            raise HTTPException(status_code=400, detail="Supabase url and anonKey are required")
        provider = ProviderConfig(
            id=provider_id,
            type="supabase",
            label=label,
            supabase_url=url,
            supabase_anon_key=anon_key,
            access_key=access_key,
        )
    elif provider_type == "sqlite":
        raw_path = (request.path or "").strip()
        if not raw_path:
            raise HTTPException(status_code=400, detail="SQLite path is required")
        sqlite_path = Path(raw_path)
        if not sqlite_path.is_absolute():
            sqlite_path = (Path(__file__).resolve().parents[2] / raw_path).resolve()
        provider = ProviderConfig(
            id=provider_id,
            type="sqlite",
            label=label,
            sqlite_path=str(sqlite_path),
            access_key=access_key,
        )
    else:
        url = (request.url or "").strip()
        if not url:
            raise HTTPException(status_code=400, detail="Database url is required")
        provider = ProviderConfig(
            id=provider_id,
            type=provider_type,
            label=label,
            connection_url=url,
            access_key=access_key,
        )

    saved = registry.upsert(provider)
    _invalidate_adapter_cache(saved.id)
    return {
        "provider": {
            "id": saved.id,
            "type": saved.type,
            "label": saved.label,
            "requiresAccessKey": bool(saved.access_key),
            "url": saved.supabase_url if saved.type == "supabase" else None,
            "anonKey": saved.supabase_anon_key if saved.type == "supabase" else None,
            "path": saved.sqlite_path if saved.type == "sqlite" else None,
            "connectionUrl": saved.connection_url if saved.type not in {"supabase", "sqlite"} else None,
        }
    }


@router.delete("/db/providers/{provider_id}")
def delete_db_provider(provider_id: str):
    registry = get_provider_registry()
    if not registry.is_mutable():
        raise HTTPException(status_code=403, detail="Provider changes are only allowed in Electron mode")

    trimmed = provider_id.strip()
    if not trimmed:
        raise HTTPException(status_code=400, detail="Provider id is required")

    deleted = registry.remove(trimmed)
    _invalidate_adapter_cache(trimmed)
    return {"deleted": deleted}


@router.post("/db/providers/{provider_id}/initialize")
def initialize_db_provider(provider_id: str):
    registry = get_provider_registry()
    provider = registry.get(provider_id.strip())
    if not provider:
        raise HTTPException(status_code=404, detail="Provider not found")

    _invalidate_adapter_cache(provider.id)
    result = initialize_provider_schema(provider)
    return result


@router.post("/db/query", response_model=DbQueryResponse)
def db_query(request: DbQueryRequest, x_db_access_key: str | None = Header(default=None)):
    try:
        registry = get_provider_registry()
        provider = registry.get(request.provider_id)
        if not provider:
            raise HTTPException(status_code=400, detail="Unknown providerId")

        # Prefer per-provider access key; fallback to global key if set.
        settings = get_settings()
        expected_key = provider.access_key or settings.db_access_key
        if expected_key and x_db_access_key != expected_key:
            raise HTTPException(status_code=401, detail="Invalid database access key")

        adapter = _get_adapter(request.provider_id)
        return adapter.execute(request)
    except HTTPException:
        raise
    except Exception as exc:
        logger.exception("[DB] Unhandled db_query error: %s", exc)
        return DbQueryResponse(error=f"Internal db proxy error: {exc}")