Spaces:
Sleeping
Sleeping
File size: 7,029 Bytes
4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d 592cb1d 4ef118d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | """
Database proxy routes.
The backend exposes one active database configuration at a time.
"""
from __future__ import annotations
import logging
import threading
from pathlib import Path
from fastapi import APIRouter, Header, HTTPException
from ..config import get_settings
from ..models.db import DbProviderUpsertRequest, DbQueryRequest, DbQueryResponse
from ..services.db_adapters import build_adapter
from ..services.db_registry import ProviderConfig, get_provider_registry, normalize_provider_type
from ..services.db_service import initialize_provider_schema, invalidate_db_adapter_cache
router = APIRouter()
logger = logging.getLogger(__name__)
_adapters = {}
_adapters_lock = threading.Lock()
def _get_adapter(provider_id: str):
registry = get_provider_registry()
provider = registry.get(provider_id)
if not provider:
raise HTTPException(status_code=400, detail="Unknown providerId")
with _adapters_lock:
if provider_id not in _adapters:
_adapters[provider_id] = build_adapter(provider)
return _adapters[provider_id]
def _invalidate_adapter_cache(provider_id: str | None = None):
with _adapters_lock:
if provider_id is None:
_adapters.clear()
else:
_adapters.pop(provider_id, None)
invalidate_db_adapter_cache(provider_id)
@router.get("/db/providers")
def list_db_providers():
registry = get_provider_registry()
settings = get_settings()
global_requires_key = bool((settings.db_access_key or "").strip())
include_details = registry.is_mutable()
data = [
{
"id": provider.id,
"type": provider.type,
"label": provider.label,
"requiresAccessKey": bool(provider.access_key) or global_requires_key,
"url": provider.supabase_url if include_details and provider.type == "supabase" else None,
"anonKey": (
provider.supabase_anon_key if include_details and provider.type == "supabase" else None
),
"path": provider.sqlite_path if include_details and provider.type == "sqlite" else None,
"connectionUrl": (
provider.connection_url
if include_details and provider.type not in {"supabase", "sqlite"}
else None
),
}
for provider in registry.list()
]
return {"providers": data, "provider": data[0] if data else None, "mutable": registry.is_mutable()}
@router.post("/db/providers")
def upsert_db_provider(request: DbProviderUpsertRequest):
registry = get_provider_registry()
if not registry.is_mutable():
raise HTTPException(status_code=403, detail="Provider changes are only allowed in Electron mode")
provider_id = request.id.strip() or "default"
provider_type = normalize_provider_type(request.type)
if not provider_type:
raise HTTPException(status_code=400, detail="Unsupported database provider type")
label = (request.label or provider_type).strip()
access_key = (request.access_key or "").strip() or None
if provider_type == "supabase":
url = (request.url or "").strip()
anon_key = (request.anon_key or "").strip()
if not url or not anon_key:
raise HTTPException(status_code=400, detail="Supabase url and anonKey are required")
provider = ProviderConfig(
id=provider_id,
type="supabase",
label=label,
supabase_url=url,
supabase_anon_key=anon_key,
access_key=access_key,
)
elif provider_type == "sqlite":
raw_path = (request.path or "").strip()
if not raw_path:
raise HTTPException(status_code=400, detail="SQLite path is required")
sqlite_path = Path(raw_path)
if not sqlite_path.is_absolute():
sqlite_path = (Path(__file__).resolve().parents[2] / raw_path).resolve()
provider = ProviderConfig(
id=provider_id,
type="sqlite",
label=label,
sqlite_path=str(sqlite_path),
access_key=access_key,
)
else:
url = (request.url or "").strip()
if not url:
raise HTTPException(status_code=400, detail="Database url is required")
provider = ProviderConfig(
id=provider_id,
type=provider_type,
label=label,
connection_url=url,
access_key=access_key,
)
saved = registry.upsert(provider)
_invalidate_adapter_cache(saved.id)
return {
"provider": {
"id": saved.id,
"type": saved.type,
"label": saved.label,
"requiresAccessKey": bool(saved.access_key),
"url": saved.supabase_url if saved.type == "supabase" else None,
"anonKey": saved.supabase_anon_key if saved.type == "supabase" else None,
"path": saved.sqlite_path if saved.type == "sqlite" else None,
"connectionUrl": saved.connection_url if saved.type not in {"supabase", "sqlite"} else None,
}
}
@router.delete("/db/providers/{provider_id}")
def delete_db_provider(provider_id: str):
registry = get_provider_registry()
if not registry.is_mutable():
raise HTTPException(status_code=403, detail="Provider changes are only allowed in Electron mode")
trimmed = provider_id.strip()
if not trimmed:
raise HTTPException(status_code=400, detail="Provider id is required")
deleted = registry.remove(trimmed)
_invalidate_adapter_cache(trimmed)
return {"deleted": deleted}
@router.post("/db/providers/{provider_id}/initialize")
def initialize_db_provider(provider_id: str):
registry = get_provider_registry()
provider = registry.get(provider_id.strip())
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
_invalidate_adapter_cache(provider.id)
result = initialize_provider_schema(provider)
return result
@router.post("/db/query", response_model=DbQueryResponse)
def db_query(request: DbQueryRequest, x_db_access_key: str | None = Header(default=None)):
try:
registry = get_provider_registry()
provider = registry.get(request.provider_id)
if not provider:
raise HTTPException(status_code=400, detail="Unknown providerId")
# Prefer per-provider access key; fallback to global key if set.
settings = get_settings()
expected_key = provider.access_key or settings.db_access_key
if expected_key and x_db_access_key != expected_key:
raise HTTPException(status_code=401, detail="Invalid database access key")
adapter = _get_adapter(request.provider_id)
return adapter.execute(request)
except HTTPException:
raise
except Exception as exc:
logger.exception("[DB] Unhandled db_query error: %s", exc)
return DbQueryResponse(error=f"Internal db proxy error: {exc}")
|