| import json |
| import os |
| import aiohttp |
| from config import DB_SPACE_URL, HF_TOKEN |
| from typing import Any |
|
|
| HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} |
| _session: aiohttp.ClientSession | None = None |
|
|
|
|
| async def get_session() -> aiohttp.ClientSession: |
| global _session |
| if _session is None: |
| _session = aiohttp.ClientSession(headers=HEADERS) |
| return _session |
|
|
|
|
| async def close_session(): |
| global _session |
| if _session: |
| await _session.close() |
| _session = None |
|
|
|
|
| def _escape(val: Any) -> str: |
| """Safe value escaping for SQLite inline usage.""" |
| if val is None: |
| return "NULL" |
| t = type(val) |
| if t is int or t is float: |
| return str(val) |
| if t is bool: |
| return "1" if val else "0" |
| s = str(val) |
| |
| return json.dumps(s) |
|
|
|
|
| def _build_sql(sql: str, params: list | None = None) -> str: |
| """ |
| Build SQL string by replacing ? placeholders with escaped values. |
| NOTE: This is a fallback. The prefered approach is to pass params |
| separately when the DB supports parameterized queries. |
| """ |
| if not params: |
| return sql |
| parts = sql.split("?") |
| result = parts[0] |
| for i, p in enumerate(params): |
| if i + 1 < len(parts): |
| result += _escape(p) + parts[i + 1] |
| else: |
| result += _escape(p) |
| return result |
|
|
|
|
| async def exec(sql: str, params: list = None) -> dict: |
| """Execute a SQL statement (INSERT, UPDATE, DELETE, DDL).""" |
| session = await get_session() |
| full_sql = _build_sql(sql, params) |
| payload: dict[str, Any] = {"sql": full_sql} |
| async with session.post(f"{DB_SPACE_URL}/exec", json=payload) as resp: |
| if resp.status >= 400: |
| text = await resp.text() |
| return {"ok": False, "error": text[:500]} |
| return await resp.json() |
|
|
|
|
| async def query(sql: str, params: list = None) -> dict: |
| """Query the database and return raw response.""" |
| session = await get_session() |
| full_sql = _build_sql(sql, params) |
| async with session.get( |
| f"{DB_SPACE_URL}/query", |
| params={"sql": full_sql}, |
| ) as resp: |
| if resp.status >= 400: |
| text = await resp.text() |
| return {"ok": False, "error": text[:500], "cols": [], "rows": []} |
| return await resp.json() |
|
|
|
|
| async def fetch_all(sql: str, params: list = None) -> list[dict]: |
| """Fetch multiple rows as dicts.""" |
| r = await query(sql, params) |
| if not r or "rows" not in r or not r["rows"]: |
| return [] |
| if r.get("ok") is False: |
| return [] |
| cols = r.get("cols") or r.get("columns", []) |
| col_names = [c["name"] if isinstance(c, dict) else c for c in cols] |
| return [dict(zip(col_names, row)) for row in r["rows"]] |
|
|
|
|
| async def fetch_one(sql: str, params: list = None) -> dict | None: |
| """Fetch a single row as a dict, or None.""" |
| rows = await fetch_all(sql, params) |
| return rows[0] if rows else None |
|
|
|
|
| async def execute(sql: str, params: list = None) -> dict: |
| """Alias for exec().""" |
| return await exec(sql, params) |
|
|
|
|
| async def init_tables(sqls: list[str]) -> list[dict]: |
| """Execute a list of SQL statements (for schema init).""" |
| results = [] |
| for sql in sqls: |
| r = await exec(sql) |
| results.append(r) |
| return results |
|
|
|
|
| async def execute_many(sql: str, params_list: list[list]) -> list[dict]: |
| """Execute the same SQL with multiple parameter sets.""" |
| results = [] |
| for params in params_list: |
| r = await exec(sql, params) |
| results.append(r) |
| return results |
|
|