File size: 3,612 Bytes
c18e004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import os
import aiohttp
from config import DB_SPACE_URL, HF_TOKEN
from typing import Any

HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
_session: aiohttp.ClientSession | None = None


async def get_session() -> aiohttp.ClientSession:
    global _session
    if _session is None:
        _session = aiohttp.ClientSession(headers=HEADERS)
    return _session


async def close_session():
    global _session
    if _session:
        await _session.close()
        _session = None


def _escape(val: Any) -> str:
    """Safe value escaping for SQLite inline usage."""
    if val is None:
        return "NULL"
    t = type(val)
    if t is int or t is float:
        return str(val)
    if t is bool:
        return "1" if val else "0"
    s = str(val)
    # Instead of manual escaping, we use parameterized queries via the JSON API
    return json.dumps(s)


def _build_sql(sql: str, params: list | None = None) -> str:
    """
    Build SQL string by replacing ? placeholders with escaped values.
    NOTE: This is a fallback. The prefered approach is to pass params
    separately when the DB supports parameterized queries.
    """
    if not params:
        return sql
    parts = sql.split("?")
    result = parts[0]
    for i, p in enumerate(params):
        if i + 1 < len(parts):
            result += _escape(p) + parts[i + 1]
        else:
            result += _escape(p)
    return result


async def exec(sql: str, params: list = None) -> dict:
    """Execute a SQL statement (INSERT, UPDATE, DELETE, DDL)."""
    session = await get_session()
    full_sql = _build_sql(sql, params)
    payload: dict[str, Any] = {"sql": full_sql}
    async with session.post(f"{DB_SPACE_URL}/exec", json=payload) as resp:
        if resp.status >= 400:
            text = await resp.text()
            return {"ok": False, "error": text[:500]}
        return await resp.json()


async def query(sql: str, params: list = None) -> dict:
    """Query the database and return raw response."""
    session = await get_session()
    full_sql = _build_sql(sql, params)
    async with session.get(
        f"{DB_SPACE_URL}/query",
        params={"sql": full_sql},
    ) as resp:
        if resp.status >= 400:
            text = await resp.text()
            return {"ok": False, "error": text[:500], "cols": [], "rows": []}
        return await resp.json()


async def fetch_all(sql: str, params: list = None) -> list[dict]:
    """Fetch multiple rows as dicts."""
    r = await query(sql, params)
    if not r or "rows" not in r or not r["rows"]:
        return []
    if r.get("ok") is False:
        return []
    cols = r.get("cols") or r.get("columns", [])
    col_names = [c["name"] if isinstance(c, dict) else c for c in cols]
    return [dict(zip(col_names, row)) for row in r["rows"]]


async def fetch_one(sql: str, params: list = None) -> dict | None:
    """Fetch a single row as a dict, or None."""
    rows = await fetch_all(sql, params)
    return rows[0] if rows else None


async def execute(sql: str, params: list = None) -> dict:
    """Alias for exec()."""
    return await exec(sql, params)


async def init_tables(sqls: list[str]) -> list[dict]:
    """Execute a list of SQL statements (for schema init)."""
    results = []
    for sql in sqls:
        r = await exec(sql)
        results.append(r)
    return results


async def execute_many(sql: str, params_list: list[list]) -> list[dict]:
    """Execute the same SQL with multiple parameter sets."""
    results = []
    for params in params_list:
        r = await exec(sql, params)
        results.append(r)
    return results